In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Building CLIP from Scratch: Dual Encoders and Contrastive Training

*Part 2 of the Vizuara series on Contrastive Pretraining (CLIP-style)*
*Estimated time: 60 minutes*

## 1. Why Does This Matter?

CLIP (Contrastive Language-Image Pretraining) changed everything by showing that a model can learn to understand both images and text in a shared embedding space. This shared space enables remarkable capabilities: zero-shot classification, image-text retrieval, and serving as the visual backbone for modern VLMs like LLaVA and GPT-4V.

In Notebook 1, we learned the contrastive loss function. Now we will build the full CLIP architecture from scratch -- an image encoder, a text encoder, and the training pipeline that connects them.

By the end of this notebook, you will:
- Build a Vision Transformer (ViT) image encoder from scratch
- Build a Transformer text encoder from scratch
- Train a complete CLIP model on CIFAR-10 with synthetic captions
- Visualize the learned multimodal embedding space

In [None]:
# Setup and GPU check
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
torch.manual_seed(42)
np.random.seed(42)

## 2. Building Intuition

Think of CLIP as two translators working side by side. One translator speaks "image" -- they can take any photograph and describe its content as a string of numbers (a vector). The other translator speaks "text" -- they can take any sentence and also produce a string of numbers.

The magic is that both translators produce numbers in the **same language**. A photo of a dog and the text "a photo of a dog" both map to similar vectors. A photo of a dog and the text "a photo of a car" map to very different vectors.

How do we train these translators? We show them millions of image-caption pairs and say: "The vector for this image should be close to the vector for this caption, and far from the vectors of all other captions in the batch."

### Think About This
- Why do we need two separate encoders instead of one? Could a single network process both images and text?
- What is the minimum information the image encoder needs to capture to match a caption?
- Why would normalizing embeddings to unit length be important?

## 3. The Mathematics

### The Dual Encoder Architecture

Let $f_\theta$ be the image encoder and $g_\phi$ be the text encoder. Given an image $x_I$ and text $x_T$:

$$z_I = f_\theta(x_I) \in \mathbb{R}^d, \quad z_T = g_\phi(x_T) \in \mathbb{R}^d$$

Computationally, each encoder takes its input (pixels or tokens) through a neural network and outputs a $d$-dimensional embedding vector. Both vectors live in the same space.

We normalize both embeddings:
$$\hat{z}_I = \frac{z_I}{\|z_I\|}, \quad \hat{z}_T = \frac{z_T}{\|z_T\|}$$

This ensures all embeddings lie on the unit hypersphere, so cosine similarity equals the dot product.

### The CLIP Loss

For a batch of $N$ image-text pairs, the CLIP loss is:

$$\mathcal{L} = \frac{1}{2}\left(-\frac{1}{N}\sum_{i=1}^{N}\log\frac{e^{\hat{z}_{I_i}\cdot\hat{z}_{T_i}/\tau}}{\sum_{j=1}^{N}e^{\hat{z}_{I_i}\cdot\hat{z}_{T_j}/\tau}} - \frac{1}{N}\sum_{i=1}^{N}\log\frac{e^{\hat{z}_{T_i}\cdot\hat{z}_{I_i}/\tau}}{\sum_{j=1}^{N}e^{\hat{z}_{T_i}\cdot\hat{z}_{I_j}/\tau}}\right)$$

This is a symmetric cross-entropy loss: the first term classifies images-to-texts, and the second classifies texts-to-images. The temperature $\tau$ controls sharpness.

Let us compute a quick example with $N=2$, $\tau=0.5$. Suppose:
$$\hat{z}_{I_1} \cdot \hat{z}_{T_1} = 0.8, \quad \hat{z}_{I_1} \cdot \hat{z}_{T_2} = 0.1$$

Image-to-text loss for pair 1:
$$-\log\frac{e^{0.8/0.5}}{e^{0.8/0.5} + e^{0.1/0.5}} = -\log\frac{e^{1.6}}{e^{1.6} + e^{0.2}} = -\log\frac{4.95}{4.95 + 1.22} = -\log(0.802) = 0.220$$

A loss of 0.220 means the model has high confidence in the correct match. This is what we want.

## 4. Let's Build It -- Component by Component

### 4.1 Patch Embedding (ViT Image Encoder)

A Vision Transformer splits an image into patches, embeds each patch, and processes them with self-attention.

In [None]:
class PatchEmbedding(nn.Module):
    """Split image into patches and embed them."""
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=128):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        # Conv2d with kernel_size=patch_size and stride=patch_size
        # acts like a patch-wise linear projection
        self.projection = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

        # CLS token and positional embeddings
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)
        self.pos_embed = nn.Parameter(
            torch.randn(1, self.num_patches + 1, embed_dim) * 0.02
        )

    def forward(self, x):
        B = x.shape[0]
        # Project patches: (B, C, H, W) -> (B, embed_dim, H/P, W/P)
        x = self.projection(x)  # (B, embed_dim, num_patches_h, num_patches_w)
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)

        # Prepend CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, num_patches+1, embed_dim)

        # Add positional embeddings
        x = x + self.pos_embed
        return x

# Test patch embedding
patch_embed = PatchEmbedding(img_size=32, patch_size=4, embed_dim=128)
test_img = torch.randn(2, 3, 32, 32)
patches = patch_embed(test_img)
print(f"Input: {test_img.shape}")
print(f"Patches: {patches.shape}  (batch, num_patches+1, embed_dim)")
print(f"Number of patches: {patch_embed.num_patches} = (32/4)^2")

### 4.2 Transformer Encoder Block

In [None]:
class TransformerBlock(nn.Module):
    """Standard Transformer encoder block with multi-head attention."""
    def __init__(self, embed_dim=128, num_heads=4, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim, num_heads, dropout=dropout, batch_first=True
        )
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # Self-attention with residual
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + attn_out

        # MLP with residual
        x = x + self.mlp(self.norm2(x))
        return x

# Test
block = TransformerBlock(embed_dim=128, num_heads=4)
out = block(patches)
print(f"Transformer block output: {out.shape}")

### 4.3 Full Image Encoder

In [None]:
class ImageEncoder(nn.Module):
    """Vision Transformer image encoder for CLIP."""
    def __init__(self, img_size=32, patch_size=4, in_channels=3,
                 embed_dim=128, depth=4, num_heads=4, output_dim=128):
        super().__init__()
        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_channels, embed_dim
        )
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.projection = nn.Linear(embed_dim, output_dim)

    def forward(self, x):
        x = self.patch_embed(x)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        # Use CLS token as image representation
        cls_output = x[:, 0]
        return self.projection(cls_output)

image_encoder = ImageEncoder(output_dim=128)
test_out = image_encoder(test_img)
print(f"Image encoder output: {test_out.shape}")
print(f"Parameters: {sum(p.numel() for p in image_encoder.parameters()):,}")

In [None]:
# Visualization checkpoint: check that embeddings are diverse
with torch.no_grad():
    random_imgs = torch.randn(20, 3, 32, 32)
    random_embs = F.normalize(image_encoder(random_imgs), dim=-1)

sim_matrix = (random_embs @ random_embs.T).numpy()
plt.figure(figsize=(6, 5))
plt.imshow(sim_matrix, cmap='RdYlGn', vmin=-1, vmax=1)
plt.colorbar(label='Cosine Similarity')
plt.title('Image Embedding Similarities\n(Should be noisy, not uniform)', fontsize=13)
plt.xlabel('Image Index')
plt.ylabel('Image Index')
plt.tight_layout()
plt.show()

### 4.4 Text Encoder

In [None]:
class TextEncoder(nn.Module):
    """Transformer text encoder for CLIP."""
    def __init__(self, vocab_size=1000, max_len=32, embed_dim=128,
                 depth=2, num_heads=4, output_dim=128):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(
            torch.randn(1, max_len, embed_dim) * 0.02
        )
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.projection = nn.Linear(embed_dim, output_dim)

    def forward(self, tokens):
        B, L = tokens.shape
        x = self.token_embed(tokens) + self.pos_embed[:, :L]
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        # Mean pooling over sequence
        x = x.mean(dim=1)
        return self.projection(x)

text_encoder = TextEncoder(vocab_size=200, output_dim=128)
test_tokens = torch.randint(0, 200, (2, 10))
text_out = text_encoder(test_tokens)
print(f"Text encoder output: {text_out.shape}")
print(f"Parameters: {sum(p.numel() for p in text_encoder.parameters()):,}")

### 4.5 Full CLIP Model

In [None]:
class MiniCLIP(nn.Module):
    """Complete CLIP model with image and text encoders."""
    def __init__(self, image_encoder, text_encoder, temperature=0.07):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.logit_scale = nn.Parameter(
            torch.log(torch.tensor(1.0 / temperature))
        )

    def forward(self, images, tokens):
        # Encode both modalities
        img_emb = F.normalize(self.image_encoder(images), dim=-1)
        txt_emb = F.normalize(self.text_encoder(tokens), dim=-1)

        # Compute scaled similarity matrix
        logit_scale = self.logit_scale.exp().clamp(max=100)
        logits = logit_scale * (img_emb @ txt_emb.T)
        return logits, img_emb, txt_emb

    def clip_loss(self, logits):
        """Symmetric contrastive loss."""
        labels = torch.arange(logits.size(0), device=logits.device)
        loss_i2t = F.cross_entropy(logits, labels)
        loss_t2i = F.cross_entropy(logits.T, labels)
        return (loss_i2t + loss_t2i) / 2

clip_model = MiniCLIP(
    ImageEncoder(output_dim=128),
    TextEncoder(vocab_size=200, output_dim=128)
).to(device)
total_params = sum(p.numel() for p in clip_model.parameters())
print(f"Total Mini-CLIP parameters: {total_params:,}")

## 5. Your Turn

### TODO: Build the Caption Tokenizer

In [None]:
class SimpleTokenizer:
    """
    A simple word-level tokenizer for CIFAR-10 captions.
    Vocabulary: PAD=0, then common words mapped to integers.
    """
    def __init__(self, max_len=12):
        self.max_len = max_len
        # ============ TODO ============
        # Build a vocabulary from these CIFAR-10 class templates:
        # "a photo of a [class]" where classes are:
        # airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
        #
        # Step 1: Create a word_to_idx dictionary starting with {"<PAD>": 0}
        # Step 2: Add all unique words from the templates
        # Hint: split each template into words and add unseen words
        # ==============================

        self.word_to_idx = {"<PAD>": 0}  # YOUR CODE to extend this
        self.vocab_size = len(self.word_to_idx)

    def encode(self, text):
        """Convert text to padded token tensor."""
        words = text.lower().split()
        tokens = [self.word_to_idx.get(w, 0) for w in words]
        # Pad or truncate
        if len(tokens) < self.max_len:
            tokens = tokens + [0] * (self.max_len - len(tokens))
        else:
            tokens = tokens[:self.max_len]
        return torch.tensor(tokens)

    def batch_encode(self, texts):
        return torch.stack([self.encode(t) for t in texts])

In [None]:
# Verification
CIFAR10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                    'dog', 'frog', 'horse', 'ship', 'truck']

tokenizer = SimpleTokenizer()
test_text = "a photo of a dog"
tokens = tokenizer.encode(test_text)
print(f"'{test_text}' -> {tokens}")
assert tokens[0] != 0, "First token should not be PAD"
assert tokenizer.vocab_size > 10, f"Vocab too small: {tokenizer.vocab_size}"
print(f"Vocabulary size: {tokenizer.vocab_size}")
print("Correct! Your tokenizer works.")

### TODO: Implement the Training Step

In [None]:
def train_step(clip_model, images, captions, optimizer, tokenizer):
    """
    Perform one training step of Mini-CLIP.

    Args:
        clip_model: the MiniCLIP model
        images: batch of images (B, C, H, W)
        captions: list of caption strings
        optimizer: the optimizer
        tokenizer: text tokenizer

    Returns:
        loss value, accuracy
    """
    # ============ TODO ============
    # Step 1: Tokenize captions using tokenizer.batch_encode(captions)
    # Step 2: Move images and tokens to device
    # Step 3: Forward pass through clip_model to get logits
    # Step 4: Compute clip_model.clip_loss(logits)
    # Step 5: Backward pass and optimizer step
    # Step 6: Compute accuracy (how often argmax matches diagonal)
    # ==============================

    loss = ???  # YOUR CODE HERE
    accuracy = ???  # YOUR CODE HERE

    return loss.item(), accuracy

In [None]:
# Verification cell
print("This will be verified during the training loop in section 7.")
print("Make sure your train_step returns (loss_value, accuracy_value)")

## 6. Putting It All Together

In [None]:
# Prepare CIFAR-10 dataset with captions
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_dataset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)

# Caption templates for each class
CAPTION_TEMPLATES = [
    "a photo of a {}",
    "an image of a {}",
    "a picture of a {}",
]

def get_caption(label_idx):
    """Generate a caption for a CIFAR-10 label."""
    class_name = CIFAR10_CLASSES[label_idx]
    template = CAPTION_TEMPLATES[np.random.randint(len(CAPTION_TEMPLATES))]
    return template.format(class_name)

# Create dataloader
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=2
)

print(f"Training samples: {len(train_dataset)}")
print(f"Batches per epoch: {len(train_loader)}")

# Sample captions
for i in range(5):
    _, label = train_dataset[i]
    print(f"  Class {label} ({CIFAR10_CLASSES[label]}): '{get_caption(label)}'")

In [None]:
# Visualize sample images with their captions
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i in range(10):
    img, label = train_dataset[i]
    ax = axes[i // 5, i % 5]
    img_display = img.permute(1, 2, 0).numpy() * 0.5 + 0.5
    ax.imshow(img_display)
    ax.set_title(get_caption(label), fontsize=9)
    ax.axis('off')
plt.suptitle('CIFAR-10 Images with Generated Captions', fontsize=14)
plt.tight_layout()
plt.show()

## 7. Training and Results

In [None]:
# Initialize model and optimizer
tokenizer = SimpleTokenizer(max_len=8)
clip_model = MiniCLIP(
    ImageEncoder(img_size=32, patch_size=4, embed_dim=128, depth=4,
                 num_heads=4, output_dim=128),
    TextEncoder(vocab_size=tokenizer.vocab_size, max_len=8, embed_dim=128,
                depth=2, num_heads=4, output_dim=128),
    temperature=0.07
).to(device)

optimizer = torch.optim.AdamW(clip_model.parameters(), lr=3e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

print(f"Model parameters: {sum(p.numel() for p in clip_model.parameters()):,}")

In [None]:
# Training loop
num_epochs = 10
train_losses = []
train_accs = []

for epoch in range(num_epochs):
    clip_model.train()
    epoch_loss = 0
    epoch_acc = 0
    n_batches = 0

    for images, labels in train_loader:
        images = images.to(device)
        captions = [get_caption(l.item()) for l in labels]
        tokens = tokenizer.batch_encode(captions).to(device)

        # Forward pass
        logits, _, _ = clip_model(images, tokens)
        loss = clip_model.clip_loss(logits)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute accuracy
        with torch.no_grad():
            preds = logits.argmax(dim=1)
            gt = torch.arange(logits.size(0), device=device)
            acc = (preds == gt).float().mean().item()

        epoch_loss += loss.item()
        epoch_acc += acc
        n_batches += 1

    scheduler.step()
    avg_loss = epoch_loss / n_batches
    avg_acc = epoch_acc / n_batches
    train_losses.append(avg_loss)
    train_accs.append(avg_acc)
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f} | Acc: {avg_acc:.3f}")

In [None]:
# Visualize training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(range(1, num_epochs+1), train_losses, 'b-o', linewidth=2, markersize=6)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Contrastive Loss', fontsize=12)
ax1.set_title('Training Loss', fontsize=14)
ax1.grid(True, alpha=0.3)

ax2.plot(range(1, num_epochs+1), train_accs, 'g-o', linewidth=2, markersize=6)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Contrastive Accuracy', fontsize=12)
ax2.set_title('Training Accuracy', fontsize=14)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8. Final Output

In [None]:
# Visualize the learned embedding space
clip_model.eval()
all_img_embs = []
all_labels = []

with torch.no_grad():
    for images, labels in torch.utils.data.DataLoader(
        train_dataset, batch_size=256, shuffle=False
    ):
        img_emb = F.normalize(
            clip_model.image_encoder(images.to(device)), dim=-1
        ).cpu()
        all_img_embs.append(img_emb)
        all_labels.append(labels)
        if len(all_labels) * 256 >= 5000:
            break

all_img_embs = torch.cat(all_img_embs)[:5000]
all_labels = torch.cat(all_labels)[:5000]

# t-SNE visualization
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
emb_2d = tsne.fit_transform(all_img_embs.numpy())

plt.figure(figsize=(10, 8))
scatter = plt.scatter(emb_2d[:, 0], emb_2d[:, 1],
                      c=all_labels.numpy(), cmap='tab10', s=5, alpha=0.6)
plt.colorbar(scatter, ticks=range(10),
             label='Class')
plt.title('t-SNE of Mini-CLIP Image Embeddings\n(Colors = CIFAR-10 classes)', fontsize=14)
plt.xlabel('t-SNE 1')
plt.ylabel('t-SNE 2')
plt.tight_layout()
plt.show()
print("Congratulations! You have built CLIP from scratch!")
print("Similar classes should cluster together in the embedding space.")

## 9. Reflection and Next Steps

### Reflection Questions
1. How does the number of Transformer layers affect the quality of embeddings? What is the minimum depth needed?
2. Why do we use separate encoders for images and text instead of a shared one?
3. What would happen if we increased the batch size from 128 to 1024? Why does batch size matter so much in contrastive learning?
4. The learned temperature converges to a small value. What does this tell us about the model's confidence?

### Optional Challenges
1. Try using a ResNet instead of ViT as the image encoder. Compare the learned representations.
2. Add more diverse caption templates and see if it improves zero-shot performance.
3. Try training for more epochs and track zero-shot classification accuracy on the test set.