ENEL-645 Assignment 2


Team Members:
Jaskirat Singh
Kate Reimann
Riley Koppang
Roxanne Mai

// Intro and purpose

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image

# Hugging Face transformers for text
from transformers import DistilBertModel, DistilBertTokenizer

import os
import re

In [None]:
// Data pre-processing (split already done)
class MultiModalDataset(Dataset):
    """
    Custom dataset that returns:
      - an image (loaded & transformed),
      - tokenized text input_ids, attention_mask,
      - and a label.
    """
    def __init__(self, 
                 image_paths, 
                 texts, 
                 labels, 
                 tokenizer, 
                 image_transform=None, 
                 max_text_len=32):
        """
        Args:
            image_paths (List[str]): Paths to image files.
            texts (List[str]): Corresponding text for each image.
            labels (List[int]): Integer labels for classification.
            tokenizer: DistilBertTokenizer (or similar).
            image_transform: torchvision transforms for image.
            max_text_len (int): Maximum tokens for DistilBERT.
        """
        self.image_paths = image_paths
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.image_transform = image_transform
        self.max_text_len = max_text_len

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # ---- Get image ----
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        if self.image_transform:
            image = self.image_transform(image)

        # ---- Get text ----
        text = str(self.texts[idx])
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_text_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        input_ids = encoding['input_ids'].squeeze()       # shape: [max_text_len]
        attention_mask = encoding['attention_mask'].squeeze()

        # ---- Get label ----
        label = self.labels[idx]

        return {
            'image': image, 
            'input_ids': input_ids, 
            'attention_mask': attention_mask, 
            'label': torch.tensor(label, dtype=torch.long)
        }

Transforms and tokenizations

In [None]:
# Example image transform (you can tweak as needed)
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # Normalization for ImageNet:
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# DistilBERT tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

multi modal

In [None]:
class MultiModalClassifier(nn.Module):
    def __init__(self, num_classes=4, projection_dim=128):
        """
        num_classes: number of output classes for final classification.
        projection_dim: dimension for x after the Dense projection for both text and image.
        """
        super(MultiModalClassifier, self).__init__()

        # ----- Image Feature Extractor (ResNet) -----
        # We'll use a pretrained ResNet18. 
        # You can choose weights='IMAGENET1K_V1' or similar in newer torchvision versions.
        self.image_model = models.resnet18(pretrained=True)
        
        # Remove the final classification layer (fc) so we get a 512-d or 1000-d feature.
        # By default, ResNet18's fc out_features=1000. We'll keep that for now.
        # If you want the 512-d embedding, you can do:
        #  self.image_model.fc = nn.Identity()
        #  then you'd have to know it's 512 dims. 
        # Here, we'll just keep the 1000-d final.
        
        # ----- Text Feature Extractor (DistilBERT) -----
        self.text_model = DistilBertModel.from_pretrained('distilbert-base-uncased')
        
        # DistilBERT hidden_size is typically 768. We can confirm as:
        text_hidden_size = self.text_model.config.hidden_size
        
        # ----- Projection Layers (Dense -> x) -----
        # For images: input size is 1000 because resnet18(fc) outputs 1000
        # For text: input size is text_hidden_size (768 for distilbert-base-uncased)
        
        self.image_proj = nn.Linear(1000, projection_dim)
        self.text_proj  = nn.Linear(text_hidden_size, projection_dim)
        
        # ----- Final Classification -----
        # We'll combine the two projected vectors by concatenation -> dimension is 2 * projection_dim
        self.classifier = nn.Linear(2 * projection_dim, num_classes)
        
    def forward(self, images, input_ids, attention_mask):
        """
        images: Tensor [batch, 3, 224, 224]
        input_ids: Tensor [batch, max_len]
        attention_mask: Tensor [batch, max_len]
        """
        # ----- IMAGE FORWARD -----
        # Pass images through ResNet. 
        # By default, ResNet includes its final fc layer, returning [batch, 1000]
        # If you replaced that fc with an Identity() layer, you'd get [batch, 512].
        f_image = self.image_model(images)  # shape: [batch, 1000]
        
        # Dense projection to dimension x
        x_image = self.image_proj(f_image)  # shape: [batch, projection_dim]
        
        # Normalize 
        # This step ensures the range of values is consistent for both modalities
        xnorm_image = F.normalize(x_image, p=2, dim=1)
        
        # ----- TEXT FORWARD -----
        # DistilBERT returns a tuple: (last_hidden_state, ...)
        # last_hidden_state shape: [batch, seq_len, hidden_size]
        # We often take the first token ([CLS]) or the pooled output
        # DistilBERT doesn't have a [CLS] token pooler, so we often take [0, 0] or average pool
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = text_outputs[0]   # shape: [batch, seq_len, hidden_size]
        
        # We'll just take the first token's embedding (like BERT's [CLS])
        # Alternatively, you could pool or average
        cls_text = last_hidden_state[:, 0]    # shape: [batch, hidden_size]
        
        # Dense projection to dimension x
        x_text = self.text_proj(cls_text)     # shape: [batch, projection_dim]
        
        # Normalize
        xnorm_text = F.normalize(x_text, p=2, dim=1)
        
        # ----- COMBINE & CLASSIFY -----
        # Concatenate the normalized features
        combined = torch.cat([xnorm_image, xnorm_text], dim=1)  # shape: [batch, 2*projection_dim]
        
        # Final linear for classification
        logits = self.classifier(combined)    # shape: [batch, num_classes]
        return logits