In [None]:
import os
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, CLIPVisionModel, CLIPImageProcessor
from PIL import Image
from tqdm.notebook import tqdm  # Use tqdm notebook for nice progress bars

# Setup Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# --- Paths ---
# Relative path from 'notebooks/' to 'data/'
DATA_ROOT = "../data/processed"
IMAGES_DIR = os.path.join(DATA_ROOT, "images")
CSV_PATH = os.path.join(DATA_ROOT, "final_metadata_10k.csv")

# --- Hyperparameters ---
BATCH_SIZE = 32
LEARNING_RATE = 5e-5
EPOCHS = 3
MAX_TEXT_LEN = 128
PROJECTION_DIM = 384 # Must match MiniLM dimension
VISION_MODEL_NAME = "openai/clip-vit-base-patch16"
TEXT_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"

In [None]:
class Cap3DDataset(Dataset):
    def __init__(self, csv_path, images_dir, tokenizer_name, vision_model_name):
        """
        Args:
            csv_path (str): Path to the metadata CSV.
            images_dir (str): Directory containing the images.
        """
        self.df = pd.read_csv(csv_path)
        self.images_dir = images_dir
        
        # Initialize Processors
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.processor = CLIPImageProcessor.from_pretrained(vision_model_name)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        uid = row['uid'] # Assuming column name is 'uid' or index 0
        text = row['text_description'] # Assuming column name is 'text_description' or index 1
        
        # Construct Image Path reliably
        # We use the UID to find the file in the images directory
        image_filename = f"{uid}.png"
        image_path = os.path.join(self.images_dir, image_filename)
        
        # Load and Convert Image
        try:
            image = Image.open(image_path).convert("RGB")
        except FileNotFoundError:
            # Fallback or error handling if an image is missing
            print(f"Warning: Image not found for UID {uid} at {image_path}")
            # Create a black dummy image to prevent crash (optional)
            image = Image.new('RGB', (224, 224), color='black')

        # Process Image
        # return_tensors="pt" gives [1, 3, 224, 224], we squeeze to [3, 224, 224]
        pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.squeeze(0)
        
        # Tokenize Text
        text_inputs = self.tokenizer(
            text, 
            padding='max_length', 
            truncation=True, 
            max_length=MAX_TEXT_LEN, 
            return_tensors="pt"
        )
        
        return {
            'pixel_values': pixel_values,
            'input_ids': text_inputs['input_ids'].squeeze(0),
            'attention_mask': text_inputs['attention_mask'].squeeze(0)
        }

# Sanity Check: Load dataset and print one sample
try:
    dataset = Cap3DDataset(CSV_PATH, IMAGES_DIR, TEXT_MODEL_NAME, VISION_MODEL_NAME)
    print(f"Dataset loaded successfully with {len(dataset)} samples.")
    sample = dataset[0]
    print("Sample keys:", sample.keys())
    print("Image shape:", sample['pixel_values'].shape)
    print("Text shape:", sample['input_ids'].shape)
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Please check your paths in Cell 2.")

In [None]:
class MultiModalContrastiveModel(nn.Module):
    def __init__(self, text_model_name, vision_model_name, projection_dim=384):
        super().__init__()
        
        # 1. Text Encoder (MiniLM - 384D output)
        self.text_encoder = AutoModel.from_pretrained(text_model_name)
        
        # 2. Vision Encoder (CLIP ViT - 768D output)
        self.vision_encoder = CLIPVisionModel.from_pretrained(vision_model_name)
        
        # 3. Projection Head
        # CLIP ViT-Base-Patch16 has hidden size 768
        vision_hidden_size = self.vision_encoder.config.hidden_size
        self.vision_projection = nn.Linear(vision_hidden_size, projection_dim)
        
        # Learnable Temperature (logit_scale)
        # Initialized to log(1/0.07) approx 2.65, standard for contrastive learning
        self.logit_scale = nn.Parameter(torch.ones([]) * 2.6592)

    def forward(self, input_ids, attention_mask, pixel_values):
        # --- Text Embedding ---
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        
        # Mean Pooling for MiniLM
        token_embeddings = text_outputs.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        text_embeds = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        
        # --- Image Embedding ---
        vision_outputs = self.vision_encoder(pixel_values=pixel_values)
        # Use pooled output (CLS token) from CLIP ViT
        image_embeds_raw = vision_outputs.pooler_output 
        
        # Project to 384D
        image_embeds = self.vision_projection(image_embeds_raw)
        
        # --- Normalization ---
        text_embeds = F.normalize(text_embeds, p=2, dim=1)
        image_embeds = F.normalize(image_embeds, p=2, dim=1)
        
        return text_embeds, image_embeds, self.logit_scale.exp()

In [None]:
# --- 1. Data Loader ---
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) # Set num_workers=0 for simple debugging first

# --- 2. Model & Optimizer ---
model = MultiModalContrastiveModel(TEXT_MODEL_NAME, VISION_MODEL_NAME, projection_dim=PROJECTION_DIM).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

# --- 3. Training Loop ---
print("Starting Training...")

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0.0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    
    for batch in progress_bar:
        # Move batch to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        pixel_values = batch['pixel_values'].to(device)
        
        # Forward Pass
        text_embeds, image_embeds, logit_scale = model(input_ids, attention_mask, pixel_values)
        
        # Compute Similarity Matrix
        # (batch_size, batch_size)
        logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
        logits_per_image = logits_per_text.t()
        
        # Generate Labels
        # The labels are just the diagonal indices [0, 1, 2, ... batch_size-1]
        current_batch_size = input_ids.shape[0]
        labels = torch.arange(current_batch_size).to(device)
        
        # Symmetric Contrastive Loss
        loss_text = criterion(logits_per_text, labels)
        loss_image = criterion(logits_per_image, labels)
        total_loss = (loss_text + loss_image) / 2
        
        # Backpropagation
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        # Update Metrics
        epoch_loss += total_loss.item()
        progress_bar.set_postfix({"loss": total_loss.item()})
    
    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")

print("Training finished!")

In [None]:
# Create a directory to save weights
save_path = "../models/contrastive_finetuned"
os.makedirs(save_path, exist_ok=True)

# Save the state dict
torch.save(model.state_dict(), os.path.join(save_path, "model_state_dict.pth"))

print(f"Model saved to {save_path}")