# PyTorch Tutorial: Multimodal AI (CLIP)

The future of AI is **Multimodal**: models that can understand text, images, audio, and video simultaneously. The breakthrough model that started this era is **CLIP (Contrastive Language-Image Pre-training)** by OpenAI.

In this notebook, we will build the core components of CLIP from scratch.

## Learning Objectives
- Understand Contrastive Learning
- Implement Dual Encoders (Image & Text)
- Implement the Contrastive Loss function
- Understand Zero-Shot Classification


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)

## 1. The Core Idea: Contrastive Learning

CLIP doesn't predict labels (like "cat" or "dog"). Instead, it learns to match **images** to their **captions**.

- **Positive pairs**: (Image of dog, "A photo of a dog") -> Maximize similarity
- **Negative pairs**: (Image of dog, "A photo of a car") -> Minimize similarity

We project both images and text into a shared **embedding space**.

## 2. Building the Dual Encoder

We need two separate neural networks:
1. **Image Encoder**: Turns pixels into a vector.
2. **Text Encoder**: Turns text into a vector.

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        # Simple CNN for demonstration
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.fc = nn.Linear(32 * 8 * 8, embedding_dim) # Assuming 32x32 input

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, 128)
        self.rnn = nn.LSTM(128, 64, batch_first=True)
        self.fc = nn.Linear(64, embedding_dim)

    def forward(self, x):
        x = self.embedding(x)
        _, (hidden, _) = self.rnn(x)
        x = hidden[-1]
        x = self.fc(x)
        return x

embed_dim = 64
image_model = ImageEncoder(embed_dim)
text_model = TextEncoder(vocab_size=1000, embedding_dim=embed_dim)
print("Encoders created!")

## 3. The CLIP Model

The CLIP model puts them together. Crucially, it normalizes the embeddings so we can compute Cosine Similarity easily.

In [None]:
class CLIP(nn.Module):
    def __init__(self, image_encoder, text_encoder):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.temperature = nn.Parameter(torch.ones([]) * 0.07)

    def forward(self, images, text):
        # 1. Get Embeddings
        image_features = self.image_encoder(images)
        text_features = self.text_encoder(text)

        # 2. Normalize (L2 norm)
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)

        # 3. Compute Similarity Matrix (Batch x Batch)
        # logits[i][j] = similarity between image i and text j
        logits = (image_features @ text_features.t()) / self.temperature
        
        return logits

model = CLIP(image_model, text_model)

## 4. Contrastive Loss

We want the diagonal of the similarity matrix to be high (image matches its own text) and off-diagonal to be low.

We use **Cross Entropy Loss** on both axes.

In [None]:
def contrastive_loss(logits):
    batch_size = logits.size(0)
    targets = torch.arange(batch_size) # [0, 1, 2, ...]
    
    # Loss for (Image -> Text)
    loss_i = F.cross_entropy(logits, targets)
    # Loss for (Text -> Image)
    loss_t = F.cross_entropy(logits.t(), targets)
    
    return (loss_i + loss_t) / 2

# Dummy Forward Pass
images = torch.randn(4, 3, 32, 32)
text = torch.randint(0, 1000, (4, 10))

logits = model(images, text)
loss = contrastive_loss(logits)

print(f"Logits shape: {logits.shape}")
print(f"Loss: {loss.item()}")

## 5. Zero-Shot Classification

Once trained, how do we classify an image?
1. Give the image to the Image Encoder.
2. Give a list of possible class names (e.g., "dog", "cat", "bird") to the Text Encoder.
3. The class with the highest similarity score wins!

This is why it's called "Zero-Shot" â€” it can classify categories it has never seen during training, as long as it understands the words.

## Key Takeaways

1. **Multimodal**: Combining vision and text in one shared space.
2. **Contrastive Loss**: Learning by comparing positive and negative pairs.
3. **Dual Encoders**: Two towers (Image & Text) that meet at the end.
4. **Zero-Shot**: Classifying without explicit training on those classes.