# SimCLR from Scratch

## This notebook implements SimCLR end‑to‑end

1. Build augmentations

2. Build the SimCLR model

3. Implement NT‑Xent loss

4. Train on STL‑10 (unlabeled)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, models
from torchvision.datasets import STL10
import matplotlib.pyplot as plt
from PIL import Image

# 1. SimCLR Augmentations

In [18]:
class SimCLRAugmentations:
    def __init__(self, image_size=96):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=9),
            transforms.ToTensor(),
        ])


    def __call__(self, x):
        v1 = self.transform(x)
        v2 = self.transform(x)
        return v1, v2

In [19]:
class STL10SimCLR(STL10):
    def __init__(self, *args, simclr_transform=None, **kwargs):
        super().__init__(*args, transform=None, **kwargs)
        self.simclr_transform = simclr_transform

    def __getitem__(self, index):
        img, _ = super().__getitem__(index)  # raw PIL image
        v1, v2 = self.simclr_transform(img)
        return v1, v2


# 2. SimCLR Model
2.1 simclr model
2.2 projection head (mlp layer)

In [26]:
class SimCLR(nn.Module):
    def __init__(self, projection_dim=128):
        super().__init__()
        self.encoder = models.resnet18(weights=None)
        self.encoder.fc = nn.Identity()
        
        
        self.projection_head = ProjectionHead(
            input_dim=512,
            hidden_dim=512,
            output_dim=projection_dim
        )
    

    def forward(self, x):
        h = self.encoder(x)
        z = self.projection_head(h)
        return z

In [21]:
class ProjectionHead(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )
    
    
    def forward(self, x):
        x = self.net(x)
        return F.normalize(x, dim=1)

# 3. NT-XENT Loss

In [22]:
class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature
    
    
    def forward(self, z):
        batch_size = z.shape[0] // 2
        sim = torch.matmul(z, z.T) / self.temperature
        
        
        mask = torch.eye(2 * batch_size, device=z.device).bool()
        sim.masked_fill_(mask, -9e15)
        
        
        positives = torch.cat([
        torch.diag(sim, batch_size),
        torch.diag(sim, -batch_size)
        ])
        
        
        labels = torch.zeros(2 * batch_size, dtype=torch.long, device=z.device)
        loss = F.cross_entropy(sim, labels)
        return loss

# 4. Train on SimCLR 
4.1 load dataset
4.2 training loop

In [23]:
transform = SimCLRAugmentations(image_size=96)

dataset = STL10SimCLR(
    root="./data",
    split="unlabeled",
    download=True,
    simclr_transform=transform
)

loader = DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    num_workers=2,
    drop_last=True
)

In [24]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = SimCLR().to(device)
loss_fn = NTXentLoss(temperature=0.5)
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=1e-4
)

In [None]:
epochs = 200


for epoch in range(epochs):
    total_loss = 0
    
    for v1, v2 in loader:
        v1 = v1.to(device)
        v2 = v2.to(device)
        
        z1 = model(v1)
        z2 = model(v2)
        z = torch.cat([z1, z2], dim=0)
        
        loss = loss_fn(z)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
        total_loss += loss.item()
        
    print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(loader):.4f}")

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79f39a17e700>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x79f39a17e700>^^
^Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^    ^self._shutdown_workers()^
^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^    ^if w.is_alive():^
^ ^ ^ ^ ^

NotImplementedError: Module [SimCLR] is missing the required "forward" function