In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Training Multimodal Models: Contrastive Learning and Instruction Tuning

*Part 3 of the Vizuara series on Multimodal Fusion Architectures*
*Estimated time: 50 minutes*

## 1. Why Does This Matter?

In the previous notebooks, we built multimodal architectures. But building the architecture is only half the story. The other half is **training** -- and modern multimodal models use a clever two-stage training recipe:

1. **Stage 1 (Alignment):** Teach the model that an image of a dog and the text "a photo of a dog" should have similar representations. This is done with **contrastive learning** (CLIP-style).
2. **Stage 2 (Instruction Tuning):** Teach the model to follow instructions like "Describe this image in detail" or "What color is the car?" This is done with **next-token prediction**.

By the end of this notebook, you will have:
- Implemented contrastive loss from scratch and trained a mini-CLIP model
- Built an instruction-tuned model and visualized the alignment space
- Understood why freezing pretrained components matters

In [None]:
# Setup
!pip install torch torchvision matplotlib numpy -q

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset

torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Building Intuition

Think of contrastive learning as a **matching game**. You are given a set of images and a set of text descriptions. Your job is to match each image with its correct description.

Before training, the image of a dog might be close to the text "airplane" in the embedding space. After training, it should be close to "a photo of a dog" and far from everything else.

The trick is in the "far from everything else" part. You do not just push matching pairs together -- you simultaneously push non-matching pairs apart. This is what makes the embeddings useful: they create a structured space where similar concepts cluster together.

### Think About This

Why is cosine similarity used instead of Euclidean distance for comparing image and text embeddings? What happens if one encoder produces much larger vectors than the other?

## 3. The Mathematics

### 3.1 Cosine Similarity

$$\text{sim}(a, b) = \frac{a \cdot b}{\|a\| \|b\|}$$

This measures the angle between two vectors, ignoring their magnitudes. A value of 1 means perfectly aligned, 0 means perpendicular, -1 means opposite.

Let us compute: $a = [3, 4]$, $b = [4, 3]$:

$$\text{sim}(a, b) = \frac{3 \times 4 + 4 \times 3}{\sqrt{9+16} \times \sqrt{16+9}} = \frac{24}{5 \times 5} = \frac{24}{25} = 0.96$$

These vectors are very similar (96% aligned). This is exactly what we want for a matching image-text pair.

### 3.2 Contrastive Loss (InfoNCE)

Given a batch of $N$ image-text pairs, the contrastive loss for image $i$ is:

$$\mathcal{L}_i = -\log \frac{\exp(\text{sim}(v_i, t_i) / \tau)}{\sum_{j=1}^{N} \exp(\text{sim}(v_i, t_j) / \tau)}$$

The numerator contains the similarity of the **matching pair** $(v_i, t_i)$. The denominator sums over **all pairs**, including non-matching ones. The temperature $\tau$ controls how sharp the distribution is.

Let us compute with $N = 3$, $\tau = 0.1$:

Suppose similarities are: $\text{sim}(v_1, t_1) = 0.9$ (match), $\text{sim}(v_1, t_2) = 0.2$, $\text{sim}(v_1, t_3) = 0.1$:

$$\mathcal{L}_1 = -\log \frac{\exp(0.9 / 0.1)}{\exp(9) + \exp(2) + \exp(1)} = -\log \frac{8103.1}{8103.1 + 7.39 + 2.72} = -\log(0.9988) = 0.0012$$

The loss is very small because the matching pair has much higher similarity. This tells us the model is doing a good job distinguishing the correct match.

Now consider a bad model where $\text{sim}(v_1, t_1) = 0.3$, $\text{sim}(v_1, t_2) = 0.4$, $\text{sim}(v_1, t_3) = 0.3$:

$$\mathcal{L}_1 = -\log \frac{\exp(3)}{\exp(3) + \exp(4) + \exp(3)} = -\log \frac{20.1}{20.1 + 54.6 + 20.1} = -\log(0.212) = 1.55$$

Much higher loss! The model cannot distinguish the matching pair from non-matching ones.

### 3.3 Temperature Parameter

$\tau$ controls how "picky" the model is. Small $\tau$ (like 0.07) makes the softmax very sharp -- the model needs high similarity for matching pairs and low for non-matching. Large $\tau$ (like 1.0) makes it more lenient.

With $\tau = 0.07$ and similarities $[0.8, 0.7, 0.1]$:
$$\text{softmax}([0.8/0.07, 0.7/0.07, 0.1/0.07]) = \text{softmax}([11.4, 10.0, 1.43]) \approx [0.80, 0.20, 0.00]$$

With $\tau = 1.0$:
$$\text{softmax}([0.8, 0.7, 0.1]) \approx [0.38, 0.34, 0.19]$$

The small $\tau$ creates much sharper distinctions. This is exactly what we want for good alignment.

## 4. Let's Build It -- Component by Component

### 4.1 Creating an Image-Text Paired Dataset

In [None]:
class CIFAR10PairedDataset(Dataset):
    """
    Pairs CIFAR-10 images with text descriptions of their class.
    This simulates the image-text pairs used in CLIP training.
    """
    def __init__(self, cifar_dataset):
        self.cifar = cifar_dataset
        self.class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                           'dog', 'frog', 'horse', 'ship', 'truck']
        self.descriptions = [
            'a photo of an airplane',
            'a photo of an automobile',
            'a photo of a bird',
            'a photo of a cat',
            'a photo of a deer',
            'a photo of a dog',
            'a photo of a frog',
            'a photo of a horse',
            'a photo of a ship',
            'a photo of a truck',
        ]

    def __len__(self):
        return len(self.cifar)

    def __getitem__(self, idx):
        img, label = self.cifar[idx]
        # Return image, label index (used as text ID), and label
        return img, label, label

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

cifar_train = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
cifar_test = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform)

# Use subsets for speed
train_paired = CIFAR10PairedDataset(torch.utils.data.Subset(cifar_train, range(5000)))
test_paired = CIFAR10PairedDataset(torch.utils.data.Subset(cifar_test, range(1000)))

train_loader = DataLoader(train_paired, batch_size=128, shuffle=True)
test_loader = DataLoader(test_paired, batch_size=128, shuffle=False)

print(f"Training pairs: {len(train_paired)}")
print(f"Test pairs: {len(test_paired)}")

### 4.2 Building the Mini-CLIP Model

In [None]:
class MiniCLIP(nn.Module):
    """
    A simplified CLIP model with separate image and text encoders.

    The image encoder is a small CNN.
    The text encoder is an embedding lookup (since we have 10 classes).
    Both produce vectors in the same shared embedding space.
    """
    def __init__(self, embed_dim=128, num_classes=10):
        super().__init__()
        # Image encoder: small CNN
        self.image_encoder = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, embed_dim)
        )

        # Text encoder: learned embeddings for each class description
        self.text_encoder = nn.Embedding(num_classes, embed_dim)

        # Learnable temperature (like in CLIP)
        self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07))

    def encode_image(self, images):
        """Encode images to normalized embeddings."""
        features = self.image_encoder(images)
        return F.normalize(features, dim=-1)

    def encode_text(self, text_ids):
        """Encode text IDs to normalized embeddings."""
        features = self.text_encoder(text_ids)
        return F.normalize(features, dim=-1)

    def forward(self, images, text_ids):
        """
        Returns: image_features, text_features, temperature
        """
        img_features = self.encode_image(images)
        txt_features = self.encode_text(text_ids)
        return img_features, txt_features, self.temperature.exp()

model = MiniCLIP().to(device)
print(f"MiniCLIP parameters: {sum(p.numel() for p in model.parameters()):,}")

### 4.3 Implementing Contrastive Loss from Scratch

In [None]:
def contrastive_loss(img_features, txt_features, temperature):
    """
    Symmetric contrastive loss (InfoNCE).

    img_features: (B, D) -- normalized image embeddings
    txt_features: (B, D) -- normalized text embeddings
    temperature: scalar -- controls sharpness

    Loss = (image-to-text loss + text-to-image loss) / 2
    """
    # Compute similarity matrix: (B, B)
    # sim[i, j] = cosine similarity between image i and text j
    logits = (img_features @ txt_features.T) * temperature

    # Labels: the diagonal (image i matches text i)
    labels = torch.arange(len(img_features), device=img_features.device)

    # Image-to-text loss: for each image, find its matching text
    loss_i2t = F.cross_entropy(logits, labels)

    # Text-to-image loss: for each text, find its matching image
    loss_t2i = F.cross_entropy(logits.T, labels)

    return (loss_i2t + loss_t2i) / 2

# Test with dummy data
dummy_img = F.normalize(torch.randn(4, 128), dim=-1)
dummy_txt = F.normalize(torch.randn(4, 128), dim=-1)
loss = contrastive_loss(dummy_img, dummy_txt, temperature=torch.tensor(14.29))
print(f"Dummy contrastive loss: {loss.item():.4f}")
print(f"Expected: ~log(4) = {np.log(4):.4f} (random embeddings)")

In [None]:
# Visualization checkpoint: similarity matrix before training
model.eval()
with torch.no_grad():
    sample_imgs = torch.stack([cifar_train[i][0] for i in range(10)]).to(device)
    sample_labels = torch.tensor([cifar_train[i][1] for i in range(10)]).to(device)

    img_feat, txt_feat, temp = model(sample_imgs, torch.arange(10).to(device))
    sim_matrix = (img_feat @ txt_feat.T).cpu().numpy()

class_names = ['airplane', 'auto', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

fig, ax = plt.subplots(figsize=(8, 7))
im = ax.imshow(sim_matrix, cmap='coolwarm', vmin=-1, vmax=1)
ax.set_xticks(range(10))
ax.set_xticklabels(class_names, rotation=45, ha='right')
ax.set_yticks(range(10))
ax.set_yticklabels([f'Image {i}' for i in range(10)])
ax.set_title('Similarity Matrix BEFORE Training\n(should be random)', fontsize=14)
plt.colorbar(im)
plt.tight_layout()
plt.show()
print("Before training, similarities are random -- no alignment yet!")

## 5. Your Turn

### TODO: Implement Hard Negative Mining

In the basic contrastive loss, all non-matching pairs are treated equally. But some non-matching pairs are harder than others -- e.g., a cat image vs "a photo of a dog" is harder to distinguish than a cat image vs "a photo of an airplane".

In [None]:
def contrastive_loss_hard_negatives(img_features, txt_features, temperature, k=3):
    """
    Contrastive loss with hard negative mining.

    Instead of using ALL non-matching pairs in the denominator,
    only use the top-k hardest negatives (highest similarity non-matches).

    Args:
        img_features: (B, D)
        txt_features: (B, D)
        temperature: scalar
        k: number of hard negatives to keep
    """
    # ============ TODO ============
    # Step 1: Compute similarity matrix (B, B)
    # Step 2: Create a mask for the diagonal (matching pairs)
    # Step 3: For each row, find the top-k non-diagonal similarities
    # Step 4: Compute loss using only matching pair + top-k hard negatives
    # Hint: Use torch.topk on the non-matching similarities
    # ==============================

    loss = None  # YOUR CODE HERE
    return loss

In [None]:
# Verification
# If implemented, compare losses:
if contrastive_loss_hard_negatives(dummy_img, dummy_txt, torch.tensor(14.29)) is not None:
    loss_all = contrastive_loss(dummy_img, dummy_txt, torch.tensor(14.29))
    loss_hard = contrastive_loss_hard_negatives(dummy_img, dummy_txt, torch.tensor(14.29), k=2)
    print(f"Loss (all negatives): {loss_all.item():.4f}")
    print(f"Loss (hard negatives only): {loss_hard.item():.4f}")
    print("Hard negative loss should be >= all-negatives loss (fewer easy negatives in denominator)")
else:
    print("TODO: Implement contrastive_loss_hard_negatives above")

### TODO: Implement Temperature Scheduling

Instead of a learnable temperature, implement a schedule that starts high (lenient) and decreases during training (more strict).

In [None]:
class TemperatureScheduler:
    """
    Linear temperature schedule: starts at tau_start, ends at tau_end.
    """
    def __init__(self, tau_start=1.0, tau_end=0.07, total_steps=1000):
        # ============ TODO ============
        # Store the start, end, and total steps
        # ==============================
        pass

    def get_temperature(self, step):
        """
        Returns the temperature at a given training step.
        Linear interpolation between tau_start and tau_end.
        """
        # ============ TODO ============
        # Compute: tau = tau_start + (tau_end - tau_start) * (step / total_steps)
        # Clamp between tau_end and tau_start
        # ==============================
        return None  # YOUR CODE HERE

In [None]:
# Verification
scheduler = TemperatureScheduler(1.0, 0.07, 100)
if scheduler.get_temperature(0) is not None:
    temps = [scheduler.get_temperature(s) for s in range(101)]
    plt.plot(temps)
    plt.xlabel('Step')
    plt.ylabel('Temperature')
    plt.title('Temperature Schedule')
    plt.show()
    print(f"Start: {temps[0]:.3f}, End: {temps[-1]:.3f}")
else:
    print("TODO: Implement TemperatureScheduler above")

## 6. Putting It All Together

In [None]:
# Full training pipeline
model = MiniCLIP().to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-4)

train_losses = []
retrieval_accs = []

## 7. Training and Results

In [None]:
# Train the contrastive model
epochs = 20

for epoch in range(epochs):
    model.train()
    epoch_loss = 0

    for imgs, text_ids, labels in train_loader:
        imgs, text_ids = imgs.to(device), text_ids.to(device)

        optimizer.zero_grad()
        img_feat, txt_feat, temp = model(imgs, text_ids)
        loss = contrastive_loss(img_feat, txt_feat, temp)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_loss)

    # Evaluate: zero-shot classification accuracy
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        # Encode all 10 class descriptions
        all_text_feat = model.encode_text(torch.arange(10).to(device))  # (10, D)

        for imgs, text_ids, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            img_feat = model.encode_image(imgs)  # (B, D)

            # Find nearest text for each image
            sims = img_feat @ all_text_feat.T  # (B, 10)
            preds = sims.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = correct / total
    retrieval_accs.append(acc)

    if (epoch + 1) % 5 == 0:
        temp_val = model.temperature.exp().item()
        print(f"Epoch {epoch+1:3d}: Loss={avg_loss:.4f}, Acc={acc:.4f}, Temp={temp_val:.2f}")

In [None]:
# Visualization checkpoint: training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(train_losses, color='#2196F3', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Contrastive Loss')
axes[0].set_title('Training Loss')
axes[0].grid(True, alpha=0.3)

axes[1].plot(retrieval_accs, color='#4CAF50', linewidth=2)
axes[1].axhline(y=0.1, color='gray', linestyle='--', alpha=0.5, label='Random (10%)')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Zero-Shot Classification Accuracy')
axes[1].set_title('Zero-Shot Accuracy (no task-specific training!)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
print(f"\nFinal zero-shot accuracy: {retrieval_accs[-1]:.4f}")
print("This model classifies images it has never been explicitly trained to classify!")

In [None]:
# Visualization checkpoint: similarity matrix AFTER training
model.eval()
with torch.no_grad():
    sample_imgs = torch.stack([cifar_test[i][0] for i in range(10)]).to(device)
    sample_labels = torch.tensor([cifar_test[i][1] for i in range(10)])

    img_feat = model.encode_image(sample_imgs)
    txt_feat = model.encode_text(torch.arange(10).to(device))
    sim_matrix = (img_feat @ txt_feat.T).cpu().numpy()

fig, ax = plt.subplots(figsize=(8, 7))
im = ax.imshow(sim_matrix, cmap='coolwarm', vmin=-1, vmax=1)
ax.set_xticks(range(10))
ax.set_xticklabels(class_names, rotation=45, ha='right')
ax.set_yticks(range(10))
actual_labels = [class_names[sample_labels[i].item()] for i in range(10)]
ax.set_yticklabels([f'{actual_labels[i]}' for i in range(10)])
ax.set_title('Similarity Matrix AFTER Training\n(diagonal should be bright)', fontsize=14)
plt.colorbar(im)
plt.tight_layout()
plt.show()
print("After training, matching image-text pairs have high similarity (bright diagonal)!")

In [None]:
# Visualize the learned embedding space with t-SNE
from sklearn.manifold import TSNE

model.eval()
all_img_features = []
all_labels = []

with torch.no_grad():
    for imgs, text_ids, labels in test_loader:
        imgs = imgs.to(device)
        feat = model.encode_image(imgs)
        all_img_features.append(feat.cpu())
        all_labels.append(labels)

all_img_features = torch.cat(all_img_features).numpy()
all_labels = torch.cat(all_labels).numpy()

# t-SNE reduction
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
features_2d = tsne.fit_transform(all_img_features)

fig, ax = plt.subplots(figsize=(10, 8))
colors = plt.cm.tab10(np.linspace(0, 1, 10))
for c in range(10):
    mask = all_labels == c
    ax.scatter(features_2d[mask, 0], features_2d[mask, 1],
              c=[colors[c]], label=class_names[c], alpha=0.5, s=10)
ax.legend(fontsize=9, ncol=2)
ax.set_title('Learned Image Embedding Space (t-SNE)\nImages cluster by semantic category!', fontsize=14)
ax.set_xlabel('t-SNE dim 1')
ax.set_ylabel('t-SNE dim 2')
plt.tight_layout()
plt.show()

## 8. Final Output

In [None]:
# Zero-shot image retrieval demonstration
model.eval()

print("=" * 60)
print("ZERO-SHOT IMAGE RETRIEVAL")
print("=" * 60)

fig, axes = plt.subplots(2, 5, figsize=(20, 8))

for query_idx in range(10):
    query_text = f"a photo of a {class_names[query_idx]}"

    with torch.no_grad():
        txt_feat = model.encode_text(torch.tensor([query_idx]).to(device))

        # Search through test set for best match
        best_sim = -1
        best_img = None
        best_label = None

        for imgs, text_ids, labels in test_loader:
            imgs = imgs.to(device)
            img_feat = model.encode_image(imgs)
            sims = (img_feat @ txt_feat.T).squeeze()
            max_sim, max_idx = sims.max(0)

            if max_sim > best_sim:
                best_sim = max_sim.item()
                best_img = imgs[max_idx].cpu()
                best_label = labels[max_idx].item()

    row, col = query_idx // 5, query_idx % 5
    axes[row, col].imshow(best_img.permute(1, 2, 0) * 0.5 + 0.5)
    match_str = "MATCH" if best_label == query_idx else "MISS"
    color = 'green' if best_label == query_idx else 'red'
    axes[row, col].set_title(
        f'Query: "{query_text}"\n'
        f'Retrieved: {class_names[best_label]} ({match_str})\n'
        f'Similarity: {best_sim:.3f}',
        fontsize=9, color=color
    )
    axes[row, col].axis('off')

plt.suptitle('Zero-Shot Text-to-Image Retrieval with Mini-CLIP', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nCongratulations! You have trained a contrastive model from scratch!")
print("This model can retrieve images from text queries without ANY task-specific training.")

## 9. Reflection and Next Steps

### Reflection Questions
1. The contrastive loss treats all non-matching pairs equally. But "cat" vs "dog" is harder to distinguish than "cat" vs "airplane." How would you modify the loss to account for semantic similarity between classes?
2. We used a batch size of 128. CLIP uses batch sizes of 32,768. Why does larger batch size help contrastive learning? (Hint: think about the number of negative examples.)
3. After contrastive pretraining, we get zero-shot classification "for free." What are the limitations of this zero-shot approach compared to supervised fine-tuning?

### Optional Challenges
1. Implement the hard negative mining TODO above and compare retrieval accuracy with and without hard negatives.
2. Add data augmentation (random crops, color jitter) to the image encoder and measure the effect on alignment quality.
3. Replace the text embedding lookup with a character-level model that can handle novel text queries (not just the 10 class descriptions).