In [26]:
import time
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score
import pandas as pd
import os
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision.models import ResNet50_Weights
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, confusion_matrix


## FashionDataset Class Definition

This code defines a custom dataset class called `FashionDataset` for handling fashion image data. It inherits from PyTorch's `Dataset` class and is designed to work with a CSV file containing metadata and a directory of images.

Key features of the `FashionDataset` class include:

- Loading images and their corresponding labels (color, type, season, gender) from the CSV file and image directory.
- Encoding categorical labels using `LabelEncoder` for efficient processing.
- Applying different transformations to images based on whether they belong to minority classes, which helps in handling imbalanced datasets.
- Skipping missing image files and tracking their IDs.
- Providing utility methods to get the number of unique classes and sample counts for each category.

This class is essential for preparing the fashion data for training models, ensuring that the data is properly loaded, labeled, and augmented as needed.

In [3]:
class FashionDataset(Dataset):
    # Custom dataset for fashion images, loading images and labels (color, type, season, gender)
    # from a CSV file and image directory. Supports minority class augmentations and skips missing images.
    def __init__(self, csv_data, img_dir, transform=None, minority_classes=None):
        # Initialize the dataset with CSV data, image directory, transformations, and minority classes
        # Args:
        #     csv_data (pd.DataFrame): Metadata with labels
        #     img_dir (str): Path to image files
        #     transform (dict, optional): Transformations for default and minority classes
        #     minority_classes (dict, optional): Minority labels for each category
        self.data = csv_data
        self.img_dir = img_dir
        self.transform = transform
        self.minority_classes = minority_classes or {}  # Default to empty dict if None
        self.skipped_ids = []  # List to track IDs of skipped images

        # Initialize label encoders and encode labels for each category
        self.color_encoder = LabelEncoder().fit(self.data['baseColour'])
        self.type_encoder = LabelEncoder().fit(self.data['articleType'])
        self.season_encoder = LabelEncoder().fit(self.data['season'])
        self.gender_encoder = LabelEncoder().fit(self.data['gender'])

        # Transform categorical labels into encoded numerical values
        self.colors = self.color_encoder.transform(self.data['baseColour'])
        self.types = self.type_encoder.transform(self.data['articleType'])
        self.seasons = self.season_encoder.transform(self.data['season'])
        self.genders = self.gender_encoder.transform(self.data['gender'])

    def __len__(self):
        # Return the total number of samples in the dataset
        return len(self.data)

    def __getitem__(self, idx):
        # Retrieve a sample (image and labels) by index
        # Returns image and encoded labels (color, type, season, gender)
        # Skips missing images and applies transformations based on minority class status
        img_id = self.data.iloc[idx]['id']
        img_name = os.path.join(self.img_dir, f"{img_id}.jpg")

        try:
            # Load and convert image to RGB format
            image = Image.open(img_name).convert('RGB')
        except FileNotFoundError:
            # Handle missing image files by skipping them
            print(f"Skipped ID {img_id}: Image not found")
            self.skipped_ids.append(img_id)
            return None

        # Extract encoded labels for the current sample
        color = self.colors[idx]
        type_ = self.types[idx]
        season = self.seasons[idx]
        gender = self.genders[idx]

        # Apply transformations if specified
        if self.transform:
            # Determine if the sample belongs to a minority class
            is_minority = any(
                encoder.inverse_transform([label])[0] in self.minority_classes.get(category, [])
                for encoder, label, category in [
                    (self.color_encoder, color, 'color'),
                    (self.type_encoder, type_, 'articleType'),
                    (self.season_encoder, season, 'season'),
                    (self.gender_encoder, gender, 'gender')
                ]
            )
            # Choose transformation based on minority status
            transform_key = 'minority' if is_minority else 'default'
            image = self.transform[transform_key](image)

        # Return the processed image and its labels
        return image, color, type_, season, gender

    def get_num_classes(self):
        # Return a dictionary with the number of unique classes for each category
        return {
            'color': len(self.color_encoder.classes_),
            'type': len(self.type_encoder.classes_),
            'season': len(self.season_encoder.classes_),
            'gender': len(self.gender_encoder.classes_)
        }

    def get_class_counts(self):
        # Return a dictionary with sample counts for each class in each category
        return {
            'color': np.bincount(self.colors),
            'type': np.bincount(self.types),
            'season': np.bincount(self.seasons),
            'gender': np.bincount(self.genders)
        }

## MultiTaskModel for Fashion Attribute Classification


This code defines the `MultiTaskModel` class, a PyTorch neural network for multi-task classification of fashion item attributes (color, type, season, gender).

### Key Features:
- Uses a pre-trained ResNet50 backbone (frozen) to extract robust image features.
- Removes the final ResNet50 layer to obtain 2048-dimensional feature vectors.
- Employs four linear layers (`color_head`, `type_head`, `season_head`, `gender_head`) for attribute-specific predictions.
- Enables simultaneous classification of all attributes in one forward pass, optimizing computational efficiency.

### Why Use ResNet50?
ResNet50 is chosen for its deep architecture (50 layers) and residual connections, which allow it to learn complex image features effectively without suffering from vanishing gradients. Pre-trained on ImageNet, it provides strong feature extraction for fashion images, reducing training time and the need for large datasets. Other CNN models (e.g., VGG, MobileNet) could be used, but ResNet50 balances performance and computational efficiency, making it a robust choice for transfer learning in multi-task scenarios.

### Why Use 2048 Neurons in All Four Heads?
The 2048 neurons in each head correspond to the output dimension of the ResNet50 backbone after removing its final layer. ResNet50's feature extractor produces a 2048-dimensional vector, which is fed directly into each task-specific linear layer. This ensures compatibility with the backbone's output and allows each head to leverage the full feature set for accurate classification, tailored to the number of classes for each attribute (e.g., `num_colors`, `num_types`).

In [4]:
class MultiTaskModel(nn.Module):
    # Custom PyTorch model for multi-task classification of fashion attributes
    # Predicts color, type, season, and gender using a shared ResNet50 backbone
    def __init__(self, num_colors, num_types, num_seasons, num_genders):
        # Initialize the model with number of classes for each attribute
        super(MultiTaskModel, self).__init__()
        
        # Load pre-trained ResNet50 and remove the final classification layer
        self.base_model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.base_model = nn.Sequential(*list(self.base_model.children())[:-1])
        
        # Freeze backbone parameters to use pre-trained weights without updates
        for param in self.base_model.parameters():
            param.requires_grad = False
            
        # Define task-specific linear layers for each attribute
        self.color_head = nn.Linear(2048, num_colors)   # For color classification
        self.type_head = nn.Linear(2048, num_types)     # For type classification
        self.season_head = nn.Linear(2048, num_seasons) # For season classification
        self.gender_head = nn.Linear(2048, num_genders) # For gender classification

    def forward(self, x):
        # Forward pass to predict all attributes
        # Input x: Batch of images
        # Output: Tuple of predictions for color, type, season, and gender
        features = self.base_model(x)                    # Extract features using ResNet50
        features = features.view(features.size(0), -1)   # Flatten features to 2048-dim vectors
        color_output = self.color_head(features)        # Predict color classes
        type_output = self.type_head(features)          # Predict type classes
        season_output = self.season_head(features)      # Predict season classes
        gender_output = self.gender_head(features)      # Predict gender classes
        return color_output, type_output, season_output, gender_output

## Data Transformations for Fashion Dataset

This code defines a dictionary of transformations (`transform`) used by the `FashionDataset` class to preprocess images. It includes two sets of transformations:

- **Default**: Applied to non-minority class samples, it resizes images to 224x224 (required for ResNet50), applies random horizontal flips, converts images to tensors, and normalizes them using ImageNet mean and standard deviation.
- **Minority**: Applied to minority class samples for data augmentation, it includes the same resizing and normalization as the default, plus random horizontal flips (50% probability) and random rotations (up to 15 degrees) to increase diversity in underrepresented classes.

#### Why Resize to 224x224?
Images are resized to 224x224 to match the input size expected by the pre-trained ResNet50 model used in `MultiTaskModel`. This size is standard for ResNet50, optimized for ImageNet pre-training, ensuring compatibility with the model's convolutional layers, consistent batch processing, and efficient feature extraction without needing to retrain the backbone.

In [5]:
transform = {
    'default': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'minority': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
}

## Data Loading and Preprocessing

This code handles the loading and preprocessing of the fashion dataset for training the `MultiTaskModel`.

#### Key Steps:
- **Data Loading**: Loads metadata from a CSV file (`styles.csv`) and specifies the image directory.
- **Data Cleaning**: Removes rows with missing values in critical columns (`id`, `baseColour`, `articleType`, `season`, `gender`) to ensure data integrity.
- **Minority Class Definition**: Identifies minority classes for gender, season, color, and article type based on exploratory data analysis (EDA) to apply targeted augmentations in `FashionDataset`.
- **Data Splitting**: Splits the dataset into 85% train+validation and 15% test sets using a fixed random seed for reproducibility.
- **Cross-Validation Setup**: Initializes 5-fold cross-validation to split train+validation data for robust model evaluation.

These steps prepare the dataset for use with the `FashionDataset` class, ensuring clean data and proper splits for training, validation, and testing.

In [6]:
csv_file = '/kaggle/input/fashion-product-images-dataset/fashion-dataset/styles.csv'
img_dir = '/kaggle/input/fashion-product-images-dataset/fashion-dataset/images'
data = pd.read_csv(csv_file, on_bad_lines='skip')
# Handle missing values
data = data.dropna(subset=['id', 'baseColour', 'articleType', 'season', 'gender'])

# Define minority classes (based on eda_report_final.md)
minority_classes = {
    'gender': ['Girls', 'Unisex', 'Boys'],  # Men: 49.8%, Women: 37.3%, others <6%
    'season': ['Spring'],  # Spring: 3.7%
    'color': [c for c in data['baseColour'].unique() if data['baseColour'].value_counts()[c] < 1000],  # Colors with <1000 occurrences
    'articleType': [a for a in data['articleType'].unique() if data['articleType'].value_counts()[a] < 100]  # Article types with <100 occurrences
}

# Split into train+val (85%) and test (15%)
train_val_data, test_data = train_test_split(data, test_size=0.15, random_state=42)

# Setup K-Fold Cross-Validation (k=5)
kfold = KFold(n_splits=5, shuffle=True, random_state=42)

## DataLoader and Custom Collate Function

This code defines utilities for creating a PyTorch `DataLoader` and handling batches in the `FashionDataset`.

#### Key Components:
- **`create_dataloader`**: Creates a `DataLoader` to batch and load data from a dataset (e.g., `FashionDataset`). It supports customizable batch size, shuffling, and multi-threaded loading (`num_workers`). The `pin_memory` option optimizes data transfer for GPU training.
- **`custom_collate_fn`**: A custom collate function that filters out `None` values in a batch (e.g., caused by missing images in `FashionDataset`) before applying the default collation to create tensors.

These utilities ensure efficient and robust data loading for training and evaluation, handling issues like missing images gracefully.

In [7]:
# Create a DataLoader for the dataset with custom settings
def create_dataloader(dataset, batch_size=32, shuffle=True, num_workers=4):
    # Returns a DataLoader for batching and loading data
    # Args:
    #     dataset: PyTorch dataset (e.g., FashionDataset)
    #     batch_size: Number of samples per batch (default: 32)
    #     shuffle: Whether to shuffle data (default: True)
    #     num_workers: Number of subprocesses for data loading (default: 4)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=custom_collate_fn,  # Use custom collate function to handle None values
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False  # Optimize memory for GPU
    )

# Custom collate function to filter out None values in a batch
def custom_collate_fn(batch):
    # Remove None items (e.g., from missing images in FashionDataset)
    batch = [item for item in batch if item is not None]
    # If batch is empty, return None
    if not batch:
        return None
    # Use default collate function to process valid items
    return torch.utils.data.dataloader.default_collate(batch)

## Class Weights, Model, and Training Setup

This code sets up the class weights, model, loss functions, and optimizer for training the `MultiTaskModel`.

#### Key Components:
- **`compute_class_weights`**: Calculates inverse frequency weights for each class to handle imbalanced data, normalizing them to scale with the number of classes.
- **Dataset Initialization**: Creates a `FashionDataset` instance to retrieve the number of classes and class counts for each attribute (color, type, season, gender).
- **Class Weights**: Computes weights for each task to address class imbalance and moves them to the appropriate device (GPU/CPU).
- **Model Initialization**: Instantiates the `MultiTaskModel` with the number of classes for each attribute and moves it to the device.
- **Loss Functions**: Defines `CrossEntropyLoss` for each task with corresponding class weights to prioritize minority classes.
- **Optimizer**: Uses the Adam optimizer with a learning rate of 0.001 to update model parameters.

This setup ensures the model is ready for training, with weighted losses to handle class imbalance effectively.

#### Why Use CrossEntropyLoss?
`CrossEntropyLoss` is used because it’s ideal for multi-class classification tasks (color, type, season, gender), combining softmax and negative log-likelihood loss. It supports class weights to handle imbalance.

In [8]:
def compute_class_weights(class_counts, num_classes):
    weights = 1.0 / (np.array(class_counts) + 1e-6)  # Inverse frequency
    weights = weights / weights.sum() * num_classes  # Normalize
    return torch.tensor(weights, dtype=torch.float32)

# Initialize dataset to get class counts
full_dataset = FashionDataset(data, img_dir, transform=transform)
num_classes = full_dataset.get_num_classes()
class_counts = full_dataset.get_class_counts()

# Compute class weights for each task
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
color_weights = compute_class_weights(class_counts['color'], num_classes['color']).to(device)
type_weights = compute_class_weights(class_counts['type'], num_classes['type']).to(device)
season_weights = compute_class_weights(class_counts['season'], num_classes['season']).to(device)
gender_weights = compute_class_weights(class_counts['gender'], num_classes['gender']).to(device)

# Initialize model, loss functions, and optimizer
model = MultiTaskModel(
    num_classes['color'],
    num_classes['type'],
    num_classes['season'],
    num_classes['gender']
).to(device)

# Define loss functions with class weights for each task
criterion_color = nn.CrossEntropyLoss(weight=color_weights)
criterion_type = nn.CrossEntropyLoss(weight=type_weights)
criterion_season = nn.CrossEntropyLoss(weight=season_weights)
criterion_gender = nn.CrossEntropyLoss(weight=gender_weights)

# Adam optimizer with learning rate 0.001
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [9]:
num_epochs = 10
k = 5
fold_results = []
fold_metrics = []
best_fold = 0
best_avg_f1 = 0.0
best_model_path = '/kaggle/working/best_model.pth'
tasks = ['color', 'type', 'season', 'gender']

## K-Fold Training and Validation Loop

This code implements the k-fold cross-validation loop for training and evaluating the `MultiTaskModel` on the fashion dataset.

#### Key Components:
- **k-fold cross validation**: Uses `KFold` (from Cell 5) to split `train_val_data` into training and validation sets for each fold.
- **Dataset and DataLoader**: Creates `FashionDataset` and `DataLoader` instances for training and validation using `create_dataloader` (Cell 6).
- **Model Setup**: Initializes a fresh `MultiTaskModel` and Adam optimizer per fold, with a `ReduceLROnPlateau` scheduler to adjust the learning rate based on validation loss.
- **Checkpointing**: Resumes training from the latest checkpoint (if available) and saves checkpoints every 10 epochs.
- **Training Loop**: Trains the model on each batch, computing task-specific losses (color, type, season, gender), performing backpropagation, and updating parameters.
- **Validation Loop**: Evaluates the model on the validation set, calculating loss and macro-averaged metrics (F1-score, precision, recall) plus accuracy for each task.
- **Early Stopping**: Stops training if validation loss doesn’t improve for 3 epochs.
- **Model Saving**: Saves the best model based on the average macro F1-score across tasks.
- **Metrics Tracking**: Stores training and validation metrics for analysis.

#### Why is Macro-Averaged F1-Score the Most Significant Metric?
The macro-averaged F1-score is the most significant metric because the dataset has imbalanced classes (e.g., 'Spring' at 3.7%, 'Girls' <6%). It averages F1-scores per class equally, ensuring minority classes (identified in Cell 5) are not ignored, unlike accuracy, which favors majority classes. It balances precision and recall, is critical for multi-task evaluation (color, type, season, gender), and drives model selection (best model saved based on `avg_f1`), making it ideal for this imbalanced multi-class classification task.

This loop ensures robust training and evaluation, optimizing model performance across all folds.

In [13]:
# Training loop for k-fold cross-validation
for fold, (train_idx, val_idx) in enumerate(kfold.split(train_val_data)):
    # Print current fold number
    print(f"Fold {fold+1}/{k}")
    
    # Split train and validation data for the current fold
    train_data = train_val_data.iloc[train_idx].reset_index(drop=True)
    val_data = train_val_data.iloc[val_idx].reset_index(drop=True)
    
    # Create train and validation datasets
    train_dataset = FashionDataset(train_data, img_dir, transform=transform, minority_classes=minority_classes)
    val_dataset = FashionDataset(val_data, img_dir, transform=transform, minority_classes=minority_classes)

    # Create DataLoaders for train and validation
    train_loader = create_dataloader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
    val_loader = create_dataloader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

    # Initialize model and optimizer for the current fold
    model = MultiTaskModel(
        num_classes['color'],   # Number of color classes
        num_classes['type'],    # Number of type classes
        num_classes['season'],  # Number of season classes
        num_classes['gender']   # Number of gender classes
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam optimizer with learning rate 0.001
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, min_lr=1e-6, verbose=True)  # Reduce LR on plateau

    # Load latest checkpoint if available
    checkpoint_files = glob.glob(f'/kaggle/working/checkpoint_fold_{fold+1}_epoch_*.pth')
    start_epoch = 0
    best_val_loss = float('inf')
    if checkpoint_files:
        latest_checkpoint = max(checkpoint_files, key=lambda x: int(x.split('_epoch_')[-1].split('.pth')[0]))
        checkpoint = torch.load(latest_checkpoint)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint['best_val_loss']
        print(f"Resuming Fold {fold+1} from checkpoint {latest_checkpoint} at Epoch {start_epoch}")

    # Initialize metrics storage for this fold
    epoch_metrics = {
        'train_loss': [],
        'val_loss': [],
        'val_f1_color': [], 'val_f1_type': [], 'val_f1_season': [], 'val_f1_gender': [],
        'val_precision_color': [], 'val_precision_type': [], 'val_precision_season': [], 'val_precision_gender': [],
        'val_recall_color': [], 'val_recall_type': [], 'val_recall_season': [], 'val_recall_gender': [],
        'val_accuracy_color': [], 'val_accuracy_type': [], 'val_accuracy_season': [], 'val_accuracy_gender': [],
        'loss_color': [], 'loss_type': [], 'loss_season': [], 'loss_gender': []
    }

    # Training loop for each epoch
    patience = 3
    counter = 0
    for epoch in range(start_epoch, num_epochs):
        epoch_start_time = time.time()
        model.train()  # Set model to training mode
        running_loss = 0.0
        running_loss_color = 0.0
        running_loss_type = 0.0
        running_loss_season = 0.0
        running_loss_gender = 0.0
        count = 0
        for batch in train_loader:
            batch_start_time = time.time()
            if batch is None:
                continue
            # Load batch data and move to device
            images, colors, types, seasons, genders = batch
            images, colors, types, seasons, genders = (
                images.to(device), colors.to(device), types.to(device),
                seasons.to(device), genders.to(device)
            )
            optimizer.zero_grad()  # Clear gradients
            # Forward pass
            color_pred, type_pred, season_pred, gender_pred = model(images)
            # Compute loss for each task
            loss_color = criterion_color(color_pred, colors)
            loss_type = criterion_type(type_pred, types)
            loss_season = criterion_season(season_pred, seasons)
            loss_gender = criterion_gender(gender_pred, genders)
            total_loss = loss_color + loss_type + loss_season + loss_gender
            total_loss.backward()  # Backpropagation
            optimizer.step()      # Update model parameters
            # Accumulate losses
            running_loss += total_loss.item()
            running_loss_color += loss_color.item()
            running_loss_type += loss_type.item()
            running_loss_season += loss_season.item()
            running_loss_gender += loss_gender.item()
            count += 1
            batch_time = time.time() - batch_start_time
            if count % 10 == 0:
                print(f"Fold {fold+1}, Epoch {epoch+1}, Batch {count}, Loss: {total_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}, Batch Time: {batch_time:.2f}s")
        # Compute average training loss
        avg_train_loss = running_loss / count if count > 0 else float('inf')
        epoch_metrics['train_loss'].append(avg_train_loss)
        epoch_metrics['loss_color'].append(running_loss_color / count)
        epoch_metrics['loss_type'].append(running_loss_type / count)
        epoch_metrics['loss_season'].append(running_loss_season / count)
        epoch_metrics['loss_gender'].append(running_loss_gender / count)
        epoch_time = time.time() - epoch_start_time
        print(f"Fold {fold+1}, Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Epoch {epoch+1} Time: {epoch_time:.2f}s")

        # Validation phase
        model.eval()  # Set model to evaluation mode
        val_loss = 0.0
        val_count = 0
        val_color_true, val_color_pred = [], []
        val_type_true, val_type_pred = [], []
        val_season_true, val_season_pred = [], []
        val_gender_true, val_gender_pred = [], []
        with torch.no_grad():  # Disable gradient computation
            for batch in val_loader:
                if batch is None:
                    continue
                # Load batch data and move to device
                images, colors, types, seasons, genders = batch
                images, colors, types, seasons, genders = (
                    images.to(device), colors.to(device), types.to(device),
                    seasons.to(device), genders.to(device)
                )
                # Forward pass
                color_pred, type_pred, season_pred, gender_pred = model(images)
                # Compute validation loss
                val_loss += (criterion_color(color_pred, colors) + criterion_type(type_pred, types) +
                             criterion_season(season_pred, seasons) + criterion_gender(gender_pred, genders)).item()
                val_count += 1
                # Store predictions and true labels for metrics
                val_color_true.extend(colors.cpu().numpy())
                val_color_pred.extend(torch.argmax(color_pred, dim=1).cpu().numpy())
                val_type_true.extend(types.cpu().numpy())
                val_type_pred.extend(torch.argmax(type_pred, dim=1).cpu().numpy())
                val_season_true.extend(seasons.cpu().numpy())
                val_season_pred.extend(torch.argmax(season_pred, dim=1).cpu().numpy())
                val_gender_true.extend(genders.cpu().numpy())
                val_gender_pred.extend(torch.argmax(gender_pred, dim=1).cpu().numpy())
        # Compute average validation loss
        avg_val_loss = val_loss / val_count if val_count > 0 else float('inf')
        # Calculate macro-averaged metrics for each task
        val_f1_color = f1_score(val_color_true, val_color_pred, average='macro')
        val_f1_type = f1_score(val_type_true, val_type_pred, average='macro')
        val_f1_season = f1_score(val_season_true, val_season_pred, average='macro')
        val_f1_gender = f1_score(val_gender_true, val_gender_pred, average='macro')
        val_precision_color = precision_score(val_color_true, val_color_pred, average='macro')
        val_precision_type = precision_score(val_type_true, val_type_pred, average='macro')
        val_precision_season = precision_score(val_season_true, val_season_pred, average='macro')
        val_precision_gender = precision_score(val_gender_true, val_gender_pred, average='macro')
        val_recall_color = recall_score(val_color_true, val_color_pred, average='macro')
        val_recall_type = recall_score(val_type_true, val_type_pred, average='macro')
        val_recall_season = recall_score(val_season_true, val_season_pred, average='macro')
        val_recall_gender = recall_score(val_gender_true, val_gender_pred, average='macro')
        val_accuracy_color = accuracy_score(val_color_true, val_color_pred)
        val_accuracy_type = accuracy_score(val_type_true, val_type_pred)
        val_accuracy_season = accuracy_score(val_season_true, val_season_pred)
        val_accuracy_gender = accuracy_score(val_gender_true, val_gender_pred)
        avg_f1 = (val_f1_color + val_f1_type + val_f1_season + val_f1_gender) / 4
        # Store validation metrics
        epoch_metrics['val_loss'].append(avg_val_loss)
        epoch_metrics['val_f1_color'].append(val_f1_color)
        epoch_metrics['val_f1_type'].append(val_f1_type)
        epoch_metrics['val_f1_season'].append(val_f1_season)
        epoch_metrics['val_f1_gender'].append(val_f1_gender)
        epoch_metrics['val_precision_color'].append(val_precision_color)
        epoch_metrics['val_precision_type'].append(val_precision_type)
        epoch_metrics['val_precision_season'].append(val_precision_season)
        epoch_metrics['val_precision_gender'].append(val_precision_gender)
        epoch_metrics['val_recall_color'].append(val_recall_color)
        epoch_metrics['val_recall_type'].append(val_recall_type)
        epoch_metrics['val_recall_season'].append(val_recall_season)
        epoch_metrics['val_recall_gender'].append(val_recall_gender)
        epoch_metrics['val_accuracy_color'].append(val_accuracy_color)
        epoch_metrics['val_accuracy_type'].append(val_accuracy_type)
        epoch_metrics['val_accuracy_season'].append(val_accuracy_season)
        epoch_metrics['val_accuracy_gender'].append(val_accuracy_gender)
        # Print validation metrics
        print(f"Fold {fold+1}, Validation Loss: {avg_val_loss:.4f}")
        print(f"Fold {fold+1}, Validation F1-Scores (Macro) - Color: {val_f1_color:.4f}, Type: {val_f1_type:.4f}, Season: {val_f1_season:.4f}, Gender: {val_f1_gender:.4f}, Avg F1: {avg_f1:.4f}")
        print(f"Fold {fold+1}, Validation Precision (Macro) - Color: {val_precision_color:.4f}, Type: {val_precision_type:.4f}, Season: {val_precision_season:.4f}, Gender: {val_precision_gender:.4f}")
        print(f"Fold {fold+1}, Validation Recall (Macro) - Color: {val_recall_color:.4f}, Type: {val_recall_type:.4f}, Season: {val_recall_season:.4f}, Gender: {val_recall_gender:.4f}")
        print(f"Fold {fold+1}, Validation Accuracy - Color: {val_accuracy_color:.4f}, Type: {val_accuracy_type:.4f}, Season: {val_accuracy_season:.4f}, Gender: {val_accuracy_gender:.4f}")

        # Update learning rate based on validation loss
        scheduler.step(avg_val_loss)

        # Save checkpoint every 10th epoch
        if (epoch + 1) % 10 == 0:
            checkpoint_path = f'/kaggle/working/checkpoint_fold_{fold+1}_epoch_{epoch+1}.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss
            }, checkpoint_path)
            print(f"Saved checkpoint for Fold {fold+1} at {checkpoint_path}")

        # Early stopping and best model saving
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            counter = 0
            if avg_f1 > best_avg_f1:
                best_avg_f1 = avg_f1
                best_fold = fold + 1
                torch.save(model.state_dict(), best_model_path)
                print(f"Saved best model from Fold {fold+1} at {best_model_path} (Avg F1: {avg_f1:.4f})")
        else:
            counter += 1
            if counter >= patience:
                print(f"Early stopping triggered for Fold {fold+1}")
                break

    # Store results for this fold
    fold_results.append({
        'fold': fold+1,
        'best_val_loss': best_val_loss,
        'val_f1_color': val_f1_color,
        'val_f1_type': val_f1_type,
        'val_f1_season': val_f1_season,
        'val_f1_gender': val_f1_gender,
        'val_precision_color': val_precision_color,
        'val_precision_type': val_precision_type,
        'val_precision_season': val_precision_season,
        'val_precision_gender': val_precision_gender,
        'val_recall_color': val_recall_color,
        'val_recall_type': val_recall_type,
        'val_recall_season': val_recall_season,
        'val_recall_gender': val_recall_gender,
        'val_accuracy_color': val_accuracy_color,
        'val_accuracy_type': val_accuracy_type,
        'val_accuracy_season': val_accuracy_season,
        'val_accuracy_gender': val_accuracy_gender,
        'avg_f1': avg_f1,
        'skipped_ids_train': train_dataset.skipped_ids,
        'skipped_ids_val': val_dataset.skipped_ids
    })
    fold_metrics.append(epoch_metrics)

Fold 1/5




Fold 1, Epoch 1, Batch 10, Loss: 9.8735, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 1, Batch 20, Loss: 9.8006, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 1, Batch 30, Loss: 8.7932, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 1, Batch 40, Loss: 9.1644, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 1, Batch 50, Loss: 7.4616, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 1, Batch 60, Loss: 7.4593, LR: 0.001000, Batch Time: 0.13s
Fold 1, Epoch 1, Batch 70, Loss: 6.6816, LR: 0.001000, Batch Time: 0.13s
Fold 1, Epoch 1, Batch 80, Loss: 8.2556, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 1, Batch 90, Loss: 6.3210, LR: 0.001000, Batch Time: 0.13s
Fold 1, Epoch 1, Batch 100, Loss: 5.4569, LR: 0.001000, Batch Time: 0.13s
Fold 1, Epoch 1, Batch 110, Loss: 7.2192, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 1, Batch 120, Loss: 5.0299, LR: 0.001000, Batch Time: 0.13s
Fold 1, Epoch 1, Batch 130, Loss: 5.3274, LR: 0.001000, Batch Time: 0.13s
Fold 1, Epoch 1, Batch 140, Loss: 5.6630, LR: 0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 1, Validation Loss: 15.0288
Fold 1, Validation F1-Scores (Macro) - Color: 0.0663, Type: 0.0188, Season: 0.6244, Gender: 0.6773, Avg F1: 0.3467
Fold 1, Validation Precision (Macro) - Color: 0.0798, Type: 0.0314, Season: 0.5984, Gender: 0.6297
Fold 1, Validation Recall (Macro) - Color: 0.0751, Type: 0.0158, Season: 0.6919, Gender: 0.7953
Fold 1, Validation Accuracy - Color: 0.2202, Type: 0.0131, Season: 0.6026, Gender: 0.8265
Saved best model from Fold 1 at /kaggle/working/best_model.pth (Avg F1: 0.3467)
Fold 1, Epoch 2, Batch 10, Loss: 3.1974, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 2, Batch 20, Loss: 3.6017, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 2, Batch 30, Loss: 2.0701, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 2, Batch 40, Loss: 5.2835, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 2, Batch 50, Loss: 4.4174, LR: 0.001000, Batch Time: 0.13s
Fold 1, Epoch 2, Batch 60, Loss: 3.5223, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 2, Batch 70, Loss: 3.3261, LR: 0.00

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 1, Epoch 3, Batch 10, Loss: 2.6097, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 3, Batch 20, Loss: 2.7511, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 3, Batch 30, Loss: 2.6516, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 3, Batch 40, Loss: 2.7273, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 3, Batch 50, Loss: 2.9999, LR: 0.001000, Batch Time: 0.11s
Fold 1, Epoch 3, Batch 60, Loss: 2.1630, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 3, Batch 70, Loss: 3.6720, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 3, Batch 80, Loss: 3.6472, LR: 0.001000, Batch Time: 0.13s
Fold 1, Epoch 3, Batch 90, Loss: 2.3933, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 3, Batch 100, Loss: 2.2442, LR: 0.001000, Batch Time: 0.13s
Fold 1, Epoch 3, Batch 110, Loss: 3.7841, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 3, Batch 120, Loss: 3.1039, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 3, Batch 130, Loss: 2.7479, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 3, Batch 140, Loss: 3.1748, LR: 0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 1, Epoch 4, Batch 10, Loss: 2.4068, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 4, Batch 20, Loss: 3.1775, LR: 0.001000, Batch Time: 0.13s
Fold 1, Epoch 4, Batch 30, Loss: 2.8376, LR: 0.001000, Batch Time: 0.11s
Fold 1, Epoch 4, Batch 40, Loss: 3.8021, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 4, Batch 50, Loss: 2.8274, LR: 0.001000, Batch Time: 0.11s
Fold 1, Epoch 4, Batch 60, Loss: 3.2732, LR: 0.001000, Batch Time: 0.12s
Skipped ID 39403: Image file /kaggle/input/fashion-product-images-dataset/fashion-dataset/images/39403.jpg not found
Fold 1, Epoch 4, Batch 70, Loss: 2.7063, LR: 0.001000, Batch Time: 0.11s
Fold 1, Epoch 4, Batch 80, Loss: 2.9099, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 4, Batch 90, Loss: 2.5356, LR: 0.001000, Batch Time: 0.11s
Fold 1, Epoch 4, Batch 100, Loss: 3.1718, LR: 0.001000, Batch Time: 0.12s
Fold 1, Epoch 4, Batch 110, Loss: 2.0863, LR: 0.001000, Batch Time: 0.11s
Fold 1, Epoch 4, Batch 120, Loss: 2.7202, LR: 0.001000, Batch Time: 0.13s
Fold

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 1, Validation Loss: 20.6303
Fold 1, Validation F1-Scores (Macro) - Color: 0.0817, Type: 0.0328, Season: 0.6667, Gender: 0.7146, Avg F1: 0.3739
Fold 1, Validation Precision (Macro) - Color: 0.0922, Type: 0.0377, Season: 0.6377, Gender: 0.6564
Fold 1, Validation Recall (Macro) - Color: 0.1234, Type: 0.0385, Season: 0.7168, Gender: 0.8197
Fold 1, Validation Accuracy - Color: 0.2603, Type: 0.0182, Season: 0.6497, Gender: 0.8586
Early stopping triggered for Fold 1
Fold 2/5




Fold 2, Epoch 1, Batch 10, Loss: 9.6210, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 1, Batch 20, Loss: 7.3753, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 1, Batch 30, Loss: 8.0572, LR: 0.001000, Batch Time: 0.13s
Fold 2, Epoch 1, Batch 40, Loss: 8.0563, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 1, Batch 50, Loss: 7.4747, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 1, Batch 60, Loss: 8.0587, LR: 0.001000, Batch Time: 0.13s
Fold 2, Epoch 1, Batch 70, Loss: 8.3885, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 1, Batch 80, Loss: 5.5247, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 1, Batch 90, Loss: 5.8194, LR: 0.001000, Batch Time: 0.13s
Fold 2, Epoch 1, Batch 100, Loss: 6.0073, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 1, Batch 110, Loss: 5.5881, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 1, Batch 120, Loss: 7.1520, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 1, Batch 130, Loss: 8.1553, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 1, Batch 140, Loss: 5.7137, LR: 0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 2, Validation Loss: 18.2874
Fold 2, Validation F1-Scores (Macro) - Color: 0.0663, Type: 0.0198, Season: 0.6328, Gender: 0.6733, Avg F1: 0.3481
Fold 2, Validation Precision (Macro) - Color: 0.0684, Type: 0.0271, Season: 0.6036, Gender: 0.6162
Fold 2, Validation Recall (Macro) - Color: 0.0869, Type: 0.0325, Season: 0.6900, Gender: 0.7864
Fold 2, Validation Accuracy - Color: 0.2426, Type: 0.0146, Season: 0.6267, Gender: 0.8424
Saved best model from Fold 2 at /kaggle/working/best_model.pth (Avg F1: 0.3481)
Fold 2, Epoch 2, Batch 10, Loss: 3.9077, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 2, Batch 20, Loss: 3.7965, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 2, Batch 30, Loss: 2.6804, LR: 0.001000, Batch Time: 0.13s
Fold 2, Epoch 2, Batch 40, Loss: 3.7320, LR: 0.001000, Batch Time: 0.13s
Fold 2, Epoch 2, Batch 50, Loss: 2.9686, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 2, Batch 60, Loss: 3.9903, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 2, Batch 70, Loss: 3.1025, LR: 0.00

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 2, Epoch 3, Batch 10, Loss: 3.7274, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 3, Batch 20, Loss: 2.9910, LR: 0.001000, Batch Time: 0.13s
Fold 2, Epoch 3, Batch 30, Loss: 2.4251, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 3, Batch 40, Loss: 2.9676, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 3, Batch 50, Loss: 4.6254, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 3, Batch 60, Loss: 2.4867, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 3, Batch 70, Loss: 3.2761, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 3, Batch 80, Loss: 2.0207, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 3, Batch 90, Loss: 3.3797, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 3, Batch 100, Loss: 2.6337, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 3, Batch 110, Loss: 2.3788, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 3, Batch 120, Loss: 3.3205, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 3, Batch 130, Loss: 4.8689, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 3, Batch 140, Loss: 3.9754, LR: 0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 2, Epoch 4, Batch 10, Loss: 1.8573, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 4, Batch 20, Loss: 2.7294, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 4, Batch 30, Loss: 3.2501, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 4, Batch 40, Loss: 2.5443, LR: 0.001000, Batch Time: 0.13s
Fold 2, Epoch 4, Batch 50, Loss: 2.2191, LR: 0.001000, Batch Time: 0.13s
Fold 2, Epoch 4, Batch 60, Loss: 2.5408, LR: 0.001000, Batch Time: 0.13s
Fold 2, Epoch 4, Batch 70, Loss: 2.6901, LR: 0.001000, Batch Time: 0.13s
Fold 2, Epoch 4, Batch 80, Loss: 3.0441, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 4, Batch 90, Loss: 2.6443, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 4, Batch 100, Loss: 2.9463, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 4, Batch 110, Loss: 2.7030, LR: 0.001000, Batch Time: 0.12s
Fold 2, Epoch 4, Batch 120, Loss: 2.3439, LR: 0.001000, Batch Time: 0.11s
Fold 2, Epoch 4, Batch 130, Loss: 2.4454, LR: 0.001000, Batch Time: 0.13s
Fold 2, Epoch 4, Batch 140, Loss: 3.7905, LR: 0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 2, Validation Loss: 24.0192
Fold 2, Validation F1-Scores (Macro) - Color: 0.0875, Type: 0.0280, Season: 0.6700, Gender: 0.7213, Avg F1: 0.3767
Fold 2, Validation Precision (Macro) - Color: 0.0896, Type: 0.0296, Season: 0.6400, Gender: 0.6658
Fold 2, Validation Recall (Macro) - Color: 0.1130, Type: 0.0331, Season: 0.7220, Gender: 0.8244
Fold 2, Validation Accuracy - Color: 0.2455, Type: 0.0198, Season: 0.6514, Gender: 0.8594
Early stopping triggered for Fold 2
Fold 3/5




Fold 3, Epoch 1, Batch 10, Loss: 10.2857, LR: 0.001000, Batch Time: 0.13s
Fold 3, Epoch 1, Batch 20, Loss: 9.0265, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 1, Batch 30, Loss: 8.1291, LR: 0.001000, Batch Time: 0.13s
Skipped ID 39425: Image file /kaggle/input/fashion-product-images-dataset/fashion-dataset/images/39425.jpg not found
Fold 3, Epoch 1, Batch 40, Loss: 7.3091, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 1, Batch 50, Loss: 6.1402, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 1, Batch 60, Loss: 6.3885, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 1, Batch 70, Loss: 6.7217, LR: 0.001000, Batch Time: 0.11s
Fold 3, Epoch 1, Batch 80, Loss: 8.7025, LR: 0.001000, Batch Time: 0.13s
Fold 3, Epoch 1, Batch 90, Loss: 11.3368, LR: 0.001000, Batch Time: 0.13s
Fold 3, Epoch 1, Batch 100, Loss: 5.5931, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 1, Batch 110, Loss: 5.7603, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 1, Batch 120, Loss: 5.2706, LR: 0.001000, Batch Time: 0.12s
Fo

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 3, Validation Loss: 17.4834
Fold 3, Validation F1-Scores (Macro) - Color: 0.0704, Type: 0.0335, Season: 0.6479, Gender: 0.6884, Avg F1: 0.3601
Fold 3, Validation Precision (Macro) - Color: 0.0856, Type: 0.0447, Season: 0.6195, Gender: 0.6287
Fold 3, Validation Recall (Macro) - Color: 0.0899, Type: 0.0367, Season: 0.6999, Gender: 0.8025
Fold 3, Validation Accuracy - Color: 0.2200, Type: 0.0370, Season: 0.6327, Gender: 0.8398
Saved best model from Fold 3 at /kaggle/working/best_model.pth (Avg F1: 0.3601)
Fold 3, Epoch 2, Batch 10, Loss: 3.4443, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 2, Batch 20, Loss: 7.4697, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 2, Batch 30, Loss: 3.8788, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 2, Batch 40, Loss: 4.2115, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 2, Batch 50, Loss: 3.1568, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 2, Batch 60, Loss: 3.7573, LR: 0.001000, Batch Time: 0.13s
Fold 3, Epoch 2, Batch 70, Loss: 3.5892, LR: 0.00

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 3, Epoch 3, Batch 10, Loss: 2.8282, LR: 0.001000, Batch Time: 0.11s
Fold 3, Epoch 3, Batch 20, Loss: 3.8507, LR: 0.001000, Batch Time: 0.13s
Fold 3, Epoch 3, Batch 30, Loss: 2.8722, LR: 0.001000, Batch Time: 0.11s
Fold 3, Epoch 3, Batch 40, Loss: 2.8287, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 3, Batch 50, Loss: 3.4391, LR: 0.001000, Batch Time: 0.11s
Fold 3, Epoch 3, Batch 60, Loss: 3.1066, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 3, Batch 70, Loss: 2.7286, LR: 0.001000, Batch Time: 0.11s
Fold 3, Epoch 3, Batch 80, Loss: 3.5345, LR: 0.001000, Batch Time: 0.13s
Fold 3, Epoch 3, Batch 90, Loss: 3.4188, LR: 0.001000, Batch Time: 0.11s
Fold 3, Epoch 3, Batch 100, Loss: 3.7362, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 3, Batch 110, Loss: 3.4286, LR: 0.001000, Batch Time: 0.11s
Fold 3, Epoch 3, Batch 120, Loss: 2.3021, LR: 0.001000, Batch Time: 0.11s
Fold 3, Epoch 3, Batch 130, Loss: 3.4234, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 3, Batch 140, Loss: 3.3597, LR: 0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 3, Epoch 4, Batch 10, Loss: 2.8885, LR: 0.001000, Batch Time: 0.11s
Fold 3, Epoch 4, Batch 20, Loss: 2.4011, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 4, Batch 30, Loss: 2.6797, LR: 0.001000, Batch Time: 0.13s
Fold 3, Epoch 4, Batch 40, Loss: 2.5674, LR: 0.001000, Batch Time: 0.13s
Fold 3, Epoch 4, Batch 50, Loss: 3.3029, LR: 0.001000, Batch Time: 0.11s
Fold 3, Epoch 4, Batch 60, Loss: 2.5515, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 4, Batch 70, Loss: 4.3022, LR: 0.001000, Batch Time: 0.11s
Fold 3, Epoch 4, Batch 80, Loss: 3.8097, LR: 0.001000, Batch Time: 0.13s
Fold 3, Epoch 4, Batch 90, Loss: 2.4905, LR: 0.001000, Batch Time: 0.13s
Fold 3, Epoch 4, Batch 100, Loss: 2.3630, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 4, Batch 110, Loss: 2.9113, LR: 0.001000, Batch Time: 0.11s
Fold 3, Epoch 4, Batch 120, Loss: 2.7122, LR: 0.001000, Batch Time: 0.12s
Fold 3, Epoch 4, Batch 130, Loss: 2.8492, LR: 0.001000, Batch Time: 0.11s
Fold 3, Epoch 4, Batch 140, Loss: 3.4946, LR: 0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 3, Validation Loss: 23.1391
Fold 3, Validation F1-Scores (Macro) - Color: 0.0828, Type: 0.0404, Season: 0.6605, Gender: 0.7202, Avg F1: 0.3760
Fold 3, Validation Precision (Macro) - Color: 0.0958, Type: 0.0496, Season: 0.6314, Gender: 0.6597
Fold 3, Validation Recall (Macro) - Color: 0.0983, Type: 0.0384, Season: 0.7196, Gender: 0.8277
Fold 3, Validation Accuracy - Color: 0.2538, Type: 0.0355, Season: 0.6390, Gender: 0.8589
Early stopping triggered for Fold 3
Fold 4/5




Fold 4, Epoch 1, Batch 10, Loss: 10.0544, LR: 0.001000, Batch Time: 0.13s
Fold 4, Epoch 1, Batch 20, Loss: 9.5584, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 1, Batch 30, Loss: 8.9005, LR: 0.001000, Batch Time: 0.13s
Fold 4, Epoch 1, Batch 40, Loss: 7.6096, LR: 0.001000, Batch Time: 0.11s
Fold 4, Epoch 1, Batch 50, Loss: 7.3650, LR: 0.001000, Batch Time: 0.13s
Fold 4, Epoch 1, Batch 60, Loss: 7.5503, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 1, Batch 70, Loss: 7.1494, LR: 0.001000, Batch Time: 0.13s
Fold 4, Epoch 1, Batch 80, Loss: 6.8752, LR: 0.001000, Batch Time: 0.11s
Fold 4, Epoch 1, Batch 90, Loss: 5.2513, LR: 0.001000, Batch Time: 0.13s
Fold 4, Epoch 1, Batch 100, Loss: 6.7270, LR: 0.001000, Batch Time: 0.11s
Fold 4, Epoch 1, Batch 110, Loss: 6.6401, LR: 0.001000, Batch Time: 0.13s
Fold 4, Epoch 1, Batch 120, Loss: 6.5523, LR: 0.001000, Batch Time: 0.11s
Fold 4, Epoch 1, Batch 130, Loss: 6.9701, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 1, Batch 140, Loss: 5.4575, LR: 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 4, Validation Loss: 14.2334
Fold 4, Validation F1-Scores (Macro) - Color: 0.2423, Type: 0.0317, Season: 0.6194, Gender: 0.6785, Avg F1: 0.3930
Fold 4, Validation Precision (Macro) - Color: 0.2491, Type: 0.0409, Season: 0.5913, Gender: 0.6222
Fold 4, Validation Recall (Macro) - Color: 0.3235, Type: 0.0359, Season: 0.6890, Gender: 0.7972
Fold 4, Validation Accuracy - Color: 0.3985, Type: 0.0274, Season: 0.6102, Gender: 0.8347
Saved best model from Fold 4 at /kaggle/working/best_model.pth (Avg F1: 0.3930)
Fold 4, Epoch 2, Batch 10, Loss: 3.1843, LR: 0.001000, Batch Time: 0.11s
Skipped ID 39401: Image file /kaggle/input/fashion-product-images-dataset/fashion-dataset/images/39401.jpg not found
Fold 4, Epoch 2, Batch 20, Loss: 4.8901, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 2, Batch 30, Loss: 3.2312, LR: 0.001000, Batch Time: 0.13s
Fold 4, Epoch 2, Batch 40, Loss: 1.7118, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 2, Batch 50, Loss: 2.3412, LR: 0.001000, Batch Time: 0.13s
Fold 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 4, Epoch 3, Batch 10, Loss: 2.8384, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 3, Batch 20, Loss: 2.9484, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 3, Batch 30, Loss: 3.2086, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 3, Batch 40, Loss: 3.0102, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 3, Batch 50, Loss: 3.2715, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 3, Batch 60, Loss: 3.2649, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 3, Batch 70, Loss: 3.6307, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 3, Batch 80, Loss: 3.6794, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 3, Batch 90, Loss: 3.4608, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 3, Batch 100, Loss: 2.1700, LR: 0.001000, Batch Time: 0.13s
Fold 4, Epoch 3, Batch 110, Loss: 2.5800, LR: 0.001000, Batch Time: 0.13s
Fold 4, Epoch 3, Batch 120, Loss: 3.1860, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 3, Batch 130, Loss: 5.2150, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 3, Batch 140, Loss: 2.5724, LR: 0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 4, Epoch 4, Batch 10, Loss: 2.1555, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 4, Batch 20, Loss: 2.2021, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 4, Batch 30, Loss: 2.8285, LR: 0.001000, Batch Time: 0.11s
Fold 4, Epoch 4, Batch 40, Loss: 3.7260, LR: 0.001000, Batch Time: 0.13s
Fold 4, Epoch 4, Batch 50, Loss: 3.4710, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 4, Batch 60, Loss: 3.0944, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 4, Batch 70, Loss: 3.9838, LR: 0.001000, Batch Time: 0.11s
Fold 4, Epoch 4, Batch 80, Loss: 4.1849, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 4, Batch 90, Loss: 2.3927, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 4, Batch 100, Loss: 3.1969, LR: 0.001000, Batch Time: 0.12s
Fold 4, Epoch 4, Batch 110, Loss: 2.5287, LR: 0.001000, Batch Time: 0.13s
Fold 4, Epoch 4, Batch 120, Loss: 3.0185, LR: 0.001000, Batch Time: 0.13s
Fold 4, Epoch 4, Batch 130, Loss: 2.3231, LR: 0.001000, Batch Time: 0.13s
Fold 4, Epoch 4, Batch 140, Loss: 2.4597, LR: 0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 4, Validation Loss: 19.6364
Fold 4, Validation F1-Scores (Macro) - Color: 0.2969, Type: 0.0420, Season: 0.6583, Gender: 0.7026, Avg F1: 0.4249
Fold 4, Validation Precision (Macro) - Color: 0.2858, Type: 0.0449, Season: 0.6290, Gender: 0.6421
Fold 4, Validation Recall (Macro) - Color: 0.4201, Type: 0.0453, Season: 0.7089, Gender: 0.8155
Fold 4, Validation Accuracy - Color: 0.4657, Type: 0.0334, Season: 0.6485, Gender: 0.8534
Early stopping triggered for Fold 4
Fold 5/5




Fold 5, Epoch 1, Batch 10, Loss: 10.2194, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 1, Batch 20, Loss: 8.4697, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 1, Batch 30, Loss: 7.6328, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 1, Batch 40, Loss: 7.2539, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 1, Batch 50, Loss: 8.1824, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 1, Batch 60, Loss: 6.3494, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 1, Batch 70, Loss: 7.5950, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 1, Batch 80, Loss: 7.3437, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 1, Batch 90, Loss: 7.3042, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 1, Batch 100, Loss: 5.7118, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 1, Batch 110, Loss: 7.0501, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 1, Batch 120, Loss: 4.7708, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 1, Batch 130, Loss: 5.2638, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 1, Batch 140, Loss: 5.9792, LR: 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 5, Validation Loss: 12.3533
Fold 5, Validation F1-Scores (Macro) - Color: 0.2714, Type: 0.0307, Season: 0.6355, Gender: 0.6834, Avg F1: 0.4053
Fold 5, Validation Precision (Macro) - Color: 0.2693, Type: 0.0361, Season: 0.6143, Gender: 0.6267
Fold 5, Validation Recall (Macro) - Color: 0.3404, Type: 0.0290, Season: 0.7030, Gender: 0.7899
Fold 5, Validation Accuracy - Color: 0.5010, Type: 0.0317, Season: 0.6062, Gender: 0.8460
Saved best model from Fold 5 at /kaggle/working/best_model.pth (Avg F1: 0.4053)
Fold 5, Epoch 2, Batch 10, Loss: 2.7746, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 2, Batch 20, Loss: 3.2333, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 2, Batch 30, Loss: 5.0225, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 2, Batch 40, Loss: 3.7041, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 2, Batch 50, Loss: 3.4939, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 2, Batch 60, Loss: 2.8950, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 2, Batch 70, Loss: 6.9072, LR: 0.00

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 5, Epoch 3, Batch 10, Loss: 2.1855, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 3, Batch 20, Loss: 2.8926, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 3, Batch 30, Loss: 2.4129, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 3, Batch 40, Loss: 2.5849, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 3, Batch 50, Loss: 3.9259, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 3, Batch 60, Loss: 1.8808, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 3, Batch 70, Loss: 2.3532, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 3, Batch 80, Loss: 2.5266, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 3, Batch 90, Loss: 2.4777, LR: 0.001000, Batch Time: 0.13s
Fold 5, Epoch 3, Batch 100, Loss: 2.5939, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 3, Batch 110, Loss: 3.4952, LR: 0.001000, Batch Time: 0.13s
Fold 5, Epoch 3, Batch 120, Loss: 3.0998, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 3, Batch 130, Loss: 4.3076, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 3, Batch 140, Loss: 2.6010, LR: 0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 5, Epoch 4, Batch 10, Loss: 3.1778, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 4, Batch 20, Loss: 3.0147, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 4, Batch 30, Loss: 4.6906, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 4, Batch 40, Loss: 1.8728, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 4, Batch 50, Loss: 1.8967, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 4, Batch 60, Loss: 2.8500, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 4, Batch 70, Loss: 3.2646, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 4, Batch 80, Loss: 3.1452, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 4, Batch 90, Loss: 1.8550, LR: 0.001000, Batch Time: 0.13s
Fold 5, Epoch 4, Batch 100, Loss: 2.7099, LR: 0.001000, Batch Time: 0.12s
Fold 5, Epoch 4, Batch 110, Loss: 2.8721, LR: 0.001000, Batch Time: 0.11s
Fold 5, Epoch 4, Batch 120, Loss: 1.9094, LR: 0.001000, Batch Time: 0.12s
Skipped ID 12347: Image file /kaggle/input/fashion-product-images-dataset/fashion-dataset/images/12347.jpg not found
Fold

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


### Plotting Metrics for K-Fold Cross-Validation

This code generates plots to visualize the performance metrics of the `MultiTaskModel` across k-fold cross-validation.

#### Key Components:
- **Loss Curves**: Plots training and validation loss for each fold, showing model convergence over epochs.
- **Task Metrics**: For each task (color, type, season, gender), plots macro-averaged F1-score, precision, recall, and accuracy over epochs, highlighting performance on imbalanced classes (e.g., 'Spring', 'Girls').
- **Loss Contribution**: Creates a bar plot showing the average loss contribution of each task (color, type, season, gender) for each fold, indicating which tasks are harder to learn.

The macro-averaged F1-score is emphasized in the task metrics plots, as it’s the most significant metric for evaluating performance on imbalanced classes, ensuring minority classes are adequately represented. These visualizations help assess model performance and identify areas for improvement across folds.

In [14]:
# Plot metrics for each fold
for fold, epoch_metrics in enumerate(fold_metrics):
    # Define epoch range for plotting
    epochs = range(1, len(epoch_metrics['train_loss']) + 1)
    
    # Plot training and validation loss curves
    plt.figure(figsize=(12, 8))
    plt.plot(epochs, epoch_metrics['train_loss'], label='Train Loss', color='#1f77b4')
    plt.plot(epochs, epoch_metrics['val_loss'], label='Validation Loss', color='#ff7f0e')
    plt.title(f'Fold {fold+1} Loss Curve')  # Set title for loss plot
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(f'/kaggle/working/fold_{fold+1}_loss_curve.png')  # Save loss plot
    plt.close()

    # Plot metrics (F1, precision, recall, accuracy) for each task
    for task in tasks:
        plt.figure(figsize=(12, 8))
        plt.plot(epochs, epoch_metrics[f'val_f1_{task}'], label='F1-Score (Macro)', color='#1f77b4')
        plt.plot(epochs, epoch_metrics[f'val_precision_{task}'], label='Precision (Macro)', color='#ff7f0e')
        plt.plot(epochs, epoch_metrics[f'val_recall_{task}'], label='Recall (Macro)', color='#2ca02c')
        plt.plot(epochs, epoch_metrics[f'val_accuracy_{task}'], label='Accuracy', color='#d62728')
        plt.title(f'Fold {fold+1} {task.capitalize()} Metrics')  # Set title for task metrics
        plt.xlabel('Epoch')
        plt.ylabel('Score')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'/kaggle/working/fold_{fold+1}_{task}_metrics.png')  # Save task metrics plot
        plt.close()

    # Plot average loss contribution per task
    plt.figure(figsize=(10, 6))
    loss_contributions = [
        np.mean(epoch_metrics['loss_color']),
        np.mean(epoch_metrics['loss_type']),
        np.mean(epoch_metrics['loss_season']),
        np.mean(epoch_metrics['loss_gender'])
    ]
    plt.bar(tasks, loss_contributions, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'])
    plt.title(f'Fold {fold+1} Average Loss Contribution per Task')  # Set title for loss contribution
    plt.xlabel('Task')
    plt.ylabel('Average Loss')
    plt.savefig(f'/kaggle/working/fold_{fold+1}_loss_contribution.png')  # Save loss contribution plot
    plt.close()

### Fold Results Summary and Cross-Fold Visualizations

This code summarizes the k-fold cross-validation results for the `MultiTaskModel` and visualizes average metrics across folds.

#### Key Components:
- **Results Summary**: Prints per-fold metrics, including best validation loss, macro-averaged F1-scores, precision, recall, accuracy for each task (color, type, season, gender), and skipped image IDs from `FashionDataset`. Also displays the best fold based on average F1-score.
- **Average Metrics Calculation**: Computes mean metrics (F1, precision, recall, accuracy, loss) across epochs for each fold and task.
- **Metrics Plot**: Plots task-specific F1-score, precision, recall, and accuracy across folds to show performance consistency.
- **Loss Contribution Plot**: Creates a bar plot of average loss per task across folds, highlighting tasks with higher difficulty.



In [31]:
# Print summary of results for each fold
print("\nFold Results Summary:")
for result in fold_results:
    # Display best validation loss and average F1-score for the fold
    print(f"Fold {result['fold']}: Best Validation Loss: {result['best_val_loss']:.4f}, Avg F1: {result['avg_f1']:.4f}")
    # Display macro-averaged F1-scores for each task
    print(f"Fold {result['fold']}: Validation F1-Scores (Macro) - Color: {result['val_f1_color']:.4f}, Type: {result['val_f1_type']:.4f}, Season: {result['val_f1_season']:.4f}, Gender: {result['val_f1_gender']:.4f}")
    # Display macro-averaged precision for each task
    print(f"Fold {result['fold']}: Validation Precision (Macro) - Color: {result['val_precision_color']:.4f}, Type: {result['val_precision_type']:.4f}, Season: {result['val_precision_season']:.4f}, Gender: {result['val_precision_gender']:.4f}")
    # Display macro-averaged recall for each task
    print(f"Fold {result['fold']}: Validation Recall (Macro) - Color: {result['val_recall_color']:.4f}, Type: {result['val_recall_type']:.4f}, Season: {result['val_recall_season']:.4f}, Gender: {result['val_recall_gender']:.4f}")
    # Display accuracy for each task
    print(f"Fold {result['fold']}: Validation Accuracy - Color: {result['val_accuracy_color']:.4f}, Type: {result['val_accuracy_type']:.4f}, Season: {result['val_accuracy_season']:.4f}, Gender: {result['val_accuracy_gender']:.4f}")
    # Display skipped image IDs for train and validation sets
    
# Print the best fold and its average F1-score
print(f"\nBest Fold: {best_fold} with Avg F1-Score: {best_avg_f1:.4f}")

# Initialize dictionary to store average metrics across folds
avg_metrics = {
    'val_loss': [], 'loss_color': [], 'loss_type': [], 'loss_season': [], 'loss_gender': []
}
for task in tasks:
    # Add task-specific metrics to dictionary
    avg_metrics[f'val_f1_{task}'] = []
    avg_metrics[f'val_precision_{task}'] = []
    avg_metrics[f'val_recall_{task}'] = []
    avg_metrics[f'val_accuracy_{task}'] = []

# Compute average metrics for each fold
for fold in range(k):
    for metric in avg_metrics.keys():
        if metric.startswith('val_') and any(metric.endswith(task) for task in tasks):
            # Compute mean for task-specific validation metrics
            avg_metrics[metric].append(np.mean(fold_metrics[fold][metric]))
        elif metric.startswith('loss_'):
            # Compute mean for loss metrics
            avg_metrics[metric].append(np.mean(fold_metrics[fold][metric]))

# Plot average metrics across folds
plt.figure(figsize=(12, 8))
for metric in ['val_f1', 'val_precision', 'val_recall', 'val_accuracy']:
    for task in tasks:
        # Plot task-specific metric across folds
        plt.plot(range(1, k+1), avg_metrics[f'{metric}_{task}'], label=f'{metric.capitalize()} {task.capitalize()}')
plt.title('Average Metrics Across Folds')  # Set title for metrics plot
plt.xlabel('Fold')
plt.ylabel('Score')
plt.legend()
plt.grid(True)
plt.savefig('/kaggle/working/average_metrics_across_folds.png')  # Save metrics plot
plt.close()

# Plot average loss contribution across folds
plt.figure(figsize=(10, 6))
plt.bar(tasks, [np.mean(avg_metrics[f'loss_{task}']) for task in tasks], color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'])
plt.title('Average Loss Contribution Across Folds')  # Set title for loss contribution
plt.xlabel('Task')
plt.ylabel('Average Loss')
plt.savefig('/kaggle/working/average_loss_contribution.png')  # Save loss contribution plot
plt.close()


Fold Results Summary:
Fold 1: Best Validation Loss: 15.0288, Avg F1: 0.3739
Fold 1: Validation F1-Scores (Macro) - Color: 0.0817, Type: 0.0328, Season: 0.6667, Gender: 0.7146
Fold 1: Validation Precision (Macro) - Color: 0.0922, Type: 0.0377, Season: 0.6377, Gender: 0.6564
Fold 1: Validation Recall (Macro) - Color: 0.1234, Type: 0.0385, Season: 0.7168, Gender: 0.8197
Fold 1: Validation Accuracy - Color: 0.2603, Type: 0.0182, Season: 0.6497, Gender: 0.8586
Fold 2: Best Validation Loss: 18.2874, Avg F1: 0.3767
Fold 2: Validation F1-Scores (Macro) - Color: 0.0875, Type: 0.0280, Season: 0.6700, Gender: 0.7213
Fold 2: Validation Precision (Macro) - Color: 0.0896, Type: 0.0296, Season: 0.6400, Gender: 0.6658
Fold 2: Validation Recall (Macro) - Color: 0.1130, Type: 0.0331, Season: 0.7220, Gender: 0.8244
Fold 2: Validation Accuracy - Color: 0.2455, Type: 0.0198, Season: 0.6514, Gender: 0.8594
Fold 3: Best Validation Loss: 17.4834, Avg F1: 0.3760
Fold 3: Validation F1-Scores (Macro) - Color: 0

## Final Test Evaluation and Visualization

This code evaluates the `MultiTaskModel` on the test set using the best model (from Cell 8) and visualizes the results.

#### Key Components:
- **Test Dataset and DataLoader**: Creates a `FashionDataset` and `DataLoader` for the test set (from Cell 5) with the same transformations and minority class handling.
- **Model Evaluation**: Loads the best model weights, sets the model to evaluation mode, and computes predictions on the test set without gradient computation.
- **Metrics Calculation**: Calculates average test loss and macro-averaged F1-score, precision, recall, and accuracy for each task (color, type, season, gender).
- **Results Output**: Prints test loss, task-specific metrics, and skipped image IDs from `FashionDataset`.
- **Confusion Matrices**: Plots a 2x2 grid of confusion matrices for each task, visualizing prediction errors.
- **Class-wise F1-Scores**: Plots bar charts of per-class F1-scores for each task, highlighting class-specific performance.
- **Error Handling**: Checks if the best model file exists, skipping evaluation if not found.


In [30]:
# Final evaluation on the test set using the best model
if os.path.exists(best_model_path):
    # Create test dataset and DataLoader
    test_dataset = FashionDataset(test_data, img_dir, transform=transform, minority_classes=minority_classes)
    test_loader = create_dataloader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
    
    # Load best model weights and set to evaluation mode
    model.load_state_dict(torch.load(best_model_path))
    model.eval()
    
    # Initialize test metrics
    test_loss = 0.0
    test_count = 0
    test_color_true, test_color_pred = [], []
    test_type_true, test_type_pred = [], []
    test_season_true, test_season_pred = [], []
    test_gender_true, test_gender_pred = [], []
    
    # Evaluate on test set without gradient computation
    with torch.no_grad():
        for batch in test_loader:
            if batch is None:
                continue
            # Load batch data and move to device
            images, colors, types, seasons, genders = batch
            images, colors, types, seasons, genders = (
                images.to(device), colors.to(device), types.to(device),
                seasons.to(device), genders.to(device)
            )
            # Forward pass
            color_pred, type_pred, season_pred, gender_pred = model(images)
            # Compute total test loss
            test_loss += (criterion_color(color_pred, colors) + criterion_type(type_pred, types) +
                          criterion_season(season_pred, seasons) + criterion_gender(gender_pred, genders)).item()
            test_count += 1
            # Store true and predicted labels for metrics
            test_color_true.extend(colors.cpu().numpy())
            test_color_pred.extend(torch.argmax(color_pred, dim=1).cpu().numpy())
            test_type_true.extend(types.cpu().numpy())
            test_type_pred.extend(torch.argmax(type_pred, dim=1).cpu().numpy())
            test_season_true.extend(seasons.cpu().numpy())
            test_season_pred.extend(torch.argmax(season_pred, dim=1).cpu().numpy())
            test_gender_true.extend(genders.cpu().numpy())
            test_gender_pred.extend(torch.argmax(gender_pred, dim=1).cpu().numpy())
    
    # Compute average test loss
    avg_test_loss = test_loss / test_count if test_count > 0 else float('inf')
    
    # Calculate macro-averaged metrics for each task
    test_f1_color = f1_score(test_color_true, test_color_pred, average='macro')
    test_f1_type = f1_score(test_type_true, test_type_pred, average='macro')
    test_f1_season = f1_score(test_season_true, test_season_pred, average='macro')
    test_f1_gender = f1_score(test_gender_true, test_gender_pred, average='macro')
    test_precision_color = precision_score(test_color_true, test_color_pred, average='macro')
    test_precision_type = precision_score(test_type_true, test_type_pred, average='macro')
    test_precision_season = precision_score(test_season_true, test_season_pred, average='macro')
    test_precision_gender = precision_score(test_gender_true, test_gender_pred, average='macro')
    test_recall_color = recall_score(test_color_true, test_color_pred, average='macro')
    test_recall_type = recall_score(test_type_true, test_type_pred, average='macro')
    test_recall_season = recall_score(test_season_true, test_season_pred, average='macro')
    test_recall_gender = recall_score(test_gender_true, test_gender_pred, average='macro')
    test_accuracy_color = accuracy_score(test_color_true, test_color_pred)
    test_accuracy_type = accuracy_score(test_type_true, test_type_pred)
    test_accuracy_season = accuracy_score(test_season_true, test_season_pred)
    test_accuracy_gender = accuracy_score(test_gender_true, test_gender_pred)
    
    # Print test metrics
    print(f"\nTest Loss: {avg_test_loss:.4f}")
    print(f"Test F1-Scores (Macro) - Color: {test_f1_color:.4f}, Type: {test_f1_type:.4f}, Season: {test_f1_season:.4f}, Gender: {test_f1_gender:.4f}")
    print(f"Test Precision (Macro) - Color: {test_precision_color:.4f}, Type: {test_precision_type:.4f}, Season: {test_precision_season:.4f}, Gender: {test_precision_gender:.4f}")
    print(f"Test Recall (Macro) - Color: {test_recall_color:.4f}, Type: {test_recall_type:.4f}, Season: {test_recall_season:.4f}, Gender: {test_recall_gender:.4f}")
    print(f"Test Accuracy - Color: {test_accuracy_color:.4f}, Type: {test_accuracy_type:.4f}, Season: {test_accuracy_season:.4f}, Gender: {test_accuracy_gender:.4f}")
    

    # Plot confusion matrices for each task
    plt.figure(figsize=(12, 10))
    for i, (task, true, pred) in enumerate([
        ('color', test_color_true, test_color_pred),
        ('type', test_type_true, test_type_pred),
        ('season', test_season_true, test_season_pred),
        ('gender', test_gender_true, test_gender_pred)
    ]):
        cm = confusion_matrix(true, pred)
        plt.subplot(2, 2, i+1)
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title(f'Test Confusion Matrix - {task.capitalize()}')
        plt.xlabel('Predicted')
        plt.ylabel('True')
    plt.tight_layout()
    plt.savefig('/kaggle/working/test_confusion_matrices.png')
    plt.close()

    # Plot class-wise F1-scores for each task
    for task, true, pred in [
        ('color', test_color_true, test_color_pred),
        ('type', test_type_true, test_type_pred),
        ('season', test_season_true, test_season_pred),
        ('gender', test_gender_true, test_gender_pred)
    ]:
        class_f1_scores = f1_score(true, pred, average=None)
        plt.figure(figsize=(10, 6))
        plt.bar(range(len(class_f1_scores)), class_f1_scores, color='#1f77b4')
        plt.title(f'Test Class-wise F1-Scores - {task.capitalize()}')
        plt.xlabel('Class Index')
        plt.ylabel('F1-Score')
        plt.savefig(f'/kaggle/working/test_class_f1_scores_{task}.png')
        plt.close()
else:
    # Handle case where best model file is missing
    print(f"Error: Best model file {best_model_path} not found. No test evaluation performed.")

Skipped ID 39410: Image file /kaggle/input/fashion-product-images-dataset/fashion-dataset/images/39410.jpg not found


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))



Test Loss: 17.0247
Test F1-Scores (Macro) - Color: 0.0597, Type: 0.0231, Season: 0.6192, Gender: 0.6782
Test Precision (Macro) - Color: 0.0721, Type: 0.0260, Season: 0.5958, Gender: 0.6198
Test Recall (Macro) - Color: 0.0648, Type: 0.0222, Season: 0.6959, Gender: 0.7955
Test Accuracy - Color: 0.2310, Type: 0.0203, Season: 0.5943, Gender: 0.8427


# Further Enhancement Needed:

To improve the performance of the 'MultiTaskModel' for predicting fashion attributes (gender, article type, season, color), several enhancements can be considered to address the challenges observed, particularly due to the dataset's class imbalance.

### Key Observations and Proposed Improvements:

***Class Imbalance and High Cardinality:*** 
The model was trained to predict four attributes: gender, article type (143 classes), season, and color (46 classes), as per the assignment requirements. However, the dataset's imbalance, especially for article type (143 classes) and color (46 classes), led to low macro-averaged F1-scores. The large number of classes in article type significantly impacted accuracy.

***Solution:*** Reduce the number of classes for article type and color:

 - **Article Type:** Instead of predicting 143 article types, cluster similar products to create fewer, more balanced classes. Alternatively, predict higher-level categories like 'MasterCategory' or 'SubCategory' in place of 'articleType', which have fewer classes and less imbalance, potentially improving model performance.

- **Color:** Group the 46 colors into color families (e.g., reds, blues, neutrals) to reduce the number of classes and improve generalization. This also accommodates future products with new colors not present in the dataset.


**Scalability for Colors:** The dataset currently includes 46 colors, but new products may introduce additional colors, increasing color classes  over time.

- **Solution:** Categorize colors into predefined color families (e.g., warm, cool, neutral tones) to create a fixed, manageable number of classes. This approach ensures the model remains robust as the dataset grows, avoiding the need to retrain for every new color.


**Model Performance:** The macro-averaged F1-score, highlights the model’s struggle with minority classes. Clustering or reducing classes for article type and color would improve F1-scores by balancing class distributions.

**Future Steps:**

- Implement clustering for article types by making cluster of similar classes to reduce the number of classes.

- Explore predicting MasterCategory or SubCategory instead of article type for better performance.

- Define color families based on hue, saturation, or perceptual similarity to simplify color classification.

- Re-evaluate the model with these changes, focusing on macro-averaged F1-scores to ensure balanced performance across all classes.

These enhancements would address the dataset’s challenges, improve model accuracy and generalization, and align with practical applications in fashion product classification.

## Thank You