# CLIP (Contrastive Language-Image Pre-training) from Scratch

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/clip_multimodal.ipynb)

CLIP aligns text and images in a shared embedding space using contrastive learning.

Key Concepts:
1. **Image Encoder:** Converts images to vector $I_f$.
2. **Text Encoder:** Converts text to vector $T_f$.
3. **Contrastive Loss:** Maximizes dot product for correct (image, text) pairs and minimizes it for incorrect ones in a batch.

$$ Loss = \frac{1}{2} (CE(I \cdot T^T, labels) + CE(T \cdot I^T, labels)) $$

In [None]:
!pip install torch torchvision matplotlib transformers

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Encoders (Simplified)

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        # Simple ResNet-like or Convolutional backbone
        self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(128, embed_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.global_pool(x).flatten(1)
        return self.fc(x)

class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_seq_len=32):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, 128)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True),
            num_layers=2
        )
        self.fc = nn.Linear(128, embed_dim)
        
    def forward(self, x):
        x = self.embedding(x)
        x = self.transformer(x)
        # Avg pooling over sequence
        x = x.mean(dim=1)
        return self.fc(x)

## 2. CLIP Model

In [None]:
class CLIP(nn.Module):
    def __init__(self, vocab_size, embed_dim=256):
        super().__init__()
        self.image_encoder = ImageEncoder(embed_dim)
        self.text_encoder = TextEncoder(vocab_size, embed_dim)
        self.temperature = nn.Parameter(torch.ones([]) * 0.07)

    def forward(self, image, text):
        I_e = self.image_encoder(image)
        T_e = self.text_encoder(text)
        
        # Normalize embeddings
        I_e = I_e / I_e.norm(dim=-1, keepdim=True)
        T_e = T_e / T_e.norm(dim=-1, keepdim=True)
        
        # Scaled pairwise cosine similarities
        logits = (I_e @ T_e.t()) * torch.exp(self.temperature)
        return logits

## 3. Dummy Training Example

In [None]:
# Mock Data
vocab_size = 1000
batch_size = 8
embed_dim = 64

model = CLIP(vocab_size, embed_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

images = torch.randn(batch_size, 3, 224, 224).to(device)
text = torch.randint(0, vocab_size, (batch_size, 20)).to(device)

# Labels: The diagonal elements are the correct pairs (image[i] matches text[i])
labels = torch.arange(batch_size).to(device)

# Forward pass
logits_per_image = model(images, text)
logits_per_text = logits_per_image.t()

# Contrastive Loss
loss_i = F.cross_entropy(logits_per_image, labels)
loss_t = F.cross_entropy(logits_per_text, labels)
loss = (loss_i + loss_t) / 2

print(f"Initial Loss: {loss.item():.4f}")

optimizer.zero_grad()
loss.backward()
optimizer.step()

print("Training step complete. In a real scenario, use COCO or Flickr30k datasets.")