In [19]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import os
from PIL import Image
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import torch.nn.functional as F
import random
import cv2
from transformers import AutoImageProcessor,ConvNextModel
from tqdm import tqdm 
from torch.optim.lr_scheduler import ReduceLROnPlateau
from itertools import cycle
import math
from tqdm import tqdm 
from functools import lru_cache

In [20]:
import warnings
warnings.filterwarnings("ignore")

# pre-defined
category_dict = {'c1':'Men Tshirts', 'c2':'Sarees', 'c3': 'Kurtis', 'c4': 'Women Tshirts', 'c5': 'Women Tops & Tunics'}
semi_classes_dict = {'c1':[4, 2, 2, 3, 2], 'c2':[4, 6, 3, 8, 4, 3, 4, 5, 9, 2], 'c3': [13, 2, 2, 2, 2, 2, 2, 3, 2], 'c4': [7, 3, 3, 3, 6, 3, 2, 2], 'c5': [12, 4, 2, 7, 2, 3, 6, 4, 4, 6]}

In [21]:
# set paths as per the set-up and Hyperparameters
input_path = "/kaggle/input/meesho/visual-taxonomy"
working_path = "/kaggle/working"
test_c_name = "c4"

BEST_MODEL_FROM_BASE_FIRST_TRAINING = f'{working_path}/best_model_{test_c_name}.pth'
NUM_EPOCH = 4
NUM_ATTR_EPOCHS = 1
LEARNING_RATE = 0.001
NUM_OF_SEMI_CLASSES_OF_COLUMNS = semi_classes_dict[test_c_name]

In [22]:
df_train = pd.read_csv(f'{input_path}/train.csv')
df_test = pd.read_csv(f'{input_path}/test.csv')

In [23]:
df_train['Category'].unique()

array(['Men Tshirts', 'Sarees', 'Kurtis', 'Women Tshirts',
       'Women Tops & Tunics'], dtype=object)

In [24]:
# Can change the value for the max_nans to more then 3, it is the hyperparameter that can be selected from the EDA-thing
def fill_nan_with_prioritized_similarity(df, num_attrs, max_nans=3):
    attr_columns = [f'attr_{i}' for i in range(1, num_attrs + 1)]
    # Step 1: Precompute modes for each attribute combination and store them in a dictionary for quick access
    precomputed_modes = {}
    
    def get_mode_for_subset(attr, conditions):
        """Get mode of the specified attribute based on the given conditions."""
        condition_tuple = tuple(conditions.items())
        if (attr, condition_tuple) in precomputed_modes:
            return precomputed_modes[(attr, condition_tuple)]
        # Filter the DataFrame based on the conditions
        subset = df
        for col, val in conditions.items():
            subset = subset[subset[col] == val]
        # Calculate the mode and store it in the cache
        if not subset.empty:
            mode_val = subset[attr].mode().iloc[0] if not subset[attr].mode().empty else None
        else:
            mode_val = None
        precomputed_modes[(attr, condition_tuple)] = mode_val
        return mode_val
    
    def fill_row(row):
        nan_count = row[attr_columns].isna().sum()
        if nan_count > max_nans:
            return row  # Skip rows with more than max_nans NaNs
        for attr_idx, attr in enumerate(attr_columns):
            if pd.isna(row[attr]):
                # Step 2: Build conditions dictionary starting from attributes after the NaN attribute
                conditions = {attr_columns[i]: row[attr_columns[i]] for i in range(attr_idx + 1, num_attrs) if not pd.isna(row[attr_columns[i]])}
                # Check if there's a matching subset and get the mode for the current attribute
                mode_value = get_mode_for_subset(attr, conditions)
                if mode_value is not None:
                    row[attr] = mode_value
                else:
                    # Fallback: fill with the overall mode if subset mode is not found
                    row[attr] = df[attr].mode().iloc[0] if not df[attr].mode().empty else row[attr]
        return row
    # Apply the optimized function row by row
    df = df.apply(fill_row, axis=1)
    return df

# 1st traing data -> df_sub

# 2nd traing data -> temp

In [25]:
def do_preprocessing(test_c_name, test_category):
    # Pre-processing the data for first model
    df_sub = df_train[df_train['Category'] == test_category]
    df_sub.dropna(axis=1, how='all', inplace=True)
    # For filling the null values upto 3 NaN thing
    print(f"Length of {test_c_name} df before removing NaN values: {len(df_sub)}")

    df_sub = fill_nan_with_prioritized_similarity(df_sub, num_attrs=len(semi_classes_dict[test_c_name]))
    temp = []
    for i in range(1,len(df_sub.columns)-2):
        temp.append(len(df_sub[f'attr_{i}'].unique().tolist())-1)
    print(f"Number of features for original df: {temp}")
    print()
    
    df_sub.dropna(axis=0, how='any', inplace=True)
    print(f"Length of {test_c_name} df after removing NaN values: {len(df_sub)}")
    
    temp = []
    for i in range(1,len(df_sub.columns)-2):
        temp.append(len(df_sub[f'attr_{i}'].unique().tolist()))
    print(f"Number of features for new df: {temp}")
    # For the second model
    df_temp = df_train[df_train['Category'] == test_category]
    df_temp = fill_nan_with_prioritized_similarity(df_temp, num_attrs=len(semi_classes_dict[test_c_name]))
    return df_sub, df_temp

In [26]:
class LabelEncoderDict:
    
    def __init__(self):
        self.encoders = {}
   
    def fit(self, df, columns):
        """Fit label encoders for each column"""
        for col in columns:
            le = LabelEncoder()
            valid_labels = df[col].unique().tolist()
            valid_labels = [x for x in valid_labels if not (isinstance(x, float) and math.isnan(x))]
            le.fit(valid_labels)
            self.encoders[col] = le
    
    def transform(self, df, columns):
        """Transform labels using fitted encoders"""
        encoded = np.zeros((len(df), len(columns)))
        for i, col in enumerate(columns):
            series = df[col].copy()
            encoded[:, i] = self.encoders[col].transform(series)
        return encoded
    
    def get_num_classes(self, column):
        """Get number of classes for a specific column"""
        return len(self.encoders[column].classes_)

class MultiLabelImageDataset(Dataset):
    
    def __init__(self, df, image_dir, transform_basic=None, transform_augmented=None, attr_columns=10, do_transform=True):
        self.df = df
        self.image_dir = image_dir
        self.transform_basic = transform_basic  # Basic transform without augmentation
        self.transform_augmented = transform_augmented  # Augmented transform with augmentation
        self.attr_columns = attr_columns
        self.do_transform = do_transform
   
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # Get image path
        img_name = str(self.df.iloc[idx]['id']).zfill(6)
        img_path = os.path.join(self.image_dir, f"{img_name}.jpg")
        # Load image
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            image = Image.new('RGB', (512, 512))
        if  self.do_transform and (random.random() > 0.5):
            if self.transform_augmented:
                image = self.transform_augmented(image)         
        else:
            if self.transform_basic:
                image = self.transform_basic(image)
        # Ensure labels are integers and convert to tensor
        labels = torch.tensor(self.df.iloc[idx][self.attr_columns].astype(int).values, dtype=torch.long)
        return image, labels

def prepare_data(df,label_encoders, image_dir, batch_size=32, test_size=0.1, num_attr_columns=10):
    """
    Prepare data loaders and label encoders
    """
      # TODO: Adjust number of columns
    # Transform labels
    attr_columns = [f'attr_{i}' for i in range(1, num_attr_columns+1)]
    encoded_labels = label_encoders.transform(df, attr_columns)
    df_encoded = df.copy()
    for i, col in enumerate(attr_columns):
        df_encoded[col] = encoded_labels[:, i]
    # Split data
    train_df, val_df = train_test_split(df_encoded, test_size=test_size, random_state=24)
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((512, 512)),  # Resize to 512x512
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    transform_augmented = transforms.Compose([
        transforms.Resize((512, 512)),  # Resize to 512x512
        transforms.RandomHorizontalFlip(p=0.9),  # 90% chance of horizontal flipping
        transforms.RandomRotation(degrees=5),  # Rotate by up to 20 degrees
        transforms.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0)),  # Randomly crop and resize
        transforms.RandomPerspective(distortion_scale=0.1, p=0.5),  # Apply perspective distortion
        transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), scale=(0.85, 1.15), shear=5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    # Create datasets
    train_dataset = MultiLabelImageDataset(
        train_df,
        image_dir,
        transform_basic=transform,
        transform_augmented=transform_augmented,
        attr_columns=attr_columns
    )
    val_dataset = MultiLabelImageDataset(
        val_df,
        image_dir,
        transform_basic=transform,
        transform_augmented=transform_augmented,
        attr_columns=attr_columns,
        do_transform=False
    )
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, drop_last=True)
    # Get number of classes for each attribute
    num_classes_per_attr = [label_encoders.get_num_classes(col) for col in attr_columns]
    return train_loader, val_loader, label_encoders, num_classes_per_attr

In [27]:
class MultiLabelClassifier(nn.Module):
    def __init__(self, num_classes_per_attr):
        super(MultiLabelClassifier, self).__init__()
        # Use ConvNeXt-Base with unfrozen backbone
        self.backbone = ConvNextModel.from_pretrained("facebook/convnext-base-384-22k-1k")
        backbone_features = self.backbone.config.hidden_sizes[-1]  # 1024 for base model
        # Modified feature processing without fixed dimensions
        self.feature_processor = nn.Sequential(
            nn.Conv2d(backbone_features, 1024, kernel_size=1),
            nn.GELU(),
            nn.Dropout(0.1)
        )
        # Keep original ModuleList structure
        self.classifier_heads = nn.ModuleList()
        for num_classes in num_classes_per_attr:
            classifier_head = nn.Sequential(
                # First branch - Spatial attention
                nn.Sequential(
                    nn.Conv2d(1024, 512, kernel_size=3, padding=1, groups=32),
                    nn.GELU(),
                    nn.Conv2d(512, 512, kernel_size=3, padding=1, groups=32),
                    nn.GELU(),
                ),
                # Second branch - Channel attention (SE-like module)
                nn.Sequential(
                    nn.AdaptiveAvgPool2d(1),
                    nn.Flatten(),
                    nn.Linear(512, 128),
                    nn.GELU(),
                    nn.Linear(128, 512),
                    nn.Sigmoid(),
                ),
                # Combine branches and final classification
                nn.Sequential(
                    nn.AdaptiveAvgPool2d(1),
                    nn.Flatten(),
                    nn.Linear(512, 1024),
                    nn.LayerNorm(1024),
                    nn.GELU(),
                    nn.Dropout(0.2),
                    nn.Linear(1024, 512),
                    nn.LayerNorm(512),
                    nn.Sigmoid(),
                    nn.Dropout(0.1),
                    nn.Linear(512, num_classes)
                )
            )
            self.classifier_heads.append(classifier_head)
   
    def freeze_backbone(self):
        """Freeze the backbone model"""
        for param in self.backbone.parameters():
            param.requires_grad = False
    
    def unfreeze_backbone(self):
        """Unfreeze the backbone model"""
        for param in self.backbone.parameters():
            param.requires_grad = True
    
    def freeze_feature_processor(self):
        """Freeze the feature processor"""
        for param in self.feature_processor.parameters():
            param.requires_grad = False
    
    def unfreeze_feature_processor(self):
        """Unfreeze the feature processor"""
        for param in self.feature_processor.parameters():
            param.requires_grad = True
    
    def set_classifier_head_trainable(self, attr_index):
        """
        Freeze all classifier heads except the specified one
        Args:
            attr_index: index of the attribute head to train (0 for attr_1, 1 for attr_2, etc.)
        """
        for i, head in enumerate(self.classifier_heads):
            for param in head.parameters():
                param.requires_grad = (i == attr_index)
  
    def freeze_all_except_head(self, attr_index):
        """
        Freeze everything except the specified classifier head
        Args:
            attr_index: index of the attribute head to train (0 for attr_1, 1 for attr_2, etc.)
        """
        self.freeze_backbone()
        self.freeze_feature_processor()
        self.set_classifier_head_trainable(attr_index)
    
    def unfreeze_all(self):
        """Unfreeze all model components"""
        self.unfreeze_backbone()
        self.unfreeze_feature_processor()
        for head in self.classifier_heads:
            for param in head.parameters():
                param.requires_grad = True
    
    def forward(self, x, attr_index=None, return_features=False):
        """
        Forward pass with optional attribute-specific output
        Args:
            x: input tensor
            attr_index: specific attribute index to get output for (0 for attr_1, etc.)
            return_features: whether to return processed features
        """
        # Extract features from ConvNeXt backbone
        features = self.backbone(x).last_hidden_state
        # Process features
        processed_features = self.feature_processor(features)
        if attr_index is not None:
            # Get output for specific attribute
            classifier_head = self.classifier_heads[attr_index]
            # Spatial attention branch
            spatial_features = classifier_head[0](processed_features)
            # Channel attention branch
            channel_attention = classifier_head[1](spatial_features)
            channel_attention = channel_attention.view(-1, 512, 1, 1)
            # Apply channel attention and get final output
            attended_features = spatial_features * channel_attention
            output = classifier_head[2](attended_features)
            if return_features:
                return output, processed_features
            return output
        # Get outputs for all attributes
        outputs = []
        for classifier_head in self.classifier_heads:
            # Spatial attention branch
            spatial_features = classifier_head[0](processed_features)
            # Channel attention branch
            channel_attention = classifier_head[1](spatial_features)
            channel_attention = channel_attention.view(-1, 512, 1, 1)
            # Apply channel attention and get final output
            attended_features = spatial_features * channel_attention
            outputs.append(classifier_head[2](attended_features))
        if return_features:
            return outputs, processed_features
        return outputs

In [28]:
class MultiLabelCELoss(nn.Module):
   
    def __init__(self):
        super(MultiLabelCELoss, self).__init__()
        self.criterion = nn.CrossEntropyLoss()
    
    def forward(self, outputs, targets):
        loss = 0
        for i, output in enumerate(outputs):
            loss += self.criterion(output, targets[:, i])
        return loss / len(outputs)

class CombinedLoss(nn.Module):
    def __init__(self, lambda_mmd=0.1, chunk_size=1024):
        super(CombinedLoss, self).__init__()
        self.lambda_mmd = lambda_mmd
        self.chunk_size = chunk_size
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, outputs, labels, source_features, target_features):
        # Calculate Cross-Entropy Loss
        if isinstance(outputs, list):
            ce_loss = 0
            for i, output in enumerate(outputs):
                if isinstance(labels, list):
                    label = labels[i]
                else:
                    label = labels[:, i] if labels.dim() > 1 else labels
                ce_loss += self.cross_entropy_loss(output, label)
            ce_loss = ce_loss / len(outputs)
        else:
            ce_loss = self.cross_entropy_loss(outputs, labels)
        # Calculate MMD Loss
        mmd_loss = self.maximum_mean_discrepancy(source_features, target_features)
        # Combine losses (without CORAL)
        total_loss = ce_loss + self.lambda_mmd * mmd_loss
        # Return zero for coral_loss to maintain compatibility
        return total_loss, ce_loss, mmd_loss, torch.tensor(0.0, device=ce_loss.device)
    
    def gaussian_kernel(self, x, y, bandwidth):
        x = x.view(x.size(0), -1)
        y = y.view(y.size(0), -1)
        x_size = x.size(0)
        y_size = y.size(0)
        dim = x.size(1)
        x = x.unsqueeze(1)  # (x_size, 1, dim)
        y = y.unsqueeze(0)  # (1, y_size, dim)
        kernel_input = (x - y).pow(2).sum(2).div(2 * bandwidth * bandwidth)
        return torch.exp(-kernel_input)  # (x_size, y_size)
    
    def maximum_mean_discrepancy(self, source_features, target_features):
        # Ensure inputs are 2D tensors
        if source_features.dim() > 2:
            source_features = source_features.view(source_features.size(0), -1)
        if target_features.dim() > 2:
            target_features = target_features.view(target_features.size(0), -1)
        # Get sizes
        batch_source = source_features.size(0)
        batch_target = target_features.size(0)
        dim = source_features.size(1)
        # Initialize MMD
        mmd = torch.tensor(0., device=source_features.device)
        # Use multiple kernel bandwidths
        bandwidths = [dim * (2 ** i) for i in range(-3, 3)]
        for bandwidth in bandwidths:
            # Process source-source
            source_sum = 0
            for i in range(0, batch_source, self.chunk_size):
                end = min(i + self.chunk_size, batch_source)
                chunk = source_features[i:end]
                kernel = self.gaussian_kernel(chunk, source_features, bandwidth)
                source_sum += kernel.sum().item()
            # Process target-target
            target_sum = 0
            for i in range(0, batch_target, self.chunk_size):
                end = min(i + self.chunk_size, batch_target)
                chunk = target_features[i:end]
                kernel = self.gaussian_kernel(chunk, target_features, bandwidth)
                target_sum += kernel.sum().item()
            # Process source-target
            cross_sum = 0
            for i in range(0, batch_source, self.chunk_size):
                s_end = min(i + self.chunk_size, batch_source)
                s_chunk = source_features[i:s_end]
                for j in range(0, batch_target, self.chunk_size):
                    t_end = min(j + self.chunk_size, batch_target)
                    t_chunk = target_features[j:t_end]
                    kernel = self.gaussian_kernel(s_chunk, t_chunk, bandwidth)
                    cross_sum += kernel.sum().item()
            # Calculate bandwidth contribution to MMD
            source_term = source_sum / (batch_source * batch_source)
            target_term = target_sum / (batch_target * batch_target)
            cross_term = 2 * cross_sum / (batch_source * batch_target)
            mmd = mmd + torch.tensor(source_term + target_term - cross_term, 
                                   device=source_features.device)
            # Clear cache
            if hasattr(torch.cuda, 'empty_cache'):
                torch.cuda.empty_cache()
        return mmd / len(bandwidths)

def train_model(model, train_loader, val_loader, num_epochs, num_classes_per_attr, model_type):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model = torch.nn.DataParallel(model)
    criterion = CombinedLoss()
    criterion2 = MultiLabelCELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    # Reduce learning rate on plateau
    lr_scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
    # Early stopping params
    early_stopping_patience = 4
    early_stopping_counter = 0
    best_val_overall_acc = 0.0
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        correct_predictions = [0] * len(num_classes_per_attr)
        total_predictions = 0
        overall_correct = 0
        # Store true and predicted labels for metrics calculation
        train_true_labels = [[] for _ in range(len(num_classes_per_attr))]
        train_predicted_labels = [[] for _ in range(len(num_classes_per_attr))]
        # Create cyclic iterator for validation data during training
        val_cycle = cycle(val_loader)
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]", unit="batch") as t:
            for images, labels in t:
                # Get target domain batch
                target_images, _ = next(val_cycle)
                # Move data to device
                images = images.to(device)
                labels = labels.to(device)
                target_images = target_images.to(device)
                # Forward pass on source domain (training data)
                outputs, source_features = model(images, return_features=True)
                # Forward pass on target domain (validation data)
                _, target_features = model(target_images, return_features=True)
                # Calculate combined loss
                loss, ce_loss, mmd_loss, coral_loss = criterion(outputs, labels, source_features, target_features)
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                train_loss += loss.item()
                t.set_postfix(loss=loss.item())
                all_labels_match = torch.ones(labels.size(0), dtype=torch.bool, device=device)
                for i, output in enumerate(outputs):
                    _, predicted = torch.max(output, 1)
                    correct_predictions[i] += (predicted == labels[:, i]).sum().item()
                    all_labels_match &= (predicted == labels[:, i])
                    train_true_labels[i].extend(labels[:, i].cpu().numpy())
                    train_predicted_labels[i].extend(predicted.cpu().numpy())
                overall_correct += all_labels_match.sum().item()
                total_predictions += labels.size(0)
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct_predictions = [0] * len(num_classes_per_attr)
        val_total_predictions = 0
        val_overall_correct = 0
        val_true_labels = [[] for _ in range(len(num_classes_per_attr))]
        val_predicted_labels = [[] for _ in range(len(num_classes_per_attr))]
        with tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]", unit="batch") as v:
            with torch.no_grad():
                for images, labels in v:
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images,return_features=False)
                    loss = criterion2(outputs, labels)
                    val_loss += loss.item()
                    v.set_postfix(loss=loss.item())
                    all_labels_match_val = torch.ones(labels.size(0), dtype=torch.bool, device=device)
                    for i, output in enumerate(outputs):
                        _, predicted = torch.max(output, 1)
                        val_correct_predictions[i] += (predicted == labels[:, i]).sum().item()
                        all_labels_match_val &= (predicted == labels[:, i])
                        # Store validation labels for precision, recall, f1-score calculations
                        val_true_labels[i].extend(labels[:, i].cpu().numpy())
                        val_predicted_labels[i].extend(predicted.cpu().numpy())
                    val_overall_correct += all_labels_match_val.sum().item()
                    val_total_predictions += labels.size(0)
        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Training Loss: {train_loss/len(train_loader):.4f}')
        print(f'Validation Loss: {val_loss/len(val_loader):.4f}')
        for i in range(len(num_classes_per_attr)):
            train_acc = 100 * correct_predictions[i] / total_predictions
            val_acc = 100 * val_correct_predictions[i] / val_total_predictions
            train_precision = precision_score(train_true_labels[i], train_predicted_labels[i], average='weighted')
            train_recall = recall_score(train_true_labels[i], train_predicted_labels[i], average='weighted')
            train_f1 = f1_score(train_true_labels[i], train_predicted_labels[i], average='weighted')
            val_precision = precision_score(val_true_labels[i], val_predicted_labels[i], average='weighted')
            val_recall = recall_score(val_true_labels[i], val_predicted_labels[i], average='weighted')
            val_f1 = f1_score(val_true_labels[i], val_predicted_labels[i], average='weighted')
            print(f'Attribute {i+1} - Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%')
            print(f'Attribute {i+1} - Train Precision: {train_precision:.2f}, Train Recall: {train_recall:.2f}, Train F1-Score: {train_f1:.2f}')
            print(f'Attribute {i+1} - Val Precision: {val_precision:.2f}, Val Recall: {val_recall:.2f}, Val F1-Score: {val_f1:.2f}')
            print()
        overall_train_acc = 100 * overall_correct / total_predictions
        overall_val_acc = 100 * val_overall_correct / val_total_predictions
        print(f'Overall Train Accuracy: {overall_train_acc:.2f}%')
        print(f'Overall Validation Accuracy: {overall_val_acc:.2f}%')
        # Early stopping logic
        if overall_val_acc >= best_val_overall_acc:
            best_val_overall_acc = overall_val_acc
            torch.save(model.state_dict(), f'best_model_{model_type}.pth')
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1
        lr_scheduler.step(overall_val_acc)
        if early_stopping_counter >= early_stopping_patience:
            print("Early stopping triggered")
    torch.save(model.module.state_dict(), f'best_model_end_{model_type}.pth')
    return model

In [29]:
class SingleAttributeDataset(Dataset):
    def __init__(self, df, image_dir, attribute, transform_basic=None, transform_augmented=None, do_transform=True):
        self.df = df[df[attribute].notna()].reset_index(drop=True)
        self.image_dir = image_dir
        self.transform_basic = transform_basic
        self.transform_augmented = transform_augmented
        self.attribute = attribute
        self.do_transform = do_transform
  
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # Get image path
        img_name = str(self.df.iloc[idx]['id']).zfill(6)
        img_path = os.path.join(self.image_dir, f"{img_name}.jpg")
        # Load image
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            image = Image.new('RGB', (512, 512))
        # Apply transforms
        if self.do_transform and (random.random() > 0.5):
            if self.transform_augmented:
                image = self.transform_augmented(image)
        else:
            if self.transform_basic:
                image = self.transform_basic(image)
        # Get label for this attribute
        label = torch.tensor(self.df.iloc[idx][self.attribute], dtype=torch.long)
        return image, label

def prepare_attribute_data(df, image_dir,label_encoders, batch_size=32, num_attr_columns=10, test_size=0.1):
    """
    Prepare separate data loaders for each attribute
    """
    # Define attribute columns
    attr_columns = [f'attr_{i}' for i in range(1, num_attr_columns+1)]
    # Transform labels for each attribute
    df_encoded = df.copy()
    for col in attr_columns:
        # Only encode non-null values
        mask = df_encoded[col].notna()
        df_encoded.loc[mask, col] = label_encoders.encoders[col].transform(df_encoded.loc[mask, col])
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    transform_augmented = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.RandomHorizontalFlip(p=0.9),
        transforms.RandomRotation(degrees=5),
        transforms.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0)),
        transforms.RandomPerspective(distortion_scale=0.1, p=0.5),
        transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), scale=(0.85, 1.15), shear=5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    # Create separate dataloaders for each attribute
    attribute_loaders = {}
    num_classes_per_attr = {}
    for attr in attr_columns:
        # Get data where this attribute is not null
        attr_df = df_encoded[df_encoded[attr].notna()].copy()
        # Split data for this attribute
        train_df, val_df = train_test_split(
            attr_df, 
            test_size=test_size, 
            stratify=attr_df[attr],  # Stratify by this attribute
            random_state=24
        )
        # Create datasets
        train_dataset = SingleAttributeDataset(
            train_df,
            image_dir,
            attr,
            transform_basic=transform,
            transform_augmented=transform_augmented,
            do_transform=True
        )
        val_dataset = SingleAttributeDataset(
            val_df,
            image_dir,
            attr,
            transform_basic=transform,
            transform_augmented=None, 
            do_transform=False
        )
        # Create dataloaders
        train_loader = DataLoader(
            train_dataset, 
            batch_size=batch_size, 
            shuffle=True, 
            drop_last=True
        )
        val_loader = DataLoader(
            val_dataset, 
            batch_size=batch_size, 
            shuffle=False,
            drop_last=True
        )
        # Store loaders for this attribute
        attribute_loaders[attr] = {
            'train': train_loader,
            'val': val_loader,
            'num_samples': len(attr_df)
        }
        # Store number of classes for this attribute
        num_classes_per_attr[attr] = label_encoders.get_num_classes(attr)
        print(f"{attr}: {len(train_df)} train samples, {len(val_df)} val samples")
    return attribute_loaders

In [30]:
def train_attribute_model(model, dataloaders, num_epochs, model_type, device='cuda', max_grad_norm=1.0):
    """
    Train a model for multi-attribute classification, handling layer freezing/unfreezing per attribute.
    Parameters:
    - model: The model wrapped in torch.nn.DataParallel (must have a method to freeze/unfreeze layers).
    - dataloaders: Dictionary containing 'train' and 'val' DataLoaders for each attribute.
    - num_epochs: Number of epochs to train the model.
    - model_type: Type of model being trained (used for saving checkpoints).
    - device: The device to run the model on (default is 'cuda').
    - max_grad_norm: Maximum norm for gradient clipping (default is 1.0).
    """
    for attr_index, attr in enumerate(dataloaders.keys()):
        train_loader = dataloaders[attr]['train']
        val_loader = dataloaders[attr]['val']
        best_val_acc = 0.0
        # Ensure model is detached from DataParallel before modifying layers
        model.module.freeze_all_except_head(attr_index)
        model.to(device)
        optimizer = torch.optim.AdamW(model.parameters())
        # Add ReduceLROnPlateau scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min',
            factor=0.1,
            patience=3,
            verbose=True
        )
        criterion = torch.nn.CrossEntropyLoss()
        for epoch in range(num_epochs):
            # Training phase
            model.train()
            train_loss = 0
            all_train_preds = []
            all_train_labels = []
            for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} train {attr}"):
                images, labels = images.to(device), labels.to(device)
                # Forward pass for the specific attribute
                outputs = model(images, attr_index=attr_index)
                loss = criterion(outputs, labels)
                train_loss += loss.item()
                # Backward pass and optimization
                optimizer.zero_grad()
                loss.backward()
                # Add gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                optimizer.step()
                # Store predictions and labels for metrics
                _, preds = torch.max(outputs, 1)
                all_train_preds.extend(preds.cpu().numpy())
                all_train_labels.extend(labels.cpu().numpy())
            # Calculate training metrics
            train_accuracy = accuracy_score(all_train_labels, all_train_preds)
            train_precision = precision_score(all_train_labels, all_train_preds, average='weighted')
            train_recall = recall_score(all_train_labels, all_train_preds, average='weighted')
            train_f1 = f1_score(all_train_labels, all_train_preds, average='weighted')
            # Validation phase
            model.eval()
            val_loss = 0
            all_val_preds = []
            all_val_labels = []
            with torch.no_grad():
                for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1} val {attr}"):
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images, attr_index=attr_index)
                    val_loss += criterion(outputs, labels).item()
                    # Store predictions and labels for metrics
                    _, preds = torch.max(outputs, 1)
                    all_val_preds.extend(preds.cpu().numpy())
                    all_val_labels.extend(labels.cpu().numpy())
            # Calculate average validation loss for the scheduler
            avg_val_loss = val_loss / len(val_loader)
            # Update learning rate based on validation loss
            scheduler.step(avg_val_loss)
            # Calculate validation metrics
            val_accuracy = accuracy_score(all_val_labels, all_val_preds)
            val_precision = precision_score(all_val_labels, all_val_preds, average='weighted')
            val_recall = recall_score(all_val_labels, all_val_preds, average='weighted')
            val_f1 = f1_score(all_val_labels, all_val_preds, average='weighted')
            # Save model weights if validation accuracy improves for this attribute
            if val_accuracy >= best_val_acc:
                best_val_acc = val_accuracy
                # Save the model's state dictionary using .module
                torch.save(model.module.state_dict(), f'best_model_attr_{model_type}.pth')
            # Print current learning rate
            current_lr = optimizer.param_groups[0]['lr']
            print(f"Epoch {epoch+1}/{num_epochs}, Attr: {attr}, Learning Rate: {current_lr:.6f}")
            print(f"Train Loss: {train_loss/len(train_loader):.4f}, Train Accuracy: {train_accuracy:.4f}, "
                  f"Precision: {train_precision:.4f}, Recall: {train_recall:.4f}, F1-score: {train_f1:.4f}")
            print(f"Val Loss: {avg_val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}, "
                  f"Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1-score: {val_f1:.4f}")
    torch.save(model.module.state_dict(), f'best_model_attr_end_{model_type}.pth')

In [31]:
import pickle
import gc
def main(test_c_name):
    # Set image directory
    image_dir = f'{input_path}/train_images'
    # Initialize device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Define a function to train each model sequentially
    def train_single_model(data1,data2, num_attr_columns, model_type):
        attr_columns = [f'attr_{i}' for i in range(1, num_attr_columns+1)]
        print(f"Preparing data for {model_type}")
        label_encoders = LabelEncoderDict()
        label_encoders.fit(data2, attr_columns)
        train_loader, val_loader, label_encoders, num_classes_per_attr = prepare_data(data1,label_encoders, image_dir, batch_size=4, num_attr_columns=num_attr_columns)
        world_size = torch.cuda.device_count()
        # Initialize the model
        print(f"Initializing model {model_type}")
        model = MultiLabelClassifier(semi_classes_dict[test_c_name]).to(device)
        
        # Train the model
        print(f"Training model with base model {model_type}")
        model = train_model(model, train_loader, val_loader, num_epochs=NUM_EPOCH, num_classes_per_attr=num_classes_per_attr, model_type=model_type)
        print(f"\nTraining model with attribute only {model_type}")
        # Here, loading the best model after first base training on entire dataset
        
        # The one with problem
        # model.module.load_state_dict(torch.load(BEST_MODEL_FROM_BASE_FIRST_TRAINING, map_location=device), strict=False)
        
        # Updated, removed the strict=False
        # model = torch.nn.DataParallel(model)
        model.load_state_dict(torch.load(BEST_MODEL_FROM_BASE_FIRST_TRAINING, map_location=device))
        
        attr_loaders = prepare_attribute_data(data2,image_dir,label_encoders, batch_size=4,num_attr_columns=num_attr_columns)
        train_attribute_model(model,attr_loaders,NUM_ATTR_EPOCHS,model_type)
        # Save label encoders
        with open(f'label_encoders_{model_type}.pkl', 'wb') as f:
            pickle.dump(label_encoders, f)
        print("---------------------------------------------")
        print(f"number of classes for {model_type} is {num_classes_per_attr}")
        print("---------------------------------------------")
        # Free up memory
        del model, train_loader, val_loader, label_encoders, num_classes_per_attr
        torch.cuda.empty_cache()
        gc.collect()
    df1,df2 = do_preprocessing(test_c_name, category_dict[test_c_name])
    train_single_model(df1,df2, num_attr_columns=len(semi_classes_dict[test_c_name]), model_type=test_c_name)

In [None]:
main(test_c_name)

Length of c4 df before removing NaN values: 18774
Number of features for original df: [7, 3, 3, 3, 6, 3, 2, 2]

Length of c4 df after removing NaN values: 15565
Number of features for new df: [7, 3, 3, 3, 6, 3, 2, 2]
Preparing data for c4
Initializing model c4


config.json:   0%|          | 0.00/69.6k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/354M [00:00<?, ?B/s]

Training model with base model c4


Epoch 1/4 [Training]:   0%|          | 17/3502 [00:18<52:54,  1.10batch/s, loss=0.935]

In [None]:
test_df_semi = df_test[df_test['Category'] == category_dict[test_c_name]]

test_df_semi

In [None]:
import torch
from torchvision import transforms
from PIL import Image
import pickle
from tqdm import tqdm
import time
def load_model(model_path, num_classes_per_attr, device):
    # Initialize the model architecture and load the saved weights
    model = MultiLabelClassifier(num_classes_per_attr)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = torch.nn.DataParallel(model)
    model.to(device)
    model.eval()
    return model
    
def load_label_encoders(encoder_path):
    with open(encoder_path, 'rb') as f:
        encoders = pickle.load(f)
    return encoders

def preprocess_image(image_path, image_size=(512,512)):
    # Define image transformations (same as used during training)
    transform = transforms.Compose([
        transforms.Resize((512,512)),  # TODO: Change with (512, 512)
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    # Open image and apply transformations
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0)  # Add batch dimension
    return image

def inference(images, model, label_encoders):
    # Perform inference
    with torch.no_grad():
        outputs = model(x=images,return_features=False)
    # Decode predictions
    predicted_labels = []
    for i, output in enumerate(outputs):
        _, predicted = torch.max(output, 1)
        attr_name = f'attr_{i + 1}'
        if attr_name in label_encoders.encoders:
            decoded_label = label_encoders.encoders[attr_name].inverse_transform([predicted.item()])[0]
            predicted_labels.append(decoded_label)
        else:
            raise KeyError(f"Encoder for {attr_name} not found in the loaded label encoders.")
    return predicted_labels

In [None]:
model_path = f"best_model_{test_c_name}.pth"
encoder_path = f"{working_path}/label_encoders_{test_c_name}.pkl"
num_classes_per_attr = NUM_OF_SEMI_CLASSES_OF_COLUMNS
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = load_model(model_path, num_classes_per_attr, device)
label_encoders = load_label_encoders(encoder_path)
image_dir = f"{input_path}/test_images"
interval = len(semi_classes_dict[test_c_name]) +1
preds = {f'attr_{i}': [] for i in range(1,interval )}
t1 = time.time()
for val in tqdm(test_df_semi['id'], desc='Processing Images', total=len(test_df_semi)):
    image_path = f"{image_dir}/{str(val).zfill(6)}.jpg"
    image = preprocess_image(image_path).to(device)  # Preprocess and send image to device
    predictions = inference(image, model, label_encoders)  # Use the already loaded model and encoders
    for i in range(1, interval):
        preds[f'attr_{i}'].append(predictions[i-1])
print(f'Time taken to process images is {time.time() - t1} seconds, which is {len(test_df_semi) / (time.time() - t1)} images per second')

In [None]:
for i in range(1,interval):
    test_df_semi[f'attr_{i}'] = preds[f'attr_{i}']
test_df_semi    

In [None]:
test_df_semi.to_csv(f'test_validation_df_{test_c_name}.csv',index=False)

In [None]:
model_path = f'best_model_attr_{test_c_name}.pth'
encoder_path = f"{working_path}/label_encoders_{test_c_name}.pkl"
num_classes_per_attr = NUM_OF_SEMI_CLASSES_OF_COLUMNS
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = load_model(model_path, num_classes_per_attr, device)
label_encoders = load_label_encoders(encoder_path)
image_dir = f"{input_path}/test_images"
preds = {f'attr_{i}': [] for i in range(1, interval)}
t1 = time.time()
for val in tqdm(test_df_semi['id'], desc='Processing Images', total=len(test_df_semi)):
    image_path = f"{image_dir}/{str(val).zfill(6)}.jpg"
    image = preprocess_image(image_path).to(device)  # Preprocess and send image to device
    predictions = inference(image, model, label_encoders)  # Use the already loaded model and encoders
    for i in range(1, interval):
        preds[f'attr_{i}'].append(predictions[i-1])
print(f'Time taken to process images is {time.time() - t1} seconds, which is {len(test_df_semi) / (time.time() - t1)} images per second')

In [None]:
for i in range(1,interval):
    test_df_semi[f'attr_{i}'] = preds[f'attr_{i}']
test_df_semi

In [None]:
test_df_semi.to_csv(f'test_attr_validation_df_{test_c_name}.csv',index=False)