# Contrastive Learning with SimCLR

In this notebook, we’ll explore **SimCLR (Simple Framework for Contrastive Learning of Visual Representations)** : one of the most influential approaches in **self-supervised learning (SSL)**.

The key idea: *learn useful image representations by comparing similar and dissimilar image pairs : without using labels.*

## 📘 1. What is Contrastive Learning?

Contrastive Learning aims to bring **similar (positive)** pairs close together and push **dissimilar (negative)** pairs apart in the latent space.

SimCLR is a framework that makes this idea practical and scalable for vision tasks.

## ⚙️ 2. Core Components of SimCLR

SimCLR has **four main components:**

1. **Data Augmentation:** Create two different augmented versions of the same image.
2. **Encoder Network:** A CNN (often ResNet) extracts representations.
3. **Projection Head:** A small MLP that maps embeddings to a contrastive space.
4. **Contrastive Loss (NT-Xent):** Encourages similar views to have high similarity.

## 🔧 3. Imports and Setup

In [ ]:
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

## 🖼️ 4. Data Augmentation for Contrastive Learning

In [ ]:
transform_simclr = T.Compose([
    T.RandomResizedCrop(size=32),
    T.RandomHorizontalFlip(),
    T.RandomApply([T.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    T.RandomGrayscale(p=0.2),
    T.ToTensor()
])

dataset = CIFAR10(root='./data', download=True, transform=transform_simclr)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

## 🧩 5. Building the SimCLR Model

In [ ]:
class SimCLR(nn.Module):
    def __init__(self, base_model='resnet18', projection_dim=128):
        super().__init__()
        self.encoder = models.__dict__[base_model](pretrained=False)
        self.encoder.fc = nn.Identity()
        
        # Projection head
        self.projector = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )
    
    def forward(self, x):
        h = self.encoder(x)
        z = self.projector(h)
        z = F.normalize(z, dim=1)
        return h, z

## 📉 6. Contrastive Loss (NT-Xent)

In [ ]:
def nt_xent_loss(z_i, z_j, temperature=0.5):
    batch_size = z_i.shape[0]
    z = torch.cat([z_i, z_j], dim=0)
    sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)

    # Mask self-similarity
    mask = torch.eye(2 * batch_size, dtype=torch.bool).to(z.device)
    sim = sim[~mask].view(2 * batch_size, -1)

    positives = torch.cat([torch.diag(sim, batch_size), torch.diag(sim, -batch_size)])
    nominator = torch.exp(positives / temperature)
    denominator = torch.sum(torch.exp(sim / temperature), dim=1)
    loss = -torch.log(nominator / denominator).mean()
    return loss

## 7. Training Loop (Simplified Example)

In [ ]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimCLR().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1):  # keep it short for demo
    for (x, _) in dataloader:
        x_i = transform_simclr(x)
        x_j = transform_simclr(x)

        x_i, x_j = x_i.to(device), x_j.to(device)
        _, z_i = model(x_i)
        _, z_j = model(x_j)

        loss = nt_xent_loss(z_i, z_j)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}] Loss: {loss.item():.4f}")

## 8. Key Takeaways

- SimCLR learns without labels by comparing augmented views.
- The **contrastive loss** encourages similar representations for the same image.
- Trained embeddings can later be fine tuned for classification tasks.

Next → `03-Pretext_Tasks_in_Self_Supervised_Learning.ipynb`