## Garbage Classification Transfer Learning


## Imports and Configuration

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Dataset, Subset
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR
from torchvision.models.efficientnet import EfficientNet_B0_Weights
import os
import re
import logging
import sys
import numpy as np
from collections import Counter
from transformers import DistilBertModel, DistilBertTokenizer
import wandb
from sklearn.model_selection import StratifiedKFold, train_test_split
from PIL import Image
import matplotlib.pyplot as plt
from wordcloud import WordCloud
import time
import spacy
from nltk.corpus import stopwords

NOTES = '''
'''

# ========================================= GLOBAL CONFIGURATION ================================================
# Data Directories
DATA_DIR = r"C:\NN Data\garbage_data\kfold_garbage_data"
CLASSES = ["Black", "Blue", "Green", "TTR"]

# ========================================= Experiment Settings =========================================
WANDB_RUN_NAME = "experiment_multimodal_gated_only"
MODEL_NAME = "experiment_multimodal_gated_only"

# ========================================= Data Settings =========================================
IMAGE_SIZE = (224, 224)  # Input image size for EfficientNetV2-S
NUM_CLASSES = 4  # Number of output classes for classification
MAX_LEN = 40  # Maximum token length for DistilBERT tokenizer
TEST_SIZE = 0.2  # Test dataset size split
K_FOLDS = 5  # Number of folds for stratified k-fold cross-validation

# ========================================= Training Hyperparameters =========================================
BATCH_SIZE = 64  # Number of samples per batch
GRAD_ACCUM_STEPS = 4
EPOCHS = 50  # Maximum number of training epochs
DROPOUT_IMAGE = 0.2 # Reduce from 0.3
DROPOUT_TEXT = 0.1 # Reduce from 0.2
DROPOUT_FUSION = 0.2 
DROPOUT_CLASSIFIER = 0.1
PATIENCE = 10  # Number of epochs to wait before early stopping
CONVERGENCE_THRESHOLD = 0.001  # Minimum improvement in validation loss to continue training

# ========================================= Optimization Settings =========================================
OPTIMIZER = "AdamW"
LR_SCHEDULING_FACTOR = 0.3
LEARNING_RATE_UNFREEZE_IMAGE = 1e-5
LEARNING_RATE_UNFREEZE_TEXT = 1e-5
LEARNING_RATE_FUSION = 1e-3
LEARNING_RATE_CLASSIFIER = 5e-3
LEARNING_RATE_IMAGE = 0.001 # # EfficientNetB0
LEARNING_RATE_TEXT = 0.00002 # DistilBERT Uncased
WEIGHT_DECAY_TEXT = 1e-3  # Reduce from 1e-2
WEIGHT_DECAY_IMAGE = 1e-4  # Reduce from 1e-3
WEIGHT_DECAY_FUSION = 4e-4 
WEIGHT_DECAY_CLASSIFIER = 1e-3  # Reduce from 1e-4
LABEL_SMOOTHING_PREDICTION = 0.05 # Reduce from 0.1

# ========================================= System Settings =========================================
NUM_WORKERS = 4  # Dataloader parallelization

# Wandb Configuration
WANDB_CONFIG = {
    "entity": "shcau-university-of-calgary-in-alberta",
    "project": "transfer_learning_garbage",
    "name": WANDB_RUN_NAME,
    "tags": ["distilBERT", "efficientnet", "CVPR_2024_dataset"],
    "notes": NOTES,
    "config": {
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "dataset": "CVPR_2024_dataset",
        "image_size": IMAGE_SIZE,
        "num_workers": NUM_WORKERS,
        "num_classes": NUM_CLASSES,
        "max_len": MAX_LEN,
        "learning_rate_image": LEARNING_RATE_IMAGE,
        "learning_rate_text": LEARNING_RATE_TEXT,
        "learning_rate_fusion": LEARNING_RATE_FUSION,
        "learning_rate_classifier": LEARNING_RATE_CLASSIFIER,
        "learning_rate_unfreeze_image": LEARNING_RATE_UNFREEZE_IMAGE, # learning rate for unfrozen EfficientNet layers
        "learning_rate_unfreeze_text": LEARNING_RATE_UNFREEZE_TEXT, # learning rate for unfrozen DistilBERT layers
        "dropout_image": DROPOUT_IMAGE,
        "dropout_text": DROPOUT_TEXT,
        "dropout_classifier": DROPOUT_CLASSIFIER,
        "convergence_threshold": CONVERGENCE_THRESHOLD,
        "patience": PATIENCE,
        "weight_decay_text": WEIGHT_DECAY_TEXT,
        "weight_decay_image": WEIGHT_DECAY_IMAGE,
        "weight_decay_classifier": WEIGHT_DECAY_CLASSIFIER,
        "label_smoothing_prediction": LABEL_SMOOTHING_PREDICTION,
        "optimizer": OPTIMIZER 
    },
    "job_type": "train",
    "resume": "allow",
}

# Normalization Stats
NORMALIZATION_STATS = EfficientNet_B0_Weights.IMAGENET1K_V1.transforms()

## Initialize Logging

In [2]:
LOG_FILE = "experiment_multimodal_gated_only.txt"  # Log file name

# Configure logging
logging.basicConfig(
    level=logging.INFO,  # Log everything (INFO and above)
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[
        logging.FileHandler(LOG_FILE, mode='w'),  # Overwrite log file on each run
        logging.StreamHandler(sys.stdout)  # Print log messages to console too
    ]
)

In [3]:
# Log the configuration
logging.info("[CONFIG] ============================== Experiment Configuration ==============================")

# Log top-level keys
logging.info(f"[CONFIG] Experiment Name: {WANDB_CONFIG['name']}")
logging.info(f"[CONFIG] Entity: {WANDB_CONFIG['entity']}")
logging.info(f"[CONFIG] Project: {WANDB_CONFIG['project']}")
logging.info(f"[CONFIG] Tags: {', '.join(WANDB_CONFIG['tags'])}")
logging.info(f"[CONFIG] Notes: {WANDB_CONFIG['notes']}")
logging.info(f"[CONFIG] Job Type: {WANDB_CONFIG['job_type']}")
logging.info(f"[CONFIG] Resume: {WANDB_CONFIG['resume']}")

# Log nested configuration (under 'config')
logging.info("[CONFIG] ------------------------------ Hyperparameters ------------------------------")
for key, value in WANDB_CONFIG["config"].items():
    logging.info(f"[CONFIG] {key}: {value}")

logging.info("[CONFIG] =============================================================================")

2025-03-25 19:16:31,862 - INFO - [CONFIG] Experiment Name: experiment_multimodal_gated_only
2025-03-25 19:16:31,863 - INFO - [CONFIG] Entity: shcau-university-of-calgary-in-alberta
2025-03-25 19:16:31,864 - INFO - [CONFIG] Project: transfer_learning_garbage
2025-03-25 19:16:31,865 - INFO - [CONFIG] Tags: distilBERT, efficientnet, CVPR_2024_dataset
2025-03-25 19:16:31,865 - INFO - [CONFIG] Notes: 

2025-03-25 19:16:31,866 - INFO - [CONFIG] Job Type: train
2025-03-25 19:16:31,867 - INFO - [CONFIG] Resume: allow
2025-03-25 19:16:31,868 - INFO - [CONFIG] ------------------------------ Hyperparameters ------------------------------
2025-03-25 19:16:31,868 - INFO - [CONFIG] epochs: 50
2025-03-25 19:16:31,869 - INFO - [CONFIG] batch_size: 64
2025-03-25 19:16:31,870 - INFO - [CONFIG] dataset: CVPR_2024_dataset
2025-03-25 19:16:31,870 - INFO - [CONFIG] image_size: (224, 224)
2025-03-25 19:16:31,871 - INFO - [CONFIG] num_workers: 4
2025-03-25 19:16:31,872 - INFO - [CONFIG] num_classes: 4
2025-03

## Weights and Biases Setup

In [4]:
def initialize_wandb(fold):
    """Initialize wandb for each fold with a unique run name."""
    wandb.init(
        entity=WANDB_CONFIG["entity"],
        project=WANDB_CONFIG["project"],
        name=f"{WANDB_RUN_NAME}_fold_{fold + 1}",
        tags=WANDB_CONFIG["tags"],
        notes=WANDB_CONFIG["notes"],
        config=WANDB_CONFIG["config"],
        job_type=WANDB_CONFIG["job_type"],
        resume=WANDB_CONFIG["resume"],
    )

## Helper Function

In [5]:
# Load SpaCy for lemmatization
nlp = spacy.load("en_core_web_sm")

# Load NLTK stopwords
stop_words = set(stopwords.words("english"))

def preprocess_text(text):
    """Standardize text, remove stopwords, and apply lemmatization."""
    # 1. Standardize text (lowercasing & trimming spaces)
    text = text.strip().lower()

    # 2. Remove stopwords
    text_tokens = text.split()
    text = " ".join([word for word in text_tokens if word not in stop_words])

    # 3. Lemmatization
    doc = nlp(text)
    text = " ".join([token.lemma_ for token in doc])

    return text

def read_text_files_with_labels_and_image_paths(path):
    """Extract text from file names, apply preprocessing, and return labels with image paths."""
    texts, labels, image_paths = [], [], []
    class_folders = sorted(os.listdir(path))
    label_map = {class_name: idx for idx, class_name in enumerate(class_folders)}

    for class_name in class_folders:
        class_path = os.path.join(path, class_name)
        if os.path.isdir(class_path):
            file_names = sorted(os.listdir(class_path))  # Sort to ensure order consistency
            for file_name in file_names:
                file_path = os.path.join(class_path, file_name)
                if os.path.isfile(file_path):
                    # Extract filename without extension
                    file_name_no_ext, _ = os.path.splitext(file_name)

                    # Replace underscores with spaces
                    text = file_name_no_ext.replace("_", " ")

                    # Remove numbers
                    text_without_digits = re.sub(r"\d+", "", text)

                    # Apply preprocessing
                    preprocessed_text = preprocess_text(text_without_digits)

                    texts.append(preprocessed_text)
                    labels.append(label_map[class_name])
                    image_paths.append(file_path)

    return np.array(texts), np.array(labels), np.array(image_paths)

## Data Setup

In [6]:
class CustomTextDataset(Dataset):
    """Dataset class for text data."""
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }
    
# Custom dataset class for images
class ImageDataset(Dataset):
    """Dataset class for image data."""
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)



class MultimodalDataset(Dataset):
    """Dataset class for multimodal data (image + text)."""
    def __init__(self, image_dataset, text_dataset):
        self.image_dataset = image_dataset
        self.text_dataset = text_dataset

    def __len__(self):
        return min(len(self.image_dataset), len(self.text_dataset))

    def __getitem__(self, idx):
        image, label = self.image_dataset[idx]
        text_data = self.text_dataset[idx]
        return {
            "image": image,
            "input_ids": text_data["input_ids"],
            "attention_mask": text_data["attention_mask"],
            "label": label
        }

## Main Experiment

### Model Definition

In [7]:
# ======================== Gated Fusion ========================
class GatedFusion(nn.Module):
    def __init__(self, feature_dim):
        super(GatedFusion, self).__init__()
        self.gate = nn.Linear(2 * feature_dim, feature_dim)  # Learnable gate
        self.sigmoid = nn.Sigmoid()  # Activation

    def forward(self, text_feat, image_feat):
        combined_feat = torch.cat((text_feat, image_feat), dim=1)
        gate_value = self.sigmoid(self.gate(combined_feat))  # Value between 0-1
        fused_feat = (gate_value * text_feat) + ((1 - gate_value) * image_feat)  # Weighted fusion
        return fused_feat

# ======================== Multimodal Classifier (Last Feature Extractor Layer Unfrozen) ========================
class MultimodalClassifier(nn.Module):
    """Multimodal model combining EfficientNetB0 and DistilBERT with partial fine-tuning."""
    def __init__(self, num_classes):
        super(MultimodalClassifier, self).__init__()

        # ----------- Image Feature Extractor (EfficientNetB0) -----------
        self.image_model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        
        # Freeze all layers except the last one
        for param in self.image_model.features.parameters():
            param.requires_grad = False
        for param in self.image_model.features[-3:].parameters():  # Unfreeze last feature layer
            param.requires_grad = True

        num_ftrs = self.image_model.classifier[1].in_features
        self.image_model.classifier = nn.Identity()  # Remove classifier
        self.image_fc = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.Dropout(DROPOUT_IMAGE)
        )

        # ----------- Text Feature Extractor (DistilBERT) -----------
        self.text_model = DistilBertModel.from_pretrained('distilbert-base-uncased')

        # Freeze all layers except the last transformer layer
        for param in self.text_model.parameters():
            param.requires_grad = False
        for param in self.text_model.transformer.layer[-2:].parameters():  # Unfreeze last transformer layer
            param.requires_grad = True

        self.text_fc = nn.Sequential(
            nn.Linear(self.text_model.config.hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(DROPOUT_TEXT)
        )

        # ----------- Normalize Features -----------
        self.text_norm = nn.LayerNorm(512)
        self.image_norm = nn.LayerNorm(512)

        # ----------- Gated Fusion -----------
        self.gated_fusion = GatedFusion(feature_dim=512)

        # ----------- Fully Connected Fusion & Classification -----------
        self.fusion_fc = nn.Sequential(
            nn.Linear(512, 512),  # Increase dimension
            nn.BatchNorm1d(512),  # Add batch normalization
            nn.ReLU(),            # Use GELU activation

            nn.Linear(512, 256),  # Intermediate layer
            nn.BatchNorm1d(256),  # Batch normalization
            nn.ReLU(),            # GELU activation
        )

        self.dropout = nn.Dropout(DROPOUT_CLASSIFIER)
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, input_ids, attention_mask, image_inputs):
        text_output = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_features = self.text_fc(text_output.last_hidden_state[:, 0, :])
        text_features = self.text_norm(text_features)
        image_features = self.image_fc(self.image_model(image_inputs))
        image_features = self.image_norm(image_features)
        gated_feat = self.gated_fusion(text_features, image_features)
        fused_features = self.fusion_fc(gated_feat)
        output = self.classifier(self.dropout(fused_features))
        return output


### Data setup

In [8]:
# Load dataset
texts, labels, image_paths = read_text_files_with_labels_and_image_paths(DATA_DIR)

# Log first and last 4 samples
logging.info("First 4 samples of dataset:\n")
logging.info(f"Texts: {texts[:4]}")
logging.info(f"Labels: {labels[:4]}")
logging.info(f"Image Paths: {image_paths[:4]}")

logging.info("\nLast 4 samples of dataset:\n")
logging.info(f"Texts: {texts[-4:]}")
logging.info(f"Labels: {labels[-4:]}")
logging.info(f"Image Paths: {image_paths[-4:]}")

2025-03-25 19:17:00,674 - INFO - First 4 samples of dataset:

2025-03-25 19:17:00,674 - INFO - Texts: ['aero bar wrapper' 'break glass' 'break rubber' 'butter paper']
2025-03-25 19:17:00,675 - INFO - Labels: [0 0 0 0]
2025-03-25 19:17:00,676 - INFO - Image Paths: ['C:\\NN Data\\garbage_data\\kfold_garbage_data\\Black\\Aero_bar_wrapper_1.png'
 'C:\\NN Data\\garbage_data\\kfold_garbage_data\\Black\\Broken_Glass_5291.png'
 'C:\\NN Data\\garbage_data\\kfold_garbage_data\\Black\\Broken_rubber_7263.png'
 'C:\\NN Data\\garbage_data\\kfold_garbage_data\\Black\\Butter_Paper_9976.png']
2025-03-25 19:17:00,676 - INFO - 
Last 4 samples of dataset:

2025-03-25 19:17:00,677 - INFO - Texts: ['wristwatch' 'xbox controller' 'xbox one controller' 'zipper file bag']
2025-03-25 19:17:00,677 - INFO - Labels: [3 3 3 3]
2025-03-25 19:17:00,678 - INFO - Image Paths: ['C:\\NN Data\\garbage_data\\kfold_garbage_data\\TTR\\wristwatch_3782.png'
 'C:\\NN Data\\garbage_data\\kfold_garbage_data\\TTR\\xbox_controller_

### Split into test set and development set

In [9]:
# Split into a test set and development set
train_texts, test_texts, train_labels, test_labels, train_image_paths, test_image_paths = train_test_split(
    texts, labels, image_paths, test_size=TEST_SIZE, stratify=labels, random_state=42
)

# Log first 4 samples of test set
logging.info("First 4 samples of test set:\n")
logging.info(f"Texts: {test_texts[:4]}")
logging.info(f"Labels: {test_labels[:4]}")
logging.info(f"Image Paths: {test_image_paths[:4]}")

logging.info("\nLast 4 samples of test set:\n")
logging.info(f"Texts: {test_texts[-4:]}")
logging.info(f"Labels: {test_labels[-4:]}")
logging.info(f"Image Paths: {test_image_paths[-4:]}")

2025-03-25 19:17:00,698 - INFO - First 4 samples of test set:

2025-03-25 19:17:00,698 - INFO - Texts: ['ballast light' 'old phone' 'milk jug lid tab' 'dirty dish sponge']
2025-03-25 19:17:00,699 - INFO - Labels: [3 3 0 0]
2025-03-25 19:17:00,700 - INFO - Image Paths: ['C:\\NN Data\\garbage_data\\kfold_garbage_data\\TTR\\ballast_light_286.png'
 'C:\\NN Data\\garbage_data\\kfold_garbage_data\\TTR\\Old_Phones_7828.png'
 'C:\\NN Data\\garbage_data\\kfold_garbage_data\\Black\\milk_jug_lid_tab_1137.png'
 'C:\\NN Data\\garbage_data\\kfold_garbage_data\\Black\\dirty_dish_sponge_437.png']
2025-03-25 19:17:00,700 - INFO - 
Last 4 samples of test set:

2025-03-25 19:17:00,701 - INFO - Texts: ['empty glass jar' 'non - stretchy plastic' 'backpack' 'piece break glass']
2025-03-25 19:17:00,702 - INFO - Labels: [1 0 3 0]
2025-03-25 19:17:00,702 - INFO - Image Paths: ['C:\\NN Data\\garbage_data\\kfold_garbage_data\\Blue\\empty_glass_jar_1609.png'
 'C:\\NN Data\\garbage_data\\kfold_garbage_data\\Black\

### Define Transformations

In [10]:
# Define transformations
transform = {
    "train": transforms.Compose([
        transforms.Resize(IMAGE_SIZE), 
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomRotation(20),
        transforms.RandomAffine(degrees=0, translate=(0.2, 0.2)),
        transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=NORMALIZATION_STATS.mean, std=NORMALIZATION_STATS.std)  # Apply correct normalization
    ]),
    "val": transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=NORMALIZATION_STATS.mean, std=NORMALIZATION_STATS.std)  # Only resize + normalize
    ]),
    "test": transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=NORMALIZATION_STATS.mean, std=NORMALIZATION_STATS.std)  # Only resize + normalize
    ])
}

# Tokenizer for DistilBERT
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

### DataLoader for test set

Create the dataloader for the test set and set aside for model evaluation.

In [11]:
# Create test dataset
test_image_dataset = ImageDataset(test_image_paths, test_labels, transform["test"])
test_text_dataset = CustomTextDataset(test_texts, test_labels, tokenizer, max_len=MAX_LEN)  # Ensure tokenizer is defined
test_multimodal_dataset = MultimodalDataset(test_image_dataset, test_text_dataset)

# DataLoader for test set
test_loader = DataLoader(test_multimodal_dataset, batch_size=BATCH_SIZE, shuffle=False)

Take a peek at a batch in the test set to verify that data has been correctly organized.

In [12]:
# Get one batch
for batch in test_loader:
    images = batch["image"]  # Image tensor
    input_ids = batch["input_ids"]  # Tokenized text tensor
    attention_mask = batch["attention_mask"]  # Attention mask
    labels = batch["label"]  # Labels tensor

    # Log shapes of tensors
    logging.info("[INFO] One Batch Sample Inspection:")
    logging.info(f"   Images Shape: {images.shape}")
    logging.info(f"   Input IDs Shape: {input_ids.shape}")
    logging.info(f"   Attention Mask Shape: {attention_mask.shape}")
    logging.info(f"   Labels Shape: {labels.shape}")

    # Log first sample details
    logging.info("\n[INFO] First Sample:")
    logging.info(f"   Image Tensor: {images[0]}")
    logging.info(f"   Input IDs: {input_ids[0]}")
    logging.info(f"   Attention Mask: {attention_mask[0]}")
    logging.info(f"   Label: {labels[0]}")

    break  # Stop after inspecting one batch


2025-03-25 19:17:02,735 - INFO - [INFO] One Batch Sample Inspection:
2025-03-25 19:17:02,736 - INFO -    Images Shape: torch.Size([64, 3, 224, 224])
2025-03-25 19:17:02,737 - INFO -    Input IDs Shape: torch.Size([64, 40])
2025-03-25 19:17:02,738 - INFO -    Attention Mask Shape: torch.Size([64, 40])
2025-03-25 19:17:02,739 - INFO -    Labels Shape: torch.Size([64])
2025-03-25 19:17:02,739 - INFO - 
[INFO] First Sample:
2025-03-25 19:17:02,771 - INFO -    Image Tensor: tensor([[[2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
         [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
         [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
         ...,
         [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
         [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
         [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489]],

        [[2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
         [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286,

### Apply Stratified K-Fold on the development set to split into train/val

In [13]:
# Initialize Stratified K-Fold
skf = StratifiedKFold(n_splits=K_FOLDS, shuffle=True, random_state=42)

for fold, (train_idx, val_idx) in enumerate(skf.split(train_texts, train_labels)):
    logging.info(f"[INFO] Fold {fold + 1}/{K_FOLDS}")

    # Extract labels for current fold
    train_labels_fold = train_labels[train_idx]
    val_labels_fold = train_labels[val_idx]

    # Log class distributions
    logging.info("[INFO] Class Distributions:")
    logging.info(f"   Train Class Distribution: {Counter(train_labels_fold)}")
    logging.info(f"   Validation Class Distribution: {Counter(val_labels_fold)}")


2025-03-25 19:17:02,785 - INFO - [INFO] Fold 1/5
2025-03-25 19:17:02,786 - INFO - [INFO] Class Distributions:
2025-03-25 19:17:02,787 - INFO -    Train Class Distribution: Counter({np.int64(1): 3590, np.int64(0): 1754, np.int64(2): 1708, np.int64(3): 1542})
2025-03-25 19:17:02,788 - INFO -    Validation Class Distribution: Counter({np.int64(1): 898, np.int64(0): 438, np.int64(2): 427, np.int64(3): 386})
2025-03-25 19:17:02,788 - INFO - [INFO] Fold 2/5
2025-03-25 19:17:02,789 - INFO - [INFO] Class Distributions:
2025-03-25 19:17:02,790 - INFO -    Train Class Distribution: Counter({np.int64(1): 3591, np.int64(0): 1753, np.int64(2): 1708, np.int64(3): 1542})
2025-03-25 19:17:02,791 - INFO -    Validation Class Distribution: Counter({np.int64(1): 897, np.int64(0): 439, np.int64(2): 427, np.int64(3): 386})
2025-03-25 19:17:02,792 - INFO - [INFO] Fold 3/5
2025-03-25 19:17:02,793 - INFO - [INFO] Class Distributions:
2025-03-25 19:17:02,794 - INFO -    Train Class Distribution: Counter({np.in

### Verify k-fold was applied correctly

In [14]:
# Ensure no data leakage in folds
for fold, (train_idx, val_idx) in enumerate(skf.split(train_texts, train_labels)):
    train_set = set(train_idx)
    val_set = set(val_idx)

    # Check for intersection (should be empty)
    intersection = train_set.intersection(val_set)
    assert len(intersection) == 0, f"Data leakage detected in Fold {fold + 1}"

    logging.info(f"[INFO] No data leakage detected in Fold {fold + 1}")

2025-03-25 19:17:02,812 - INFO - [INFO] No data leakage detected in Fold 1
2025-03-25 19:17:02,814 - INFO - [INFO] No data leakage detected in Fold 2
2025-03-25 19:17:02,815 - INFO - [INFO] No data leakage detected in Fold 3
2025-03-25 19:17:02,817 - INFO - [INFO] No data leakage detected in Fold 4
2025-03-25 19:17:02,818 - INFO - [INFO] No data leakage detected in Fold 5


In [15]:
# for fold, (train_idx, val_idx) in enumerate(skf.split(train_texts, train_labels)):
#     train_labels_fold = train_labels[train_idx]
#     val_labels_fold = train_labels[val_idx]

#     plt.figure(figsize=(10, 4))
#     plt.hist(train_labels_fold, bins=len(set(train_labels)), alpha=0.6, label="Train")
#     plt.hist(val_labels_fold, bins=len(set(train_labels)), alpha=0.6, label="Validation")
#     plt.title(f"Class Distribution in Fold {fold + 1}")
#     plt.legend()
#     plt.show()

## Train Model

### Evaluation Function

In [16]:
def evaluate_model(model, dataloader, device):
    model.eval()
    total_loss = 0
    correct, total = 0, 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch in dataloader:
            # Move data to the appropriate device
            images = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            # Forward pass
            outputs = model(input_ids, attention_mask, images)
            loss = criterion(outputs, labels)  # Compute batch loss

            # Aggregate loss for averaging
            total_loss += loss.item() * labels.size(0)  # Multiply by batch size for proper averaging
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)

    avg_loss = total_loss / total  # Normalize loss over total samples
    accuracy = correct / total  # Compute accuracy

    return avg_loss, accuracy


## Adaptive Weight Decay

In [17]:
def adaptive_weight_decay(epoch, warmup_epochs=5, decay_factors=(0.1, 1.0)):
    """
    Returns a scaled weight decay based on epoch number.
    During warm-up, it applies a lower decay (decay_factors[0]).
    After warm-up, it applies full weight decay (decay_factors[1]).
    """
    if epoch < warmup_epochs:
        return decay_factors[0]  # Use lower decay during warm-up
    return decay_factors[1]  # Use normal decay afterward


In [18]:
def get_warmup_lr(epoch, warmup_epochs, base_lr):
    """
    Linear warmup schedule for the learning rate.
    """
    if epoch < warmup_epochs:
        return base_lr * (epoch + 1) / warmup_epochs
    else:
        return base_lr

### Train Loop

In [19]:
def train_model(model, dataloaders, criterion, optimizer, device, fold, use_mixup=True):
    initialize_wandb(fold)
    wandb.watch(model, log="all")

    best_val_loss = float("inf")  # Track best validation loss
    epochs_without_improvement = 0  # Track epochs without improvement until equals patience

    # ================ ReduceLROnPlateau Scheduler ================
    plateau_scheduler = ReduceLROnPlateau(
        optimizer, mode="min", factor=LR_SCHEDULING_FACTOR, patience=3, verbose=True
    )
    
    # AMP GradScaler
    scaler = GradScaler()

    epoch_start_time = time.time()  # Start total training timer
    logging.info("[TRAIN INFO] Starting Training...")

    # Warmup settings
    WARMUP_EPOCHS = 8  # Number of epochs for warmup
    base_lr_image = LEARNING_RATE_IMAGE  # Base learning rate for EfficientNet
    base_lr_text = LEARNING_RATE_TEXT  # Base learning rate for DistilBERT
    base_lr_fusion = LEARNING_RATE_FUSION  # Base learning rate for fusion layer
    base_lr_classifier = LEARNING_RATE_CLASSIFIER  # Base learning rate for classifier

    # Training Loop
    for epoch in range(EPOCHS):
        logging.info(f"[TRAIN INFO] ============================== Epoch {epoch + 1}/{EPOCHS} ==============================")
        
        # Apply learning rate warmup
        if epoch < WARMUP_EPOCHS:
            warmup_lr_image = get_warmup_lr(epoch, WARMUP_EPOCHS, base_lr_image)
            warmup_lr_text = get_warmup_lr(epoch, WARMUP_EPOCHS, base_lr_text)
            warmup_lr_fusion = get_warmup_lr(epoch, WARMUP_EPOCHS, base_lr_fusion)
            warmup_lr_classifier = get_warmup_lr(epoch, WARMUP_EPOCHS, base_lr_classifier)

            # Update learning rates for each parameter group
            optimizer.param_groups[0]["lr"] = warmup_lr_image  # Unfrozen EfficientNet layer
            optimizer.param_groups[1]["lr"] = warmup_lr_text  # Unfrozen DistilBERT layer
            optimizer.param_groups[2]["lr"] = warmup_lr_image  # Image FC layer
            optimizer.param_groups[3]["lr"] = warmup_lr_text  # Text FC layer
            optimizer.param_groups[4]["lr"] = warmup_lr_fusion  # Fusion layer
            optimizer.param_groups[5]["lr"] = warmup_lr_classifier  # Classifier layer

        model.train()  # Set model to training modes
        total_train_loss = 0  # Track total training loss for the epoch
        batch_train_loss = 0  # Track batch loss for gradient accumulation
        step = 0  # Track the number of batches processed
        optimizer.zero_grad()

        # Training phase
        for step, batch in enumerate(dataloaders["train_loader"], 1):
            # Move data to device
            images = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            with autocast():
                outputs = model(input_ids, attention_mask, images)  # Send inputs to network and receive outputs
                loss = criterion(outputs, labels) / GRAD_ACCUM_STEPS  # Compute loss (no normalization for gradient accumulation)

            # Backward pass and optimizer step
            scaler.scale(loss).backward()  # Scale loss and backpropagate

            batch_train_loss += loss.item()
            total_train_loss += loss.item() * GRAD_ACCUM_STEPS  # Undo normalization for total loss

            step += 1

            # Perform optimizer step before learning rate scheduler step
            if step % GRAD_ACCUM_STEPS == 0 or step == len(dataloaders["train_loader"]):
                # Gradient Clipping
                scaler.unscale_(optimizer)  # Unscale gradients before clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Clip gradients to a max norm of 1.0

                # Optimizer step
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

                # Log batch loss
                logging.info(f"[TRAIN INFO] Batch {step}/{len(dataloaders['train_loader'])}, Accumulated loss over {GRAD_ACCUM_STEPS} batches: {batch_train_loss:.4f}")
                batch_train_loss = 0  # Reset batch loss for the next accumulation

        # Validation step to see how well model performs this epoch
        logging.info(f"[TRAIN INFO] Evaluating model...")
        val_loss, val_acc = evaluate_model(model, dataloaders["val_loader"], device)
        avg_train_loss = total_train_loss / len(dataloaders["train_loader"])

        # **Learning Rate Scheduler Handling**
        plateau_scheduler.step(val_loss)  

        # Log weight decay and learning rate updates
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": avg_train_loss,
            "val_loss": val_loss,
            "val_accuracy": val_acc,
            "train_val_loss_diff": avg_train_loss - val_loss,  # Track overfitting tendency
            "early_stopping_epochs": epochs_without_improvement,  # Track early stopping
            "learning_rate_image": optimizer.param_groups[0]["lr"],  # Log learning rates
            "learning_rate_text": optimizer.param_groups[1]["lr"],
            "learning_rate_fusion": optimizer.param_groups[4]["lr"],
            "learning_rate_classifier": optimizer.param_groups[5]["lr"],
        })

        logging.info(f"[TRAIN INFO] Epoch {epoch + 1}/{EPOCHS}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        # Check for improvement in validation loss
        if val_loss < best_val_loss - CONVERGENCE_THRESHOLD:  # If loss improves, save the model
            best_val_loss = val_loss
            epochs_without_improvement = 0  # Reset epochs without improvement counter for patience
            torch.save(model.state_dict(), f"{MODEL_NAME}_fold_{fold+1}.pth")
            logging.info(f"[TRAIN INFO] Best Model Saved for Fold {fold + 1}")
        else:
            epochs_without_improvement += 1  # Increment until patience reached

        # Early stopping if no improvement for epochs
        if epochs_without_improvement >= PATIENCE:
            total_training_time = time.time() - epoch_start_time
            logging.info(f"[TRAIN INFO] Early stopping at epoch {epoch + 1} as validation loss did not improve for {PATIENCE} epochs.")
            logging.info(f"[TRAIN INFO] Total Time: {total_training_time:.2f}s")
            wandb.finish()
            break

    total_training_time = time.time() - epoch_start_time
    logging.info(f"[TRAIN INFO] Fold {fold + 1} Training Complete at epoch {epoch + 1}. Total Time: {total_training_time:.2f}s")
    wandb.finish()

In [None]:
# Initialize Stratified K-Fold
skf = StratifiedKFold(n_splits=K_FOLDS, shuffle=True, random_state=42)

logging.info("[K-FOLD INFO] Starting Stratified K-Fold Cross-Validation...")

for fold, (train_idx, val_idx) in enumerate(skf.split(train_texts, train_labels)):

    
    fold_start_time = time.time()  # Start timing for this fold
    logging.info(f"[K-FOLD INFO] ============================== Fold {fold+1}/{K_FOLDS} ==============================")

    # Get train and validation subsets
    train_texts_fold = train_texts[train_idx]
    val_texts_fold = train_texts[val_idx]
    train_labels_fold = train_labels[train_idx]
    val_labels_fold = train_labels[val_idx]
    train_image_paths_fold = train_image_paths[train_idx]
    val_image_paths_fold = train_image_paths[val_idx]

    logging.info(f"[K-FOLD INFO] Fold {fold+1}:")
    logging.info(f"   Train Samples: {len(train_texts_fold)}")
    logging.info(f"   Validation Samples: {len(val_texts_fold)}")

    # Create dataset objects
    train_image_dataset = ImageDataset(train_image_paths_fold, train_labels_fold, transform["train"])
    val_image_dataset = ImageDataset(val_image_paths_fold, val_labels_fold, transform["val"])
    
    train_text_dataset = CustomTextDataset(train_texts_fold, train_labels_fold, tokenizer, max_len=MAX_LEN)
    val_text_dataset = CustomTextDataset(val_texts_fold, val_labels_fold, tokenizer, max_len=MAX_LEN)

    # Create multimodal datasets
    train_multimodal_dataset = MultimodalDataset(train_image_dataset, train_text_dataset)
    val_multimodal_dataset = MultimodalDataset(val_image_dataset, val_text_dataset)

    logging.info(f"[K-FOLD INFO] Created multimodal datasets for Fold {fold+1}")

    # Create DataLoaders
    dataloaders = {
        "train_loader": DataLoader(train_multimodal_dataset, batch_size=BATCH_SIZE, shuffle=True),
        "val_loader": DataLoader(val_multimodal_dataset, batch_size=BATCH_SIZE, shuffle=False)
    }

    logging.info(f"[K-FOLD INFO] DataLoaders initialized for Fold {fold+1}:")
    logging.info(f"   Train batches: {len(dataloaders['train_loader'])}, Validation batches: {len(dataloaders['val_loader'])}")

    # Initialize model, optimizer, and criterion
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MultimodalClassifier(num_classes=NUM_CLASSES).to(device)

    logging.info(f"[K-FOLD INFO] Model initialized on {device} for Fold {fold+1}")

    # Define Optimizer using AdamW
    optimizer = optim.AdamW([
        {"params": model.image_model.features[-3:].parameters(), "lr": LEARNING_RATE_UNFREEZE_IMAGE, "weight_decay": WEIGHT_DECAY_IMAGE},  # Unfrozen EfficientNet layer
        {"params": model.text_model.transformer.layer[-2:].parameters(), "lr": LEARNING_RATE_UNFREEZE_TEXT, "weight_decay": WEIGHT_DECAY_TEXT},  # Unfrozen DistilBERT layer
        {"params": model.image_fc.parameters(), "lr": LEARNING_RATE_IMAGE, "weight_decay": 0}, 
        {"params": model.text_fc.parameters(), "lr": LEARNING_RATE_TEXT, "weight_decay": 0},
        {"params": model.fusion_fc.parameters(), "lr": LEARNING_RATE_FUSION, "weight_decay": WEIGHT_DECAY_FUSION},  
        {"params": model.classifier.parameters(), "lr": LEARNING_RATE_CLASSIFIER, "weight_decay": WEIGHT_DECAY_CLASSIFIER}  
    ], betas=(0.9, 0.999), eps=1e-8)  # Default AdamW betas and eps


    logging.info(f"[K-FOLD INFO] Optimizer initialized for Fold {fold+1}:")
    # Define Loss Function
    criterion = torch.nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING_PREDICTION) 

    logging.info(f"[K-FOLD INFO] Loss function initialized for Fold {fold+1}")

    # Train model for this fold
    train_model(model, dataloaders, criterion, optimizer, device, fold, use_mixup=True)

    # Clear GPU cache
    torch.cuda.empty_cache()

    # Measure Fold Time
    fold_time = time.time() - fold_start_time
    logging.info(f"[K-FOLD INFO] Fold {fold+1} completed in {fold_time:.2f} seconds")


2025-03-25 19:17:02,879 - INFO - [K-FOLD INFO] Starting Stratified K-Fold Cross-Validation...
2025-03-25 19:17:02,886 - INFO - [K-FOLD INFO] Fold 1:
2025-03-25 19:17:02,886 - INFO -    Train Samples: 8594
2025-03-25 19:17:02,887 - INFO -    Validation Samples: 2149
2025-03-25 19:17:02,888 - INFO - [K-FOLD INFO] Created multimodal datasets for Fold 1
2025-03-25 19:17:02,889 - INFO - [K-FOLD INFO] DataLoaders initialized for Fold 1:
2025-03-25 19:17:02,890 - INFO -    Train batches: 135, Validation batches: 34
2025-03-25 19:17:03,660 - INFO - [K-FOLD INFO] Model initialized on cuda for Fold 1
2025-03-25 19:17:03,661 - INFO - [K-FOLD INFO] Optimizer initialized for Fold 1:
2025-03-25 19:17:03,662 - INFO - [K-FOLD INFO] Loss function initialized for Fold 1


wandb: Currently logged in as: shcau (shcau-university-of-calgary-in-alberta) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


  scaler = GradScaler()


2025-03-25 19:17:05,791 - INFO - [TRAIN INFO] Starting Training...


  with autocast():


2025-03-25 19:17:13,844 - INFO - [TRAIN INFO] Batch 4/135, Accumulated loss over 4 batches: 1.1097
2025-03-25 19:17:21,919 - INFO - [TRAIN INFO] Batch 8/135, Accumulated loss over 4 batches: 1.4189
2025-03-25 19:17:29,936 - INFO - [TRAIN INFO] Batch 12/135, Accumulated loss over 4 batches: 1.3560
2025-03-25 19:17:37,761 - INFO - [TRAIN INFO] Batch 16/135, Accumulated loss over 4 batches: 1.3363
2025-03-25 19:17:45,558 - INFO - [TRAIN INFO] Batch 20/135, Accumulated loss over 4 batches: 1.3328
2025-03-25 19:17:53,138 - INFO - [TRAIN INFO] Batch 24/135, Accumulated loss over 4 batches: 1.2556
2025-03-25 19:18:00,680 - INFO - [TRAIN INFO] Batch 28/135, Accumulated loss over 4 batches: 1.2571
2025-03-25 19:18:08,211 - INFO - [TRAIN INFO] Batch 32/135, Accumulated loss over 4 batches: 1.2654
2025-03-25 19:18:16,340 - INFO - [TRAIN INFO] Batch 36/135, Accumulated loss over 4 batches: 1.2163
2025-03-25 19:18:24,097 - INFO - [TRAIN INFO] Batch 40/135, Accumulated loss over 4 batches: 1.2204
20

0,1
early_stopping_epochs,▁▁▁▁▁▁▁▂▃▃▁▂▃▃▄▅▆▆▇█
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
learning_rate_classifier,▁▂▃▄▅▆▇██████▃▃▃▃▁▁▁
learning_rate_fusion,▁▂▃▄▅▆▇██████▃▃▃▃▁▁▁
learning_rate_image,▁▂▃▄▅▆▇██████▃▃▃▃▁▁▁
learning_rate_text,▁▂▃▄▅▆▇██████▃▃▃▃▁▁▁
train_loss,█▅▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁
train_val_loss_diff,█▇▆▅▅▅▄▄▃▃▃▂▂▂▂▁▁▁▁▁
val_accuracy,▁▅▆▇██▇▇▇██▇▇▇██████
val_loss,█▄▂▂▁▁▁▂▁▁▁▂▁▂▁▁▁▂▁▁

0,1
early_stopping_epochs,9.0
epoch,20.0
learning_rate_classifier,0.00045
learning_rate_fusion,9e-05
learning_rate_image,9e-05
learning_rate_text,0.0
train_loss,0.2445
train_val_loss_diff,-0.14096
val_accuracy,0.88274
val_loss,0.38546


2025-03-25 21:03:49,705 - INFO - [TRAIN INFO] Fold 1 Training Complete at epoch 20. Total Time: 6403.91s
2025-03-25 21:03:49,723 - INFO - [K-FOLD INFO] Fold 1 completed in 6406.84 seconds
2025-03-25 21:03:49,726 - INFO - [K-FOLD INFO] Fold 2:
2025-03-25 21:03:49,726 - INFO -    Train Samples: 8594
2025-03-25 21:03:49,727 - INFO -    Validation Samples: 2149
2025-03-25 21:03:49,727 - INFO - [K-FOLD INFO] Created multimodal datasets for Fold 2
2025-03-25 21:03:49,728 - INFO - [K-FOLD INFO] DataLoaders initialized for Fold 2:
2025-03-25 21:03:49,729 - INFO -    Train batches: 135, Validation batches: 34
2025-03-25 21:03:50,464 - INFO - [K-FOLD INFO] Model initialized on cuda for Fold 2
2025-03-25 21:03:50,466 - INFO - [K-FOLD INFO] Optimizer initialized for Fold 2:
2025-03-25 21:03:50,466 - INFO - [K-FOLD INFO] Loss function initialized for Fold 2


2025-03-25 21:03:51,187 - INFO - [TRAIN INFO] Starting Training...
2025-03-25 21:03:57,623 - INFO - [TRAIN INFO] Batch 4/135, Accumulated loss over 4 batches: 1.0719
2025-03-25 21:04:06,017 - INFO - [TRAIN INFO] Batch 8/135, Accumulated loss over 4 batches: 1.3964
2025-03-25 21:04:14,196 - INFO - [TRAIN INFO] Batch 12/135, Accumulated loss over 4 batches: 1.3534
2025-03-25 21:04:21,867 - INFO - [TRAIN INFO] Batch 16/135, Accumulated loss over 4 batches: 1.3247
2025-03-25 21:04:29,611 - INFO - [TRAIN INFO] Batch 20/135, Accumulated loss over 4 batches: 1.3092
2025-03-25 21:04:37,404 - INFO - [TRAIN INFO] Batch 24/135, Accumulated loss over 4 batches: 1.2952
2025-03-25 21:04:45,209 - INFO - [TRAIN INFO] Batch 28/135, Accumulated loss over 4 batches: 1.2671
2025-03-25 21:04:52,821 - INFO - [TRAIN INFO] Batch 32/135, Accumulated loss over 4 batches: 1.2432
2025-03-25 21:05:00,619 - INFO - [TRAIN INFO] Batch 36/135, Accumulated loss over 4 batches: 1.2396
2025-03-25 21:05:08,198 - INFO - [T

In [None]:
# for fold in range(K_FOLDS):
#     logging.info(f"\n[TEST INFO] Evaluating Fold {fold + 1} on Test Set...")

#     # Load best model for the fold
#     model = MultimodalClassifier(num_classes=NUM_CLASSES).to(device)
#     model_path = f"best_model_fold_{fold + 1}.pth"
    
#     try:
#         model.load_state_dict(torch.load(model_path))
#         logging.info(f"[TEST INFO] Loaded best model for Fold {fold + 1} from {model_path}")
#     except FileNotFoundError:
#         logging.error(f"[ERROR] Model file {model_path} not found! Skipping Fold {fold + 1} evaluation.")
#         continue  # Skip to the next fold if model file is missing

#     model.eval()  # Set to evaluation mode

#     # Evaluate model on test data
#     test_loss, test_acc = evaluate_model(model, test_loader, device)

#     # Log test set performance for the fold
#     logging.info(f"[TEST INFO] Fold {fold + 1} Test Performance:")
#     logging.info(f"   Test Loss: {test_loss:.4f}")
#     logging.info(f"   Test Accuracy: {test_acc:.2f}%")