# VisionRec V2 - Multimodal Training Pipeline

This notebook implements the training pipeline for a custom multimodal embedding model that learns a shared embedding space for both images and text using triplet loss.

## Architecture Overview
- **Dual Encoder**: Separate encoders for images and text
- **Image Encoder**: ResNet-50 backbone with projection head
- **Text Encoder**: Transformer-based encoder with projection head
- **Loss Function**: Triplet Loss with semi-hard negative mining
- **Triplet Types**: Image-Image, Text-Image, Image-Text

## Training Strategy
1. Load and preprocess multimodal dataset
2. Create triplets with semi-hard negative mining
3. Train dual encoders to produce normalized embeddings
4. Validate on test set
5. Save trained model weights and configuration

## 1. Import Dependencies

Import all required libraries for dataset handling, model building, training, and visualization.

In [None]:
import os
import json
import random
import numpy as np
from pathlib import Path
from PIL import Image
from typing import List, Tuple, Dict
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

import matplotlib.pyplot as plt
from tqdm import tqdm

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

# Device configuration (CPU-friendly)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

## 2. Configuration

Define all hyperparameters and paths for training.

In [None]:
class Config:
    # Paths
    DATASET_ROOT = r"E:\Projects\AI Based\RecTrio\datasets\animals\raw-img"
    OUTPUT_DIR = r".kaggle/working/outputs/model"
    
    # Model architecture
    EMBEDDING_DIM = 512
    IMAGE_SIZE = 224
    TEXT_MAX_LENGTH = 50
    VOCAB_SIZE = 10000
    TEXT_EMBED_DIM = 300
    TEXT_HIDDEN_DIM = 256
    
    # Training hyperparameters
    BATCH_SIZE = 32
    NUM_EPOCHS = 50
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5
    
    # Triplet loss
    MARGIN = 0.5
    MINING_STRATEGY = 'semi-hard'  # semi-hard negative mining
    
    # Data split
    TRAIN_SPLIT = 0.8
    VAL_SPLIT = 0.1
    TEST_SPLIT = 0.1
    
    # Training
    NUM_WORKERS = 2
    SAVE_FREQ = 5  # Save model every N epochs
    
config = Config()

# Create output directory
os.makedirs(config.OUTPUT_DIR, exist_ok=True)

print("Configuration:")
print(f"  Dataset: {config.DATASET_ROOT}")
print(f"  Output: {config.OUTPUT_DIR}")
print(f"  Embedding Dim: {config.EMBEDDING_DIM}")
print(f"  Batch Size: {config.BATCH_SIZE}")
print(f"  Epochs: {config.NUM_EPOCHS}")
print(f"  Learning Rate: {config.LEARNING_RATE}")
print(f"  Margin: {config.MARGIN}")

## 3. Dataset Preparation

Build vocabulary and prepare the multimodal dataset with image paths, text labels, and class IDs.

In [None]:
def load_dataset(dataset_root: str):
    """
    Load dataset from directory structure.
    Expected structure: dataset_root/class_name/image_files
    
    Returns:
        List of tuples: (image_path, text_label, class_id)
    """
    dataset = []
    class_names = sorted([d for d in os.listdir(dataset_root) 
                         if os.path.isdir(os.path.join(dataset_root, d))])
    
    class_to_id = {name: idx for idx, name in enumerate(class_names)}
    
    for class_name in class_names:
        class_dir = os.path.join(dataset_root, class_name)
        class_id = class_to_id[class_name]
        
        for img_file in os.listdir(class_dir):
            if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                img_path = os.path.join(class_dir, img_file)
                dataset.append((img_path, class_name, class_id))
    
    return dataset, class_to_id, class_names

# Load dataset
print("Loading dataset...")
full_dataset, class_to_id, class_names = load_dataset(config.DATASET_ROOT)

print(f"\nDataset Statistics:")
print(f"  Total samples: {len(full_dataset)}")
print(f"  Number of classes: {len(class_names)}")
print(f"  Classes: {class_names}")

# Count samples per class
class_counts = defaultdict(int)
for _, label, _ in full_dataset:
    class_counts[label] += 1

print("\nSamples per class:")
for class_name in class_names:
    print(f"  {class_name}: {class_counts[class_name]}")

In [None]:
# Build vocabulary from class names
class Vocabulary:
    def __init__(self, max_vocab_size=10000):
        self.word2idx = {'<PAD>': 0, '<UNK>': 1}
        self.idx2word = {0: '<PAD>', 1: '<UNK>'}
        self.word_counts = defaultdict(int)
        self.max_vocab_size = max_vocab_size
        
    def build_vocab(self, texts: List[str]):
        """Build vocabulary from text labels"""
        # Count word frequencies
        for text in texts:
            words = text.lower().split()
            for word in words:
                self.word_counts[word] += 1
        
        # Sort by frequency and take top words
        sorted_words = sorted(self.word_counts.items(), 
                            key=lambda x: x[1], reverse=True)
        
        # Add to vocabulary (leave space for PAD and UNK)
        for word, _ in sorted_words[:self.max_vocab_size - 2]:
            idx = len(self.word2idx)
            self.word2idx[word] = idx
            self.idx2word[idx] = word
    
    def encode(self, text: str, max_length: int) -> List[int]:
        """Convert text to token indices"""
        words = text.lower().split()
        tokens = [self.word2idx.get(word, self.word2idx['<UNK>']) 
                 for word in words]
        
        # Pad or truncate
        if len(tokens) < max_length:
            tokens += [self.word2idx['<PAD>']] * (max_length - len(tokens))
        else:
            tokens = tokens[:max_length]
        
        return tokens
    
    def __len__(self):
        return len(self.word2idx)

# Build vocabulary
vocab = Vocabulary(max_vocab_size=config.VOCAB_SIZE)
all_labels = [label for _, label, _ in full_dataset]
vocab.build_vocab(all_labels)

print(f"\nVocabulary built:")
print(f"  Vocabulary size: {len(vocab)}")
print(f"  Sample words: {list(vocab.word2idx.keys())[:10]}")

# Update config with actual vocab size
config.VOCAB_SIZE = len(vocab)

In [None]:
# Split dataset into train, validation, and test sets
random.shuffle(full_dataset)

train_size = int(config.TRAIN_SPLIT * len(full_dataset))
val_size = int(config.VAL_SPLIT * len(full_dataset))

train_data = full_dataset[:train_size]
val_data = full_dataset[train_size:train_size + val_size]
test_data = full_dataset[train_size + val_size:]

print(f"\nDataset splits:")
print(f"  Train: {len(train_data)} samples ({config.TRAIN_SPLIT * 100}%)")
print(f"  Validation: {len(val_data)} samples ({config.VAL_SPLIT * 100}%)")
print(f"  Test: {len(test_data)} samples ({config.TEST_SPLIT * 100}%)")

## 4. Custom Dataset Class

PyTorch Dataset class that handles multimodal data loading and preprocessing.

In [None]:
class MultimodalDataset(Dataset):
    """
    Multimodal dataset for image-text embedding learning.
    Each sample contains: image, text tokens, and class ID.
    """
    
    def __init__(self, data: List[Tuple], vocab: Vocabulary, 
                 transform=None, max_text_length=50):
        self.data = data
        self.vocab = vocab
        self.transform = transform
        self.max_text_length = max_text_length
        
        # Group data by class for triplet sampling
        self.class_to_indices = defaultdict(list)
        for idx, (_, _, class_id) in enumerate(data):
            self.class_to_indices[class_id].append(idx)
        
        self.classes = list(self.class_to_indices.keys())
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path, text_label, class_id = self.data[idx]
        
        # Load and transform image
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a blank image if loading fails
            image = torch.zeros(3, 224, 224)
        
        # Encode text
        text_tokens = self.vocab.encode(text_label, self.max_text_length)
        text_tokens = torch.tensor(text_tokens, dtype=torch.long)
        
        return {
            'image': image,
            'text': text_tokens,
            'class_id': class_id,
            'text_label': text_label
        }
    
    def get_triplet(self, anchor_idx):
        """
        Generate a triplet: (anchor, positive, negative)
        Both anchor and positive have the same class, negative has different class.
        """
        anchor_class = self.data[anchor_idx][2]
        
        # Get positive (same class, different sample)
        positive_candidates = [i for i in self.class_to_indices[anchor_class] 
                              if i != anchor_idx]
        if len(positive_candidates) > 0:
            positive_idx = random.choice(positive_candidates)
        else:
            positive_idx = anchor_idx  # Fallback if only one sample in class
        
        # Get negative (different class)
        negative_classes = [c for c in self.classes if c != anchor_class]
        negative_class = random.choice(negative_classes)
        negative_idx = random.choice(self.class_to_indices[negative_class])
        
        return anchor_idx, positive_idx, negative_idx

# Image transformations
train_transform = transforms.Compose([
    transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = MultimodalDataset(train_data, vocab, train_transform, 
                                 config.TEXT_MAX_LENGTH)
val_dataset = MultimodalDataset(val_data, vocab, test_transform, 
                               config.TEXT_MAX_LENGTH)
test_dataset = MultimodalDataset(test_data, vocab, test_transform, 
                                config.TEXT_MAX_LENGTH)

print(f"\nDatasets created:")
print(f"  Train dataset: {len(train_dataset)} samples")
print(f"  Val dataset: {len(val_dataset)} samples")
print(f"  Test dataset: {len(test_dataset)} samples")

# Test dataset
sample = train_dataset[0]
print(f"\nSample data:")
print(f"  Image shape: {sample['image'].shape}")
print(f"  Text tokens shape: {sample['text'].shape}")
print(f"  Class ID: {sample['class_id']}")
print(f"  Text label: {sample['text_label']}")

## 5. Model Architecture

Dual encoder architecture with separate encoders for images and text, both projecting to a shared embedding space.

In [None]:
class ImageEncoder(nn.Module):
    """
    Image encoder using ResNet-50 backbone with custom projection head.
    Outputs L2-normalized embeddings.
    """
    
    def __init__(self, embedding_dim=512):
        super(ImageEncoder, self).__init__()
        
        # Load pre-trained ResNet-50
        resnet = models.resnet50(pretrained=True)
        
        # Remove final FC layer
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        
        # Projection head
        self.projection = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, embedding_dim)
        )
        
    def forward(self, x):
        # Extract features
        features = self.backbone(x)
        features = features.view(features.size(0), -1)
        
        # Project to embedding space
        embeddings = self.projection(features)
        
        # L2 normalize
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
        return embeddings


class TextEncoder(nn.Module):
    """
    Text encoder using embedding layer + LSTM + projection head.
    Outputs L2-normalized embeddings.
    """
    
    def __init__(self, vocab_size, embed_dim=300, hidden_dim=256, 
                 embedding_dim=512, num_layers=2):
        super(TextEncoder, self).__init__()
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # LSTM
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers,
                           batch_first=True, bidirectional=True, dropout=0.3)
        
        # Projection head (bidirectional LSTM outputs hidden_dim * 2)
        self.projection = nn.Sequential(
            nn.Linear(hidden_dim * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, embedding_dim)
        )
        
    def forward(self, x):
        # Embed tokens
        embedded = self.embedding(x)
        
        # LSTM encoding
        lstm_out, (hidden, cell) = self.lstm(embedded)
        
        # Use mean pooling over sequence
        text_features = torch.mean(lstm_out, dim=1)
        
        # Project to embedding space
        embeddings = self.projection(text_features)
        
        # L2 normalize
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
        return embeddings


class DualEncoderModel(nn.Module):
    """
    Dual encoder model with shared embedding space for images and text.
    """
    
    def __init__(self, vocab_size, embedding_dim=512, 
                 text_embed_dim=300, text_hidden_dim=256):
        super(DualEncoderModel, self).__init__()
        
        self.image_encoder = ImageEncoder(embedding_dim)
        self.text_encoder = TextEncoder(vocab_size, text_embed_dim, 
                                       text_hidden_dim, embedding_dim)
        self.embedding_dim = embedding_dim
        
    def forward(self, images, texts):
        """
        Forward pass for both modalities.
        
        Args:
            images: Batch of images [B, 3, H, W]
            texts: Batch of text tokens [B, max_length]
            
        Returns:
            image_embeddings: [B, embedding_dim]
            text_embeddings: [B, embedding_dim]
        """
        image_embeddings = self.image_encoder(images)
        text_embeddings = self.text_encoder(texts)
        
        return image_embeddings, text_embeddings
    
    def encode_image(self, images):
        """Encode images only"""
        return self.image_encoder(images)
    
    def encode_text(self, texts):
        """Encode texts only"""
        return self.text_encoder(texts)

# Initialize model
model = DualEncoderModel(
    vocab_size=config.VOCAB_SIZE,
    embedding_dim=config.EMBEDDING_DIM,
    text_embed_dim=config.TEXT_EMBED_DIM,
    text_hidden_dim=config.TEXT_HIDDEN_DIM
).to(device)

print(f"\nModel Architecture:")
print(f"  Image Encoder: ResNet-50 + Projection Head")
print(f"  Text Encoder: Embedding + BiLSTM + Projection Head")
print(f"  Embedding Dimension: {config.EMBEDDING_DIM}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel Parameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")

## 6. Triplet Loss with Semi-Hard Negative Mining

Implement triplet loss function that supports Image-Image, Text-Image, and Image-Text triplets with semi-hard negative mining.

In [None]:
class TripletLoss(nn.Module):
    """
    Triplet Loss with online semi-hard negative mining.
    
    For a batch, computes all possible triplets and selects semi-hard negatives:
    - Negative is farther than anchor-positive distance
    - But within margin of the anchor
    """
    
    def __init__(self, margin=0.5):
        super(TripletLoss, self).__init__()
        self.margin = margin
        
    def forward(self, embeddings, labels):
        """
        Compute triplet loss with semi-hard negative mining.
        
        Args:
            embeddings: [batch_size, embedding_dim] - L2 normalized
            labels: [batch_size] - class labels
            
        Returns:
            loss: scalar tensor
            num_triplets: number of valid triplets
        """
        # Compute pairwise distances (using dot product since normalized)
        # Distance = 2 - 2 * cosine_similarity
        dot_product = torch.matmul(embeddings, embeddings.t())
        distances = 2.0 - 2.0 * dot_product
        
        # Create mask for positive pairs (same class)
        labels = labels.unsqueeze(1)
        positive_mask = labels == labels.t()
        positive_mask.fill_diagonal_(False)  # Exclude self
        
        # Create mask for negative pairs (different class)
        negative_mask = labels != labels.t()
        
        # For each anchor, find hardest positive and semi-hard negative
        losses = []
        num_valid_triplets = 0
        
        for i in range(embeddings.size(0)):
            # Get positive distances for this anchor
            pos_dists = distances[i][positive_mask[i]]
            if len(pos_dists) == 0:
                continue
            
            # Hardest positive (farthest positive)
            hardest_positive_dist = pos_dists.max()
            
            # Get negative distances for this anchor
            neg_dists = distances[i][negative_mask[i]]
            if len(neg_dists) == 0:
                continue
            
            # Semi-hard negatives: farther than positive but within margin
            semi_hard_negatives = neg_dists[
                (neg_dists > hardest_positive_dist) & 
                (neg_dists < hardest_positive_dist + self.margin)
            ]
            
            if len(semi_hard_negatives) > 0:
                # Use hardest semi-hard negative
                hardest_negative_dist = semi_hard_negatives.min()
            else:
                # If no semi-hard negative, use hardest negative
                hardest_negative_dist = neg_dists.min()
            
            # Compute triplet loss
            loss = F.relu(hardest_positive_dist - hardest_negative_dist + self.margin)
            
            if loss > 0:
                losses.append(loss)
                num_valid_triplets += 1
        
        if len(losses) == 0:
            return torch.tensor(0.0, device=embeddings.device), 0
        
        return torch.stack(losses).mean(), num_valid_triplets


def compute_multimodal_triplet_loss(model, batch, triplet_loss_fn, device):
    """
    Compute triplet loss for multiple triplet types:
    1. Image-Image triplets
    2. Text-Image triplets (anchor: text, pos/neg: images)
    3. Image-Text triplets (anchor: image, pos/neg: texts)
    
    Returns:
        Combined loss and individual losses for logging
    """
    images = batch['image'].to(device)
    texts = batch['text'].to(device)
    labels = batch['class_id'].to(device)
    
    # Encode both modalities
    image_embeddings, text_embeddings = model(images, texts)
    
    # 1. Image-Image triplets
    loss_img_img, num_img_img = triplet_loss_fn(image_embeddings, labels)
    
    # 2. Text-Text triplets
    loss_txt_txt, num_txt_txt = triplet_loss_fn(text_embeddings, labels)
    
    # 3. Cross-modal: combine image and text embeddings for same class
    # This encourages image and text of same class to be close
    combined_embeddings = torch.cat([image_embeddings, text_embeddings], dim=0)
    combined_labels = torch.cat([labels, labels], dim=0)
    loss_cross_modal, num_cross = triplet_loss_fn(combined_embeddings, combined_labels)
    
    # Weighted combination
    total_loss = (loss_img_img + loss_txt_txt + loss_cross_modal) / 3.0
    
    return total_loss, {
        'loss_img_img': loss_img_img.item(),
        'loss_txt_txt': loss_txt_txt.item(),
        'loss_cross_modal': loss_cross_modal.item(),
        'num_img_img': num_img_img,
        'num_txt_txt': num_txt_txt,
        'num_cross': num_cross
    }

# Initialize loss function
triplet_loss_fn = TripletLoss(margin=config.MARGIN)

print(f"\nTriplet Loss Configuration:")
print(f"  Margin: {config.MARGIN}")
print(f"  Mining Strategy: {config.MINING_STRATEGY}")
print(f"  Triplet Types: Image-Image, Text-Text, Cross-Modal")

## 7. Training and Validation Functions

Implement training loop with validation and metric computation.

In [None]:
def compute_retrieval_accuracy(image_embeddings, text_embeddings, labels, k=5):
    """
    Compute retrieval accuracy metrics:
    - Image-to-Text: Given image, retrieve correct text in top-k
    - Text-to-Image: Given text, retrieve correct image in top-k
    
    Args:
        image_embeddings: [N, embedding_dim]
        text_embeddings: [N, embedding_dim]
        labels: [N] class labels
        k: top-k accuracy
        
    Returns:
        img2txt_acc: Image-to-text retrieval accuracy
        txt2img_acc: Text-to-image retrieval accuracy
    """
    # Compute similarity matrix (cosine similarity for normalized embeddings)
    similarity = torch.matmul(image_embeddings, text_embeddings.t())
    
    # Image-to-Text retrieval
    _, img2txt_indices = similarity.topk(k, dim=1)
    img2txt_correct = 0
    for i in range(len(labels)):
        retrieved_labels = labels[img2txt_indices[i]]
        if labels[i] in retrieved_labels:
            img2txt_correct += 1
    img2txt_acc = img2txt_correct / len(labels)
    
    # Text-to-Image retrieval
    _, txt2img_indices = similarity.t().topk(k, dim=1)
    txt2img_correct = 0
    for i in range(len(labels)):
        retrieved_labels = labels[txt2img_indices[i]]
        if labels[i] in retrieved_labels:
            txt2img_correct += 1
    txt2img_acc = txt2img_correct / len(labels)
    
    return img2txt_acc, txt2img_acc


def train_epoch(model, dataloader, optimizer, triplet_loss_fn, device, epoch):
    """Train for one epoch"""
    model.train()
    
    total_loss = 0
    total_img_img_loss = 0
    total_txt_txt_loss = 0
    total_cross_modal_loss = 0
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch} [Train]")
    
    for batch_idx, batch in enumerate(pbar):
        optimizer.zero_grad()
        
        # Compute loss
        loss, loss_dict = compute_multimodal_triplet_loss(
            model, batch, triplet_loss_fn, device
        )
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Accumulate losses
        total_loss += loss.item()
        total_img_img_loss += loss_dict['loss_img_img']
        total_txt_txt_loss += loss_dict['loss_txt_txt']
        total_cross_modal_loss += loss_dict['loss_cross_modal']
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'img_img': f"{loss_dict['loss_img_img']:.4f}",
            'txt_txt': f"{loss_dict['loss_txt_txt']:.4f}",
            'cross': f"{loss_dict['loss_cross_modal']:.4f}"
        })
    
    # Compute averages
    num_batches = len(dataloader)
    return {
        'loss': total_loss / num_batches,
        'loss_img_img': total_img_img_loss / num_batches,
        'loss_txt_txt': total_txt_txt_loss / num_batches,
        'loss_cross_modal': total_cross_modal_loss / num_batches
    }


def validate(model, dataloader, triplet_loss_fn, device, epoch, split_name="Val"):
    """Validate and compute retrieval accuracy"""
    model.eval()
    
    total_loss = 0
    all_image_embeddings = []
    all_text_embeddings = []
    all_labels = []
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc=f"Epoch {epoch} [{split_name}]")
        
        for batch in pbar:
            # Compute loss
            loss, loss_dict = compute_multimodal_triplet_loss(
                model, batch, triplet_loss_fn, device
            )
            total_loss += loss.item()
            
            # Collect embeddings for retrieval accuracy
            images = batch['image'].to(device)
            texts = batch['text'].to(device)
            labels = batch['class_id']
            
            image_emb, text_emb = model(images, texts)
            
            all_image_embeddings.append(image_emb.cpu())
            all_text_embeddings.append(text_emb.cpu())
            all_labels.append(labels)
            
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})
    
    # Concatenate all embeddings
    all_image_embeddings = torch.cat(all_image_embeddings, dim=0)
    all_text_embeddings = torch.cat(all_text_embeddings, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    
    # Compute retrieval accuracy
    img2txt_acc, txt2img_acc = compute_retrieval_accuracy(
        all_image_embeddings, all_text_embeddings, all_labels, k=5
    )
    
    avg_loss = total_loss / len(dataloader)
    
    return {
        'loss': avg_loss,
        'img2txt_acc': img2txt_acc,
        'txt2img_acc': txt2img_acc,
        'avg_acc': (img2txt_acc + txt2img_acc) / 2
    }

print("Training and validation functions defined.")

## 8. Create Data Loaders

Create PyTorch DataLoaders for training, validation, and testing.

In [None]:
# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    num_workers=config.NUM_WORKERS,
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=True if torch.cuda.is_available() else False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=True if torch.cuda.is_available() else False
)

print(f"\nData Loaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")
print(f"  Batch size: {config.BATCH_SIZE}")

## 9. Optimizer and Scheduler

Configure optimizer and learning rate scheduler.

In [None]:
# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.LEARNING_RATE,
    weight_decay=config.WEIGHT_DECAY
)

# Learning rate scheduler (reduce on plateau)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    verbose=True,
    min_lr=1e-7
)

print(f"\nOptimizer: AdamW")
print(f"  Learning Rate: {config.LEARNING_RATE}")
print(f"  Weight Decay: {config.WEIGHT_DECAY}")
print(f"\nScheduler: ReduceLROnPlateau")
print(f"  Factor: 0.5")
print(f"  Patience: 3 epochs")

## 10. Main Training Loop

Execute the full training loop with validation, checkpointing, and metric tracking.

In [None]:
# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'test_loss': [],
    'train_img2txt_acc': [],
    'train_txt2img_acc': [],
    'val_img2txt_acc': [],
    'val_txt2img_acc': [],
    'test_img2txt_acc': [],
    'test_txt2img_acc': [],
    'learning_rates': []
}

best_val_loss = float('inf')
best_val_acc = 0.0

print(f"\n{'='*60}")
print(f"Starting Training for {config.NUM_EPOCHS} epochs")
print(f"{'='*60}\n")

for epoch in range(1, config.NUM_EPOCHS + 1):
    print(f"\n--- Epoch {epoch}/{config.NUM_EPOCHS} ---")
    
    # Train
    train_metrics = train_epoch(model, train_loader, optimizer, 
                                triplet_loss_fn, device, epoch)
    
    # Validate
    val_metrics = validate(model, val_loader, triplet_loss_fn, 
                          device, epoch, split_name="Val")
    
    # Test
    test_metrics = validate(model, test_loader, triplet_loss_fn, 
                           device, epoch, split_name="Test")
    
    # Update learning rate scheduler
    scheduler.step(val_metrics['loss'])
    current_lr = optimizer.param_groups[0]['lr']
    
    # Store metrics
    history['train_loss'].append(train_metrics['loss'])
    history['val_loss'].append(val_metrics['loss'])
    history['test_loss'].append(test_metrics['loss'])
    history['val_img2txt_acc'].append(val_metrics['img2txt_acc'])
    history['val_txt2img_acc'].append(val_metrics['txt2img_acc'])
    history['test_img2txt_acc'].append(test_metrics['img2txt_acc'])
    history['test_txt2img_acc'].append(test_metrics['txt2img_acc'])
    history['learning_rates'].append(current_lr)
    
    # Print epoch summary
    print(f"\nEpoch {epoch} Summary:")
    print(f"  Train Loss: {train_metrics['loss']:.4f}")
    print(f"  Val Loss: {val_metrics['loss']:.4f}")
    print(f"  Test Loss: {test_metrics['loss']:.4f}")
    print(f"  Val Image->Text Acc: {val_metrics['img2txt_acc']:.4f}")
    print(f"  Val Text->Image Acc: {val_metrics['txt2img_acc']:.4f}")
    print(f"  Test Image->Text Acc: {test_metrics['img2txt_acc']:.4f}")
    print(f"  Test Text->Image Acc: {test_metrics['txt2img_acc']:.4f}")
    print(f"  Learning Rate: {current_lr:.2e}")
    
    # Save best model based on validation loss
    if val_metrics['loss'] < best_val_loss:
        best_val_loss = val_metrics['loss']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_metrics['loss'],
            'val_acc': val_metrics['avg_acc'],
            'config': {
                'embedding_dim': config.EMBEDDING_DIM,
                'margin': config.MARGIN,
                'vocab_size': config.VOCAB_SIZE,
                'text_embed_dim': config.TEXT_EMBED_DIM,
                'text_hidden_dim': config.TEXT_HIDDEN_DIM
            }
        }, os.path.join(config.OUTPUT_DIR, 'best_model.pth'))
        print(f"  ✓ Saved best model (loss: {val_metrics['loss']:.4f})")
    
    # Save best model based on validation accuracy
    if val_metrics['avg_acc'] > best_val_acc:
        best_val_acc = val_metrics['avg_acc']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_metrics['loss'],
            'val_acc': val_metrics['avg_acc'],
            'config': {
                'embedding_dim': config.EMBEDDING_DIM,
                'margin': config.MARGIN,
                'vocab_size': config.VOCAB_SIZE,
                'text_embed_dim': config.TEXT_EMBED_DIM,
                'text_hidden_dim': config.TEXT_HIDDEN_DIM
            }
        }, os.path.join(config.OUTPUT_DIR, 'best_model_acc.pth'))
        print(f"  ✓ Saved best accuracy model (acc: {val_metrics['avg_acc']:.4f})")
    
    # Save checkpoint every N epochs
    if epoch % config.SAVE_FREQ == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_metrics['loss'],
            'val_acc': val_metrics['avg_acc']
        }, os.path.join(config.OUTPUT_DIR, f'checkpoint_epoch_{epoch}.pth'))
        print(f"  ✓ Saved checkpoint at epoch {epoch}")

print(f"\n{'='*60}")
print(f"Training Complete!")
print(f"{'='*60}")
print(f"Best Validation Loss: {best_val_loss:.4f}")
print(f"Best Validation Accuracy: {best_val_acc:.4f}")

## 11. Save Training History

Save training history and model metadata for analysis.

In [None]:
# Save training history
history_path = os.path.join(config.OUTPUT_DIR, 'training_history.json')
with open(history_path, 'w') as f:
    json.dump(history, f, indent=2)

print(f"\nTraining history saved to: {history_path}")

# Save vocabulary
vocab_path = os.path.join(config.OUTPUT_DIR, 'vocabulary.json')
with open(vocab_path, 'w') as f:
    json.dump({
        'word2idx': vocab.word2idx,
        'idx2word': vocab.idx2word
    }, f, indent=2)

print(f"Vocabulary saved to: {vocab_path}")

# Save class mappings
class_mapping_path = os.path.join(config.OUTPUT_DIR, 'class_mapping.json')
with open(class_mapping_path, 'w') as f:
    json.dump({
        'class_to_id': class_to_id,
        'class_names': class_names
    }, f, indent=2)

print(f"Class mapping saved to: {class_mapping_path}")

# Save configuration
config_path = os.path.join(config.OUTPUT_DIR, 'training_config.json')
config_dict = {
    'embedding_dim': config.EMBEDDING_DIM,
    'image_size': config.IMAGE_SIZE,
    'text_max_length': config.TEXT_MAX_LENGTH,
    'vocab_size': config.VOCAB_SIZE,
    'text_embed_dim': config.TEXT_EMBED_DIM,
    'text_hidden_dim': config.TEXT_HIDDEN_DIM,
    'batch_size': config.BATCH_SIZE,
    'num_epochs': config.NUM_EPOCHS,
    'learning_rate': config.LEARNING_RATE,
    'weight_decay': config.WEIGHT_DECAY,
    'margin': config.MARGIN,
    'mining_strategy': config.MINING_STRATEGY,
    'best_val_loss': best_val_loss,
    'best_val_acc': best_val_acc
}

with open(config_path, 'w') as f:
    json.dump(config_dict, f, indent=2)

print(f"Training configuration saved to: {config_path}")

print(f"\n{'='*60}")
print("All training artifacts saved successfully!")
print(f"{'='*60}")

## 12. Visualize Training Progress

Plot training and testing metrics to analyze model performance.

In [None]:
# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

epochs = range(1, len(history['train_loss']) + 1)

# 1. Training and Validation Loss
axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
axes[0, 0].plot(epochs, history['test_loss'], 'g-', label='Test Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss', fontsize=12)
axes[0, 0].set_title('Training, Validation, and Test Loss', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=10)
axes[0, 0].grid(True, alpha=0.3)

# 2. Image-to-Text Retrieval Accuracy
axes[0, 1].plot(epochs, history['val_img2txt_acc'], 'r-', label='Val Image->Text', linewidth=2)
axes[0, 1].plot(epochs, history['test_img2txt_acc'], 'g-', label='Test Image->Text', linewidth=2)
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Accuracy', fontsize=12)
axes[0, 1].set_title('Image-to-Text Retrieval Accuracy (Top-5)', fontsize=14, fontweight='bold')
axes[0, 1].legend(fontsize=10)
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_ylim([0, 1])

# 3. Text-to-Image Retrieval Accuracy
axes[1, 0].plot(epochs, history['val_txt2img_acc'], 'r-', label='Val Text->Image', linewidth=2)
axes[1, 0].plot(epochs, history['test_txt2img_acc'], 'g-', label='Test Text->Image', linewidth=2)
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('Accuracy', fontsize=12)
axes[1, 0].set_title('Text-to-Image Retrieval Accuracy (Top-5)', fontsize=14, fontweight='bold')
axes[1, 0].legend(fontsize=10)
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_ylim([0, 1])

# 4. Average Retrieval Accuracy (Val + Test)
val_avg_acc = [(i2t + t2i) / 2 for i2t, t2i in zip(history['val_img2txt_acc'], history['val_txt2img_acc'])]
test_avg_acc = [(i2t + t2i) / 2 for i2t, t2i in zip(history['test_img2txt_acc'], history['test_txt2img_acc'])]

axes[1, 1].plot(epochs, val_avg_acc, 'r-', label='Validation Avg', linewidth=2)
axes[1, 1].plot(epochs, test_avg_acc, 'g-', label='Test Avg', linewidth=2)
axes[1, 1].set_xlabel('Epoch', fontsize=12)
axes[1, 1].set_ylabel('Accuracy', fontsize=12)
axes[1, 1].set_title('Average Retrieval Accuracy', fontsize=14, fontweight='bold')
axes[1, 1].legend(fontsize=10)
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_ylim([0, 1])

plt.tight_layout()
plt.savefig(os.path.join(config.OUTPUT_DIR, 'training_metrics.png'), dpi=300, bbox_inches='tight')
plt.show()

print(f"\nTraining visualization saved to: {os.path.join(config.OUTPUT_DIR, 'training_metrics.png')}")

## 13. Additional Visualizations

Detailed visualizations for loss components and learning rate schedule.

In [None]:
# Learning Rate Schedule
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

ax.plot(epochs, history['learning_rates'], 'b-', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Learning Rate', fontsize=12)
ax.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(config.OUTPUT_DIR, 'learning_rate_schedule.png'), dpi=300, bbox_inches='tight')
plt.show()

print(f"Learning rate schedule saved to: {os.path.join(config.OUTPUT_DIR, 'learning_rate_schedule.png')}")

## 14. Final Summary and Model Information

Display final training statistics and saved model information.

In [None]:
print("\n" + "="*80)
print(" "*25 + "TRAINING COMPLETE - FINAL SUMMARY")
print("="*80)

print("\nDataset Information:")
print(f"  Total samples: {len(full_dataset):,}")
print(f"  Number of classes: {len(class_names)}")
print(f"  Train samples: {len(train_data):,}")
print(f"  Validation samples: {len(val_data):,}")
print(f"  Test samples: {len(test_data):,}")

print("\nModel Architecture:")
print(f"  Embedding dimension: {config.EMBEDDING_DIM}")
print(f"  Image encoder: ResNet-50 + Projection Head")
print(f"  Text encoder: Embedding + BiLSTM + Projection Head")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

print("\nTraining Configuration:")
print(f"  Epochs: {config.NUM_EPOCHS}")
print(f"  Batch size: {config.BATCH_SIZE}")
print(f"  Initial learning rate: {config.LEARNING_RATE}")
print(f"  Final learning rate: {history['learning_rates'][-1]:.2e}")
print(f"  Triplet margin: {config.MARGIN}")
print(f"  Mining strategy: {config.MINING_STRATEGY}")

print("\nBest Performance:")
print(f"  Best validation loss: {best_val_loss:.4f}")
print(f"  Best validation accuracy: {best_val_acc:.4f}")

print("\nFinal Test Performance:")
print(f"  Test loss: {history['test_loss'][-1]:.4f}")
print(f"  Test Image->Text accuracy: {history['test_img2txt_acc'][-1]:.4f}")
print(f"  Test Text->Image accuracy: {history['test_txt2img_acc'][-1]:.4f}")
print(f"  Test average accuracy: {test_avg_acc[-1]:.4f}")

print("\nSaved Files:")
print(f"  Model (best loss): {os.path.join(config.OUTPUT_DIR, 'best_model.pth')}")
print(f"  Model (best accuracy): {os.path.join(config.OUTPUT_DIR, 'best_model_acc.pth')}")
print(f"  Training history: {os.path.join(config.OUTPUT_DIR, 'training_history.json')}")
print(f"  Vocabulary: {os.path.join(config.OUTPUT_DIR, 'vocabulary.json')}")
print(f"  Class mapping: {os.path.join(config.OUTPUT_DIR, 'class_mapping.json')}")
print(f"  Configuration: {os.path.join(config.OUTPUT_DIR, 'training_config.json')}")
print(f"  Visualizations: {os.path.join(config.OUTPUT_DIR, '*.png')}")

print("\n" + "="*80)
print("Next Steps:")
print("  1. Review training visualizations to assess model convergence")
print("  2. Proceed to inference notebook for model evaluation")
print("  3. Build vector database for similarity search")
print("  4. Integrate with recommendation system")
print("="*80 + "\n")