
"""
Hybrid Garbage Classification Model 

This script implements a hybrid model for garbage classification that combines:
- Image features using ResNet18
- Text features using DistilBERT

The model fuses features from both modalities to improve classification performance
beyond what either modality could achieve individually. The implementation includes
both an initial training phase and an automatic fine-tuning phase.
"""

In [None]:
import os
import re
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from transformers import DistilBertModel, DistilBertTokenizer
from PIL import Image
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

In [None]:
# Set device - Uses GPU if available, otherwise falls back to CPU
# This is important for deep learning models which benefit significantly from GPU acceleration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Constants and Hyperparameters
NUM_CLASSES = 4 # Number of garbage categories to classify
MAX_TEXT_LENGTH = 24 # Maximum length of text tokens for BERT processing
BATCH_SIZE = 32 # Number of samples processed in each training batch
EPOCHS = 5 # Number of complete passes through the training dataset
IMAGE_SIZE = 224 # Input image size (224x224 is standard for many CNN architectures)
LEARNING_RATE = 2e-5 # Base learning rate for optimizer

In [None]:
# Dataset Paths - Points to the directories containing the garbage classification dataset
# Dataset is organized in a class-based folder structure
TRAIN_PATH = "/work/TALC/enel645_2025w/garbage_data/CVPR_2024_dataset_Train"
VAL_PATH = "/work/TALC/enel645_2025w/garbage_data/CVPR_2024_dataset_Val"
TEST_PATH = "/work/TALC/enel645_2025w/garbage_data/CVPR_2024_dataset_Test"

In [None]:
# Text Preprocessing
def extract_text_from_path(file_path):
    """
    Extracts text information from the image filename to use as textual input.
    This function extracts the base filename, removes underscores and numbers,
    which helps in creating clean text descriptions from filenames.
    
    Args:
        file_path: Path to the image file
        
    Returns:
        Cleaned text derived from the filename
    """
    file_name = os.path.basename(file_path)
    file_name_no_ext, _ = os.path.splitext(file_name)
    text = file_name_no_ext.replace('_', ' ')
    return re.sub(r'\d+', '', text)  # Remove numbers to clean up text

In [None]:
# Load dataset paths
def get_files_with_labels(root_path):
    """
    Loads all image files from a directory structure where each subdirectory
    represents a class. Maps each class name to a numeric label.
    
    Args:
        root_path: Path to the root directory containing class subdirectories
        
    Returns:
        Tuple of (image_paths, texts, labels) for all images in the dataset
    """
    image_paths, texts, labels = [], [], []
    class_folders = sorted(os.listdir(root_path))
    # Create a mapping from class folder names to numeric labels
    label_map = {class_name: idx for idx, class_name in enumerate(class_folders)}
     # Traverse the directory structure
    for class_name in class_folders:
        class_path = os.path.join(root_path, class_name)
        if os.path.isdir(class_path):
            for file_name in os.listdir(class_path):
                file_path = os.path.join(class_path, file_name)
                # Only process image files
                if file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    image_paths.append(file_path)
                    texts.append(extract_text_from_path(file_path))
                    labels.append(label_map[class_name])
    return image_paths, texts, labels

In [None]:
# Data Augmentation Pipeline for Training Images
# These transformations increase the diversity of training data and help prevent overfitting
train_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), # Resize to standard input size
    transforms.RandomHorizontalFlip(), # Randomly flip images horizontally
    transforms.RandomRotation(20), # Randomly rotate images by up to 20 degrees
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1), # Randomly adjust brightness, contrast, saturation, hue
    transforms.ToTensor(), # Convert image to tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Normalize with ImageNet means and stds
])

In [None]:
# Transformation Pipeline for Validation/Test Images
# No augmentation for evaluation, just preprocessing
test_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
# Dataset Class
class HybridDataset(Dataset):
    """
    Custom dataset class that handles both image and text modalities.
    Loads images from disk and processes text through a BERT tokenizer.
    
    This enables the model to work with both types of input simultaneously.
    """
    def __init__(self, image_paths, texts, labels, tokenizer, transform, max_len=MAX_TEXT_LENGTH):
        """
        Initialize the dataset with paths, texts, and preprocessing components.
        
        Args:
            image_paths: List of paths to image files
            texts: List of text descriptions corresponding to each image
            labels: List of class labels for each image
            tokenizer: BERT tokenizer for text processing
            transform: Image transformation pipeline
            max_len: Maximum token length for text inputs
        """
        self.image_paths = image_paths
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.transform = transform
        self.max_len = max_len
        """Return the total number of samples in the dataset"""

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        """
        Get a single sample from the dataset, including both image and text.
        
        Args:
            idx: Index of the sample to retrieve
            
        Returns:
            Dictionary containing processed image tensor, tokenized text, and label
        """
        # Load and process image
        image = Image.open(self.image_paths[idx]).convert('RGB')
        image = self.transform(image)

        # Process text with BERT tokenizer
        encoding = self.tokenizer(
            self.texts[idx], padding='max_length', truncation=True,
            max_length=self.max_len, return_tensors='pt'
        )
# Return all components needed for model input
        return {
            'image': image,
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'label': torch.tensor(self.labels[idx], dtype=torch.long)
        }

In [None]:
# Attention-based Fusion Module
class AttentionFusion(nn.Module):
    """
    Attention mechanism for multimodal fusion, which learns to weight
    the importance of each modality (image vs text) for classification.
    
    This adaptive weighting is more sophisticated than simple concatenation
    or averaging, as it can emphasize the more informative modality for each sample.
    """
    def __init__(self, img_dim=1024, text_dim=1024):
        super().__init__()
        self.attn = nn.Linear(img_dim + text_dim, 1)

    def forward(self, img_features, text_features):
        """
        Compute weighted combination of image and text features.
        
        Args:
            img_features: Image feature tensor
            text_features: Text feature tensor
            
        Returns:
            Fused feature tensor with adaptive weighting
        """
         # Compute attention weight (importance of image vs text)
        weights = torch.sigmoid(self.attn(torch.cat([img_features, text_features], dim=1)))
        # Weighted combination of features
        return weights * img_features + (1 - weights) * text_features

In [None]:
# Hybrid Model
class HybridModel(nn.Module):
    """
    End-to-end multimodal model that processes both images and text,
    and combines their features for classification.
    
    Architecture:
    1. Image Branch: ResNet50 → FC → Normalization → ReLU
    2. Text Branch: DistilBERT → FC → Normalization → ReLU
    3. Attention Fusion: Combines features from both branches
    4. Classifier: Makes final class prediction
    """
    def __init__(self, num_classes=NUM_CLASSES):
        """
        Initialize the complete model architecture.
        
        Args:
            num_classes: Number of output classes
        """
        super().__init__()
        # Image Processing Branch
        # Load pretrained ResNet50 and remove the classification head
        self.image_model = models.resnet50(weights="IMAGENET1K_V2")
        self.image_model = nn.Sequential(*list(self.image_model.children())[:-1])
        # Feature transformation network for image embeddings
        self.image_fc = nn.Sequential(
            nn.Linear(2048, 1024),  # Reduce dimensionality from 2048 to 1024
            nn.BatchNorm1d(1024), # Normalize activations for stability
            nn.ReLU()  # Apply non-linearity
        )
# Load pretrained DistilBERT for text feature extraction
        self.text_model = DistilBertModel.from_pretrained('distilbert-base-uncased')
        # Feature transformation network for text embeddings
        self.text_fc = nn.Sequential(
            nn.Linear(768, 1024), # Transform from BERT's 768 dimension to 1024
            nn.BatchNorm1d(1024),  # Normalize activations
            nn.ReLU() # Apply non-linearity
        )
# Attention-based fusion module to combine modalities
        self.fusion = AttentionFusion(img_dim=1024, text_dim=1024)
        self.classifier = nn.Linear(1024, num_classes)
        # Process images through ResNet
    def forward(self, images, input_ids, attention_mask):
        """
        Forward pass through the entire model.
        
        Args:
            images: Batch of input images
            input_ids: Tokenized text inputs
            attention_mask: Attention mask for BERT
            
        Returns:
            Classification logits
        """
        img_features = self.image_model(images).squeeze()
        # Ensure img_features has the right shape
        if len(img_features.shape) == 1:
            img_features = img_features.unsqueeze(0)
        img_features = self.image_fc(img_features)
        
        text_output = self.text_model(input_ids=input_ids, attention_mask=attention_mask)[0]
        text_features = self.text_fc(text_output[:, 0, :])
        fused_features = self.fusion(img_features, text_features)
        return self.classifier(fused_features)

In [None]:
# Training and Evaluation Functions
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=EPOCHS):
    """
    Trains the model with periodic validation.
    
    Args:
        model: Model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        criterion: Loss function
        optimizer: Optimization algorithm
        scheduler: Learning rate scheduler
        epochs: Number of training epochs
        
    Returns:
        Trained model (with best weights loaded)
      """
    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        train_loss, correct_train, total_train = 0, 0, 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, input_ids, attn_mask, labels = batch['image'].to(device), batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['label'].to(device)

            optimizer.zero_grad()
            outputs = model(images, input_ids, attn_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            predictions = outputs.argmax(dim=1)
            correct_train += (predictions == labels).sum().item()
            total_train += labels.size(0)

        train_acc = correct_train / total_train
        train_loss /= len(train_loader)
# Validation phase
        model.eval()
        val_loss, correct_val, total_val = 0, 0, 0
 # No gradient computation needed for validation
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating"):
                images, input_ids, attn_mask, labels = batch['image'].to(device), batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['label'].to(device)
                outputs = model(images, input_ids, attn_mask)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                predictions = outputs.argmax(dim=1)
                correct_val += (predictions == labels).sum().item()
                total_val += labels.size(0)

        val_acc = correct_val / total_val
        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
 # Update learning rate based on scheduler
        scheduler.step()

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_hybrid_model.pth')
            print("Saved Best Model!")

    model.load_state_dict(torch.load('best_hybrid_model.pth'))
    return model

In [None]:
def evaluate_model(model, test_loader):
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            images, input_ids, attn_mask, labels = batch['image'].to(device), batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['label'].to(device)
            outputs = model(images, input_ids, attn_mask)
            preds = outputs.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = (np.array(all_preds) == np.array(all_labels)).mean()
    print(f"Test Accuracy: {accuracy:.4f}")
    print("Classification Report:")
    print(classification_report(all_labels, all_preds))

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.show()

    return all_preds, all_labels

In [None]:
# Main Execution
if __name__ == "__main__":
    # Load tokenizer
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

    # Load datasets
    train_image_paths, train_texts, train_labels = get_files_with_labels(TRAIN_PATH)
    val_image_paths, val_texts, val_labels = get_files_with_labels(VAL_PATH)
    test_image_paths, test_texts, test_labels = get_files_with_labels(TEST_PATH)

    train_dataset = HybridDataset(train_image_paths, train_texts, train_labels, tokenizer, train_transform)
    val_dataset = HybridDataset(val_image_paths, val_texts, val_labels, tokenizer, test_transform)
    test_dataset = HybridDataset(test_image_paths, test_texts, test_labels, tokenizer, test_transform)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

    # Initialize model
    model = HybridModel().to(device)

    # Define optimizer and scheduler
    optimizer = optim.AdamW([
        {"params": model.image_model.parameters(), "lr": 1e-4},
        {"params": model.text_model.parameters(), "lr": 5e-6},
        {"params": model.classifier.parameters(), "lr": 1e-4},
    ])
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)

    # Define loss function
    criterion = nn.CrossEntropyLoss()

    # Train the model
    print("Starting training...")
    model = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=EPOCHS)

    # Evaluate on test set
    print("Evaluating on test set...")
    evaluate_model(model, test_loader)