In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Zero-Shot Transfer and Evaluation: The Magic of CLIP

*Part 3 of the Vizuara series on Contrastive Pretraining (CLIP-style)*
*Estimated time: 45 minutes*

## 1. Why Does This Matter?

The most remarkable property of CLIP is **zero-shot classification** -- classifying images into categories the model has never been explicitly trained on. No fine-tuning, no labeled data, just natural language prompts.

This is a paradigm shift: instead of training a separate classifier for every task, you describe the task in words and CLIP figures out the rest.

By the end of this notebook, you will:
- Implement zero-shot classification using CLIP embeddings
- Evaluate performance on unseen test data
- Engineer prompts to improve accuracy
- Build an image retrieval system
- Understand CLIP's limitations through targeted experiments

In [None]:
# Setup and GPU check
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
torch.manual_seed(42)
np.random.seed(42)

CIFAR10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                    'dog', 'frog', 'horse', 'ship', 'truck']

## 2. Building Intuition

Let us think about how zero-shot classification works with a simple analogy.

Imagine you walk into a room full of paintings. You have never seen these particular paintings before, but someone hands you a note that says "Find the painting of a dog." You look around, compare each painting against your mental image of "dog," and point to the one that matches best.

CLIP does exactly this, but mathematically:
1. Embed the image into a vector
2. Embed "a photo of a dog" into a vector
3. Compare them using cosine similarity
4. Pick the label with the highest similarity

The key insight is that the text encoder provides a **flexible classifier interface**. You can define any set of classes just by writing text prompts.

### Think About This
- What if we used "a bright photograph of a cute dog" instead of "a photo of a dog"? Would it matter?
- Why might CLIP struggle with the difference between "two cats" and "three cats"?
- Could you use CLIP for sentiment classification? How would you phrase the prompts?

## 3. The Mathematics

### Zero-Shot Prediction

Given an image $x$ and a set of $K$ class labels $\{c_1, c_2, \ldots, c_K\}$, zero-shot prediction works as follows:

1. Create text prompts: $t_k = \text{"a photo of a } c_k\text{"}$
2. Encode: $z_I = f(x)$, $z_{T_k} = g(t_k)$ for each $k$
3. Normalize: $\hat{z}_I = z_I / \|z_I\|$, $\hat{z}_{T_k} = z_{T_k} / \|z_{T_k}\|$
4. Predict: $\hat{y} = \arg\max_k \frac{\exp(\hat{z}_I \cdot \hat{z}_{T_k} / \tau)}{\sum_{j=1}^{K} \exp(\hat{z}_I \cdot \hat{z}_{T_j} / \tau)}$

Computationally, this is a softmax over the similarities between the image embedding and each class text embedding, with temperature scaling.

Let us work through a simple example with 3 classes. Suppose the similarities are:
- $\text{sim}(\text{image}, \text{"dog"}) = 0.85$
- $\text{sim}(\text{image}, \text{"cat"}) = 0.60$
- $\text{sim}(\text{image}, \text{"car"}) = 0.10$

With $\tau = 0.07$:
$$p(\text{dog}) = \frac{e^{0.85/0.07}}{e^{0.85/0.07} + e^{0.60/0.07} + e^{0.10/0.07}} = \frac{e^{12.14}}{e^{12.14} + e^{8.57} + e^{1.43}}$$

Since $e^{12.14} \gg e^{8.57} \gg e^{1.43}$, essentially all probability mass goes to "dog." This is exactly what we want -- the model is very confident about the correct class.

### Prompt Ensembling

A simple trick to improve accuracy: use multiple prompt templates per class and average the embeddings:

$$\hat{z}_{T_k} = \frac{1}{M}\sum_{m=1}^{M} g(\text{template}_m(c_k))$$

This reduces sensitivity to any single prompt wording.

## 4. Let's Build It -- Component by Component

### 4.1 Rebuild the Model Architecture

We need the same architecture from Notebook 2. Let us rebuild it.

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=128):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim) * 0.02)

    def forward(self, x):
        B = x.shape[0]
        x = self.projection(x).flatten(2).transpose(1, 2)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        return x + self.pos_embed

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=128, num_heads=4, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)), nn.GELU(),
            nn.Dropout(dropout), nn.Linear(int(embed_dim * mlp_ratio), embed_dim), nn.Dropout(dropout),
        )

    def forward(self, x):
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + attn_out
        return x + self.mlp(self.norm2(x))

class ImageEncoder(nn.Module):
    def __init__(self, img_size=32, patch_size=4, embed_dim=128, depth=4, num_heads=4, output_dim=128):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)
        self.blocks = nn.ModuleList([TransformerBlock(embed_dim, num_heads) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.projection = nn.Linear(embed_dim, output_dim)

    def forward(self, x):
        x = self.patch_embed(x)
        for block in self.blocks:
            x = block(x)
        return self.projection(self.norm(x)[:, 0])

class TextEncoder(nn.Module):
    def __init__(self, vocab_size=200, max_len=12, embed_dim=128, depth=2, num_heads=4, output_dim=128):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, embed_dim) * 0.02)
        self.blocks = nn.ModuleList([TransformerBlock(embed_dim, num_heads) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.projection = nn.Linear(embed_dim, output_dim)

    def forward(self, tokens):
        x = self.token_embed(tokens) + self.pos_embed[:, :tokens.size(1)]
        for block in self.blocks:
            x = block(x)
        return self.projection(self.norm(x).mean(dim=1))

class MiniCLIP(nn.Module):
    def __init__(self, image_encoder, text_encoder, temperature=0.07):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.logit_scale = nn.Parameter(torch.log(torch.tensor(1.0 / temperature)))

    def forward(self, images, tokens):
        img_emb = F.normalize(self.image_encoder(images), dim=-1)
        txt_emb = F.normalize(self.text_encoder(tokens), dim=-1)
        logit_scale = self.logit_scale.exp().clamp(max=100)
        return logit_scale * (img_emb @ txt_emb.T), img_emb, txt_emb

    def clip_loss(self, logits):
        labels = torch.arange(logits.size(0), device=logits.device)
        return (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2

print("Architecture rebuilt.")

### 4.2 Simple Tokenizer

In [None]:
class SimpleTokenizer:
    def __init__(self, max_len=12):
        self.max_len = max_len
        self.word_to_idx = {"<PAD>": 0}
        # Build vocab from all possible caption words
        all_words = set()
        templates = ["a photo of a {}", "an image of a {}", "a picture of a {}",
                     "a blurry photo of a {}", "a close-up photo of a {}"]
        for cls in CIFAR10_CLASSES:
            for tmpl in templates:
                for w in tmpl.format(cls).lower().split():
                    all_words.add(w)
        for idx, word in enumerate(sorted(all_words), 1):
            self.word_to_idx[word] = idx
        self.vocab_size = len(self.word_to_idx)

    def encode(self, text):
        words = text.lower().split()
        tokens = [self.word_to_idx.get(w, 0) for w in words]
        if len(tokens) < self.max_len:
            tokens += [0] * (self.max_len - len(tokens))
        return torch.tensor(tokens[:self.max_len])

    def batch_encode(self, texts):
        return torch.stack([self.encode(t) for t in texts])

tokenizer = SimpleTokenizer()
print(f"Vocabulary size: {tokenizer.vocab_size}")

### 4.3 Train the Model (Quick Training)

In [None]:
# Load data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=2)

# Initialize model
clip_model = MiniCLIP(
    ImageEncoder(output_dim=128),
    TextEncoder(vocab_size=tokenizer.vocab_size, max_len=12, output_dim=128),
).to(device)

optimizer = torch.optim.AdamW(clip_model.parameters(), lr=3e-4, weight_decay=0.01)

CAPTION_TEMPLATES = ["a photo of a {}", "an image of a {}", "a picture of a {}"]

def get_caption(label_idx):
    cls = CIFAR10_CLASSES[label_idx]
    tmpl = CAPTION_TEMPLATES[np.random.randint(len(CAPTION_TEMPLATES))]
    return tmpl.format(cls)

In [None]:
# Train for 10 epochs
num_epochs = 10
for epoch in range(num_epochs):
    clip_model.train()
    total_loss = 0
    n = 0
    for images, labels in train_loader:
        images = images.to(device)
        captions = [get_caption(l.item()) for l in labels]
        tokens = tokenizer.batch_encode(captions).to(device)

        logits, _, _ = clip_model(images, tokens)
        loss = clip_model.clip_loss(logits)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        n += 1
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {total_loss/n:.4f}")
print("Training complete!")

## 5. Your Turn

### TODO: Implement Zero-Shot Classification

In [None]:
def zero_shot_classify(clip_model, images, class_names, tokenizer, templates=None):
    """
    Classify images using zero-shot text prompts.

    Args:
        clip_model: trained MiniCLIP model
        images: (B, C, H, W) image tensor
        class_names: list of class name strings
        tokenizer: text tokenizer
        templates: list of prompt templates (default: ["a photo of a {}"])

    Returns:
        predictions: (B,) tensor of predicted class indices
        probabilities: (B, K) tensor of class probabilities
    """
    if templates is None:
        templates = ["a photo of a {}"]

    clip_model.eval()
    with torch.no_grad():
        # ============ TODO ============
        # Step 1: Encode all images through clip_model.image_encoder
        # Step 2: Normalize image embeddings
        # Step 3: For each class, create text prompts using templates
        #         Average the embeddings across templates (prompt ensembling)
        # Step 4: Normalize text embeddings
        # Step 5: Compute similarity matrix (images x classes)
        # Step 6: Apply softmax to get probabilities
        # Step 7: Return argmax predictions and probabilities
        # ==============================

        predictions = ???  # YOUR CODE HERE
        probabilities = ???  # YOUR CODE HERE

    return predictions, probabilities

In [None]:
# Verification
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False)
test_images, test_labels = next(iter(test_loader))

preds, probs = zero_shot_classify(clip_model, test_images.to(device),
                                   CIFAR10_CLASSES, tokenizer)

# Check shapes
assert preds.shape == (100,), f"Expected shape (100,), got {preds.shape}"
assert probs.shape == (100, 10), f"Expected shape (100, 10), got {probs.shape}"

acc = (preds.cpu() == test_labels).float().mean()
print(f"Zero-shot accuracy on 100 test images: {acc:.1%}")
print("(Even modest accuracy shows the model learned meaningful representations!)")
print("Correct implementation!")

### TODO: Implement Prompt Ensembling

In [None]:
def evaluate_with_prompt_ensemble(clip_model, test_loader, class_names,
                                    tokenizer, template_sets):
    """
    Evaluate zero-shot accuracy using different prompt template sets.

    Args:
        template_sets: dict of {name: [template_list]}

    Returns:
        dict of {name: accuracy}
    """
    results = {}
    for name, templates in template_sets.items():
        # ============ TODO ============
        # Step 1: Loop over all batches in test_loader
        # Step 2: For each batch, call zero_shot_classify with the templates
        # Step 3: Compare predictions with ground truth labels
        # Step 4: Compute overall accuracy
        # ==============================

        accuracy = ???  # YOUR CODE HERE
        results[name] = accuracy
        print(f"  {name}: {accuracy:.1%}")

    return results

In [None]:
# Verification
template_sets = {
    "single": ["a photo of a {}"],
    "three": ["a photo of a {}", "an image of a {}", "a picture of a {}"],
    "descriptive": ["a photo of a {}", "a clear photo of a {}", "a close-up of a {}"],
}
test_loader_eval = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False)

print("Zero-shot accuracy with different prompt templates:")
results = evaluate_with_prompt_ensemble(clip_model, test_loader_eval,
                                         CIFAR10_CLASSES, tokenizer, template_sets)
print("\nPrompt ensembling typically improves accuracy!")

## 6. Putting It All Together

In [None]:
# Full evaluation on CIFAR-10 test set
clip_model.eval()
correct = 0
total = 0

# Pre-compute class embeddings with ensemble
templates = ["a photo of a {}", "an image of a {}", "a picture of a {}"]
class_embeddings = []

with torch.no_grad():
    for cls in CIFAR10_CLASSES:
        prompts = [t.format(cls) for t in templates]
        tokens = tokenizer.batch_encode(prompts).to(device)
        embs = F.normalize(clip_model.text_encoder(tokens), dim=-1)
        class_embeddings.append(embs.mean(dim=0))

    class_embeddings = torch.stack(class_embeddings)  # (10, 128)
    class_embeddings = F.normalize(class_embeddings, dim=-1)

    for images, labels in test_loader_eval:
        images = images.to(device)
        img_embs = F.normalize(clip_model.image_encoder(images), dim=-1)
        sims = img_embs @ class_embeddings.T
        preds = sims.argmax(dim=1).cpu()
        correct += (preds == labels).sum().item()
        total += len(labels)

accuracy = correct / total
print(f"Overall zero-shot accuracy on CIFAR-10 test set: {accuracy:.1%}")
print(f"Random chance would be: {1/10:.1%}")

In [None]:
# Per-class accuracy
class_correct = {cls: 0 for cls in CIFAR10_CLASSES}
class_total = {cls: 0 for cls in CIFAR10_CLASSES}

with torch.no_grad():
    for images, labels in test_loader_eval:
        images = images.to(device)
        img_embs = F.normalize(clip_model.image_encoder(images), dim=-1)
        sims = img_embs @ class_embeddings.T
        preds = sims.argmax(dim=1).cpu()
        for p, l in zip(preds, labels):
            cls_name = CIFAR10_CLASSES[l.item()]
            class_total[cls_name] += 1
            if p.item() == l.item():
                class_correct[cls_name] += 1

fig, ax = plt.subplots(figsize=(12, 5))
class_accs = [class_correct[c] / class_total[c] for c in CIFAR10_CLASSES]
colors = ['green' if a > 0.5 else 'orange' if a > 0.3 else 'red' for a in class_accs]
ax.bar(CIFAR10_CLASSES, class_accs, color=colors, edgecolor='black', linewidth=0.5)
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title('Per-Class Zero-Shot Accuracy', fontsize=14)
ax.set_ylim(0, 1.0)
ax.axhline(y=0.1, color='gray', linestyle='--', label='Random chance')
ax.legend()
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

## 7. Training and Results

In [None]:
# Image retrieval demo: find images matching a text query
def text_to_image_retrieval(clip_model, text_query, dataset, tokenizer,
                             top_k=5, n_candidates=1000):
    """Retrieve most similar images for a text query."""
    clip_model.eval()
    with torch.no_grad():
        # Encode query
        tokens = tokenizer.encode(text_query).unsqueeze(0).to(device)
        txt_emb = F.normalize(clip_model.text_encoder(tokens), dim=-1)

        # Encode candidate images
        indices = np.random.choice(len(dataset), n_candidates, replace=False)
        all_sims = []
        all_imgs = []
        all_labels = []

        for start in range(0, len(indices), 256):
            batch_idx = indices[start:start+256]
            imgs = torch.stack([dataset[i][0] for i in batch_idx]).to(device)
            labels = [dataset[i][1] for i in batch_idx]
            img_embs = F.normalize(clip_model.image_encoder(imgs), dim=-1)
            sims = (txt_emb @ img_embs.T).squeeze().cpu()
            all_sims.append(sims)
            all_imgs.extend([dataset[i][0] for i in batch_idx])
            all_labels.extend(labels)

        all_sims = torch.cat(all_sims)
        top_indices = all_sims.argsort(descending=True)[:top_k]

    return top_indices, all_sims, all_imgs, all_labels

In [None]:
# Run retrieval queries
queries = ["a photo of a dog", "a photo of a ship", "a photo of a truck"]

fig, axes = plt.subplots(len(queries), 6, figsize=(18, 3*len(queries)))
for q_idx, query in enumerate(queries):
    top_idx, sims, imgs, labels = text_to_image_retrieval(
        clip_model, query, test_dataset, tokenizer, top_k=5
    )

    axes[q_idx, 0].text(0.5, 0.5, f'Query:\n"{query}"', ha='center', va='center',
                        fontsize=11, transform=axes[q_idx, 0].transAxes)
    axes[q_idx, 0].axis('off')

    for rank, idx in enumerate(top_idx):
        img = imgs[idx].permute(1, 2, 0).numpy() * 0.5 + 0.5
        axes[q_idx, rank+1].imshow(img)
        cls = CIFAR10_CLASSES[labels[idx]]
        axes[q_idx, rank+1].set_title(f"#{rank+1}: {cls}\nsim={sims[idx]:.3f}", fontsize=9)
        axes[q_idx, rank+1].axis('off')

plt.suptitle('Text-to-Image Retrieval Results', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 8. Final Output

In [None]:
# Confusion matrix
from sklearn.metrics import confusion_matrix
import itertools

all_preds = []
all_true = []

clip_model.eval()
with torch.no_grad():
    for images, labels in test_loader_eval:
        images = images.to(device)
        img_embs = F.normalize(clip_model.image_encoder(images), dim=-1)
        sims = img_embs @ class_embeddings.T
        preds = sims.argmax(dim=1).cpu()
        all_preds.extend(preds.numpy())
        all_true.extend(labels.numpy())

cm = confusion_matrix(all_true, all_preds)

plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation='nearest', cmap='Blues')
plt.colorbar()
tick_marks = np.arange(len(CIFAR10_CLASSES))
plt.xticks(tick_marks, CIFAR10_CLASSES, rotation=45, ha='right')
plt.yticks(tick_marks, CIFAR10_CLASSES)

# Add text annotations
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j, i, format(cm[i, j], 'd'), ha='center', va='center',
             color='white' if cm[i, j] > cm.max() / 2 else 'black', fontsize=8)

plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.title('Zero-Shot Classification Confusion Matrix', fontsize=14)
plt.tight_layout()
plt.show()

print(f"\nFinal Zero-Shot Accuracy: {np.mean(np.array(all_preds) == np.array(all_true)):.1%}")
print("Congratulations! You have built a complete zero-shot classification system!")

## 9. Reflection and Next Steps

### Reflection Questions
1. Which classes does the model confuse most often? Why might that be?
2. How would you adapt this system for a completely different dataset (e.g., medical images)?
3. What are the limits of zero-shot classification? When would fine-tuning be necessary?
4. How does prompt engineering affect performance? What makes a good prompt?

### Optional Challenges
1. Try classifying into sub-categories: instead of "dog," use "golden retriever," "poodle," etc. Does CLIP distinguish them?
2. Implement image-to-text retrieval: given an image, find the best-matching caption.
3. Try adversarial prompts: "a photo that is NOT a dog." How does CLIP handle negation?
4. Compare single-template vs. multi-template ensembling across all 10 classes.