# PyTorch Tutorial: Multimodal AI (CLIP)

The future of AI is **Multimodal**: models that can understand text, images, audio, and video simultaneously. The breakthrough model that started this era is **CLIP (Contrastive Language-Image Pre-training)** by OpenAI.

In this notebook, we will build the core components of CLIP from scratch.

## Learning Objectives
- Understand Contrastive Learning
- Implement Dual Encoders (Image & Text)
- Implement the Contrastive Loss function
- Understand Zero-Shot Classification


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)

## 1. The Core Idea: Contrastive Learning

CLIP doesn't predict labels (like "cat" or "dog"). Instead, it learns to match **images** to their **captions**.

- **Positive pairs**: (Image of dog, "A photo of a dog") -> Maximize similarity
- **Negative pairs**: (Image of dog, "A photo of a car") -> Minimize similarity

We project both images and text into a shared **embedding space**.

## 2. Building the Dual Encoder

We need two separate neural networks:
1. **Image Encoder**: Turns pixels into a vector.
2. **Text Encoder**: Turns text into a vector.

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        # Simple CNN for demonstration
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.fc = nn.Linear(32 * 8 * 8, embedding_dim) # Assuming 32x32 input

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, 128)
        self.rnn = nn.LSTM(128, 64, batch_first=True)
        self.fc = nn.Linear(64, embedding_dim)

    def forward(self, x):
        x = self.embedding(x)
        _, (hidden, _) = self.rnn(x)
        x = hidden[-1]
        x = self.fc(x)
        return x

embed_dim = 64
image_model = ImageEncoder(embed_dim)
text_model = TextEncoder(vocab_size=1000, embedding_dim=embed_dim)
print("Encoders created!")

## 3. The CLIP Model

The CLIP model puts them together. Crucially, it normalizes the embeddings so we can compute Cosine Similarity easily.

In [None]:
class CLIP(nn.Module):
    def __init__(self, image_encoder, text_encoder):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.temperature = nn.Parameter(torch.ones([]) * 0.07)

    def forward(self, images, text):
        # 1. Get Embeddings
        image_features = self.image_encoder(images)
        text_features = self.text_encoder(text)

        # 2. Normalize (L2 norm)
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)

        # 3. Compute Similarity Matrix (Batch x Batch)
        # logits[i][j] = similarity between image i and text j
        logits = (image_features @ text_features.t()) / self.temperature
        
        return logits

model = CLIP(image_model, text_model)

## 4. Contrastive Loss

We want the diagonal of the similarity matrix to be high (image matches its own text) and off-diagonal to be low.

We use **Cross Entropy Loss** on both axes.

In [None]:
def contrastive_loss(logits):
    batch_size = logits.size(0)
    targets = torch.arange(batch_size) # [0, 1, 2, ...]
    
    # Loss for (Image -> Text)
    loss_i = F.cross_entropy(logits, targets)
    # Loss for (Text -> Image)
    loss_t = F.cross_entropy(logits.t(), targets)
    
    return (loss_i + loss_t) / 2

# Dummy Forward Pass
images = torch.randn(4, 3, 32, 32)
text = torch.randint(0, 1000, (4, 10))

logits = model(images, text)
loss = contrastive_loss(logits)

print(f"Logits shape: {logits.shape}")
print(f"Loss: {loss.item()}")

## 5. Zero-Shot Classification

Once trained, how do we classify an image?
1. Give the image to the Image Encoder.
2. Give a list of possible class names (e.g., "dog", "cat", "bird") to the Text Encoder.
3. The class with the highest similarity score wins!

This is why it's called "Zero-Shot" — it can classify categories it has never seen during training, as long as it understands the words.

## 6. Understanding Temperature in Contrastive Learning

The **temperature** parameter is crucial for contrastive learning. Let's understand why.

In [None]:
# Temperature controls the "sharpness" of the probability distribution
import matplotlib.pyplot as plt
import numpy as np

def softmax_with_temp(logits, temperature):
    logits = logits / temperature
    exp_logits = np.exp(logits - np.max(logits))
    return exp_logits / exp_logits.sum()

# Example: similarities between one image and 5 text candidates
similarities = np.array([0.8, 0.6, 0.5, 0.3, 0.1])  # Image matches text 0

temperatures = [0.01, 0.07, 0.5, 1.0]
fig, axes = plt.subplots(1, 4, figsize=(14, 3))

for ax, temp in zip(axes, temperatures):
    probs = softmax_with_temp(similarities, temp)
    ax.bar(range(5), probs)
    ax.set_title(f'τ = {temp}')
    ax.set_xlabel('Text Index')
    ax.set_ylabel('Probability')
    ax.set_ylim(0, 1.1)

plt.suptitle('Effect of Temperature on Probability Distribution', y=1.02)
plt.tight_layout()
plt.show()

print("Low temperature (τ=0.01): Almost one-hot, very confident")
print("CLIP default (τ=0.07): Sharp but not extreme")
print("High temperature (τ=1.0): Softer distribution, more exploration")

## 7. InfoNCE Loss (The Formal Math)

CLIP uses **InfoNCE (Noise Contrastive Estimation)** loss, which is a generalization of cross-entropy:

$$\mathcal{L}_{i \to t} = -\log \frac{\exp(\text{sim}(I_i, T_i) / \tau)}{\sum_{j=1}^{N} \exp(\text{sim}(I_i, T_j) / \tau)}$$

This is equivalent to:
1. For each image, classify which of the N texts is its match
2. Cross-entropy loss where the "correct class" is the diagonal

**Why it works**: Maximizes mutual information between image and text representations.

## 8. Hard Negative Mining

Not all negatives are equal. **Hard negatives** are samples that are similar but not matches - they provide the strongest learning signal.

In [None]:
def contrastive_loss_with_hard_negatives(image_features, text_features, temperature=0.07):
    """
    Enhanced contrastive loss with hard negative mining.
    
    Hard negatives: samples with high similarity but wrong label
    - Easy negative: "dog" vs "airplane" (very different)
    - Hard negative: "golden retriever" vs "labrador" (similar but different)
    """
    batch_size = image_features.shape[0]
    
    # Normalize features
    image_features = F.normalize(image_features, dim=-1)
    text_features = F.normalize(text_features, dim=-1)
    
    # Compute similarity matrix
    logits = (image_features @ text_features.t()) / temperature
    
    # Labels are the diagonal (i matches i)
    labels = torch.arange(batch_size, device=logits.device)
    
    # Standard InfoNCE loss
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.t(), labels)
    
    # Hard negative analysis (for monitoring)
    with torch.no_grad():
        # Find hardest negative for each image
        logits_masked = logits.clone()
        logits_masked[torch.arange(batch_size), torch.arange(batch_size)] = float('-inf')
        hardest_neg_sim = logits_masked.max(dim=1).values.mean()
        
        # Positive similarity (diagonal)
        pos_sim = logits.diag().mean()
        
        # Margin between positive and hardest negative
        margin = pos_sim - hardest_neg_sim
    
    return (loss_i2t + loss_t2i) / 2, {
        'positive_sim': pos_sim.item(),
        'hardest_neg_sim': hardest_neg_sim.item(),
        'margin': margin.item()
    }

# Demo
batch_size = 8
embed_dim = 64
img_feats = torch.randn(batch_size, embed_dim)
txt_feats = torch.randn(batch_size, embed_dim)

loss, metrics = contrastive_loss_with_hard_negatives(img_feats, txt_feats)
print(f"Loss: {loss.item():.4f}")
print(f"Positive similarity: {metrics['positive_sim']:.4f}")
print(f"Hardest negative similarity: {metrics['hardest_neg_sim']:.4f}")
print(f"Margin (higher is better): {metrics['margin']:.4f}")

## 9. Vision Transformer (ViT) Image Encoder

Modern CLIP uses **Vision Transformers** instead of CNNs for the image encoder.

In [None]:
class PatchEmbedding(nn.Module):
    """Split image into patches and project to embedding dimension."""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, 
                              kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        # (B, C, H, W) -> (B, embed_dim, n_patches_h, n_patches_w)
        x = self.proj(x)
        # Flatten spatial dimensions
        x = x.flatten(2).transpose(1, 2)  # (B, n_patches, embed_dim)
        return x

class ViTImageEncoder(nn.Module):
    """Simplified Vision Transformer for CLIP."""
    def __init__(self, img_size=224, patch_size=16, embed_dim=768, 
                 depth=12, num_heads=12, output_dim=512):
        super().__init__()
        
        self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)
        n_patches = (img_size // patch_size) ** 2
        
        # CLS token and position embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        
        # Transformer blocks
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, 
            dim_feedforward=embed_dim * 4, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        
        # Project to output dimension
        self.proj = nn.Linear(embed_dim, output_dim)
        self.ln = nn.LayerNorm(embed_dim)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # (B, n_patches, embed_dim)
        
        # Add CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, n_patches+1, embed_dim)
        
        # Add position embeddings
        x = x + self.pos_embed
        
        # Transformer
        x = self.transformer(x)
        
        # Use CLS token as image representation
        x = self.ln(x[:, 0])
        x = self.proj(x)
        
        return x

# Demo
vit_encoder = ViTImageEncoder(img_size=32, patch_size=4, embed_dim=128, 
                               depth=4, num_heads=4, output_dim=64)
dummy_image = torch.randn(2, 3, 32, 32)
output = vit_encoder(dummy_image)
print(f"ViT output shape: {output.shape}")  # (2, 64)

## 10. FAANG Interview Questions

### Q1: Explain how CLIP achieves zero-shot classification without being trained on specific classes.

**Answer**: CLIP learns a shared embedding space where both images and text are mapped to vectors. During training:
1. It sees millions of (image, caption) pairs from the internet
2. It learns to maximize similarity between matching pairs (InfoNCE loss)
3. The text encoder learns semantic representations of natural language

For zero-shot classification:
1. Create text prompts like "a photo of a {class}" for each class
2. Encode the image and all prompts
3. The class with highest cosine similarity wins

This works because:
- The model understands language semantically, not just memorizes class names
- It can generalize to any concept describable in natural language
- No retraining or fine-tuning needed for new classes

---

### Q2: Why does CLIP use a learnable temperature parameter instead of a fixed value?

**Answer**: Temperature controls the sharpness of the softmax distribution.

**Fixed temperature issues**:
- Too low: Loss becomes numerically unstable, gradients explode
- Too high: All pairs become equally likely, slow learning

**Learnable temperature**:
- Starts at a reasonable value (e.g., 0.07)
- Model learns optimal sharpness during training
- Adapts to the difficulty of the task
- Typically decreases during training as embeddings become more separable

OpenAI found that learnable temperature consistently outperforms fixed values.

---

### Q3: What are the computational challenges of training CLIP and how are they addressed?

**Answer**:

**Challenge 1: Large batch sizes needed**
- Contrastive learning requires many negatives (batch = negatives)
- CLIP used batch size of 32,768!
- **Solution**: Distributed training across many GPUs, gradient accumulation

**Challenge 2: Memory for similarity matrix**
- Similarity matrix is $N \times N$ where N = batch size
- For N=32768: ~4GB just for the matrix
- **Solution**: Gradient checkpointing, mixed precision, distributed computation

**Challenge 3: Expensive text encoding**
- Must encode all text in batch for every forward pass
- **Solution**: Cache text embeddings during evaluation, use efficient transformers

**Challenge 4: Data quality**
- Need high-quality (image, text) pairs at massive scale
- **Solution**: Filter web data, use CLIP itself to filter (bootstrapping)

---

### Q4: How would you extend CLIP to handle more than 2 modalities (e.g., add audio)?

**Answer**: Multi-modal extensions like ImageBind use:

1. **Anchor modality**: Use images as the "hub" that connects all modalities
2. **Shared embedding space**: All modalities map to same dimensional space
3. **Pairwise contrastive learning**:
   - Image-Text pairs (existing)
   - Image-Audio pairs (new)
   - Image-Video pairs (new)
4. **Emergent alignments**: Audio-Text alignment emerges through shared image space

Key insight: You don't need paired data for all modality combinations - alignment transfers through the anchor modality.

---

### Q5: What's the difference between early fusion and late fusion in multimodal models?

**Answer**:

| Aspect | Early Fusion | Late Fusion (CLIP) |
|--------|-------------|-------------------|
| **Architecture** | Combine inputs before encoder | Separate encoders, combine embeddings |
| **Interaction** | Deep cross-modal attention | Dot product / shallow interaction |
| **Flexibility** | Single model, fixed modalities | Modular, can swap encoders |
| **Efficiency** | One forward pass | Can pre-compute embeddings |
| **Use cases** | VQA, detailed understanding | Retrieval, classification |

CLIP uses **late fusion** because:
1. Enables efficient retrieval (pre-compute all embeddings)
2. Modular design (upgrade image encoder without retraining text)
3. Scales to massive datasets
4. Simple to parallelize training

## Key Takeaways

1. **Multimodal**: Combining vision and text in one shared embedding space.
2. **Contrastive Loss (InfoNCE)**: Learning by comparing positive and negative pairs.
3. **Dual Encoders**: Two towers (Image & Text) that meet through similarity computation.
4. **Zero-Shot**: Classifying without explicit training on those classes.
5. **Temperature**: Controls softmax sharpness; learnable gives best results.
6. **Hard Negatives**: The most informative training signal comes from challenging examples.
7. **Vision Transformers**: Modern CLIP uses ViT, not CNNs, for image encoding.