# SimCLR from Scratch

## This notebook implements SimCLR end‑to‑end

1. Build augmentations

2. Build the SimCLR model

3. Implement NT‑Xent loss

4. Train on STL‑10 (unlabeled)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, models
from torchvision.datasets import STL10
import matplotlib.pyplot as plt
from PIL import Image
import os
import random
import numpy as np

In [2]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    os.environ["PYTHONHASHSEED"] = str(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(41)

# 1. SimCLR Augmentations

In [3]:
class SimCLRAugmentations:
    def __init__(self, image_size=96):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=9),
            transforms.ToTensor(),
        ])


    def __call__(self, x):
        v1 = self.transform(x)
        v2 = self.transform(x)
        return v1, v2

In [4]:
class STL10SimCLR(STL10):
    def __init__(self, *args, simclr_transform=None, **kwargs):
        super().__init__(*args, transform=None, **kwargs)
        self.simclr_transform = simclr_transform

    def __getitem__(self, index):
        img, _ = super().__getitem__(index)  # raw PIL image
        v1, v2 = self.simclr_transform(img)
        return v1, v2

# 2. SimCLR Model
2.1 simclr model
2.2 projection head (mlp layer)

In [5]:
class SimCLR(nn.Module):
    def __init__(self, projection_dim=128):
        super().__init__()
        self.encoder = models.resnet18(weights=None)
        self.encoder.fc = nn.Identity()
        
        
        self.projection_head = ProjectionHead(
            input_dim=512,
            hidden_dim=512,
            output_dim=projection_dim
        )
    

    def forward(self, x):
        h = self.encoder(x)
        z = self.projection_head(h)
        return z

In [6]:
class ProjectionHead(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )
    
    
    def forward(self, x):
        x = self.net(x)
        return F.normalize(x, dim=1)

# 3. NT-XENT Loss

In [7]:
class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature

    def forward(self, z):
        z = F.normalize(z, dim=1)
        N = z.size(0)
        B = N // 2

        sim = torch.matmul(z, z.T) / self.temperature

        # mask self-similarity
        mask = torch.eye(N, device=z.device, dtype=torch.bool)
        sim.masked_fill_(mask, -1e9)

        # positive indices
        targets = torch.arange(N, device=z.device)
        targets[:B] += B
        targets[B:] -= B

        loss = F.cross_entropy(sim, targets)
        return loss

# 4. Train on SimCLR 
4.1 load dataset
4.2 training loop

In [8]:
transform = SimCLRAugmentations(image_size=96)

dataset = STL10SimCLR(
    root="./data",
    split="unlabeled",
    download=True,
    simclr_transform=transform
)

loader = DataLoader(
    dataset,
    batch_size=512,
    shuffle=True,
    num_workers=4,
    drop_last=True
)

100%|██████████| 2.64G/2.64G [05:45<00:00, 7.65MB/s] 


In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = SimCLR().to(device)
loss_fn = NTXentLoss(temperature=0.5)
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=1e-4
)

In [10]:
epochs = 50 # increase to 200 for real runs
best_loss = float("inf")
ckpt_path = "simclr_encoder.pth"

for epoch in range(epochs):
    total_loss = 0
    
    for v1, v2 in loader:
        v1 = v1.to(device)
        v2 = v2.to(device)
        
        z1 = model(v1)
        z2 = model(v2)
        z = torch.cat([z1, z2], dim=0)
        
        loss = loss_fn(z)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
        total_loss += loss.item()
        
    avg_loss = total_loss / len(loader)

    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.encoder.state_dict(), ckpt_path)
    
    print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f}")


torch.save(model.encoder.state_dict(), "simclr_encoder_final.pth")

Epoch 1/50 | Loss: 5.9972
Epoch 2/50 | Loss: 5.7838
Epoch 3/50 | Loss: 5.7180
Epoch 4/50 | Loss: 5.6743
Epoch 5/50 | Loss: 5.6409
Epoch 6/50 | Loss: 5.6197
Epoch 7/50 | Loss: 5.5949
Epoch 8/50 | Loss: 5.5820
Epoch 9/50 | Loss: 5.5759
Epoch 10/50 | Loss: 5.5629
Epoch 11/50 | Loss: 5.5518
Epoch 12/50 | Loss: 5.5437
Epoch 13/50 | Loss: 5.5374
Epoch 14/50 | Loss: 5.5289
Epoch 15/50 | Loss: 5.5247
Epoch 16/50 | Loss: 5.5175
Epoch 17/50 | Loss: 5.5106
Epoch 18/50 | Loss: 5.5067
Epoch 19/50 | Loss: 5.4998
Epoch 20/50 | Loss: 5.4998
Epoch 21/50 | Loss: 5.4940
Epoch 22/50 | Loss: 5.4904
Epoch 23/50 | Loss: 5.4872
Epoch 24/50 | Loss: 5.4808
Epoch 25/50 | Loss: 5.4799
Epoch 26/50 | Loss: 5.4769
Epoch 27/50 | Loss: 5.4729
Epoch 28/50 | Loss: 5.4703
Epoch 29/50 | Loss: 5.4649
Epoch 30/50 | Loss: 5.4629
Epoch 31/50 | Loss: 5.4601
Epoch 32/50 | Loss: 5.4592
Epoch 33/50 | Loss: 5.4550
Epoch 34/50 | Loss: 5.4529
Epoch 35/50 | Loss: 5.4495
Epoch 36/50 | Loss: 5.4466
Epoch 37/50 | Loss: 5.4455
Epoch 38/5

# 5. Validation
5.1 Linear probing
5.2 KNN testing

In [21]:
import torchvision.models as models
import torch.nn as nn

def ResNet18():
    model = models.resnet18(weights=None)
    model.fc = nn.Identity()   # remove classification head
    return model

In [22]:
encoder = ResNet18().cuda()
encoder.load_state_dict(torch.load("simclr_encoder_final.pth"))
encoder = encoder.cuda()

for param in encoder.parameters():
    param.requires_grad = False

encoder.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

### 5.1 Linear probing

In [23]:
linear_head = torch.nn.Linear(512, 10).cuda()

In [24]:
from torchvision.datasets import STL10
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(96),
    transforms.ToTensor()
])

train_set = STL10(
    root="./data",
    split="train",
    download=True,
    transform=transform
)

test_set = STL10(
    root="./data",
    split="test",
    download=True,
    transform=transform
)

train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=4)

In [25]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(linear_head.parameters(), lr=1e-3)


for epoch in range(20):
    linear_head.train()
    total_loss = 0

    for x, y in train_loader:
        x, y = x.cuda(), y.cuda()

        with torch.no_grad():
            features = encoder(x)

        logits = linear_head(features)
        loss = criterion(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch}: Loss = {total_loss/len(train_loader):.4f}")

Epoch 0: Loss = 1.4850
Epoch 1: Loss = 0.9440
Epoch 2: Loss = 0.8786
Epoch 3: Loss = 0.8488
Epoch 4: Loss = 0.8236
Epoch 5: Loss = 0.8115
Epoch 6: Loss = 0.7941
Epoch 7: Loss = 0.7853
Epoch 8: Loss = 0.7703
Epoch 9: Loss = 0.7624
Epoch 10: Loss = 0.7633
Epoch 11: Loss = 0.7495
Epoch 12: Loss = 0.7416
Epoch 13: Loss = 0.7349
Epoch 14: Loss = 0.7280
Epoch 15: Loss = 0.7296
Epoch 16: Loss = 0.7195
Epoch 17: Loss = 0.7147
Epoch 18: Loss = 0.7106
Epoch 19: Loss = 0.7112


In [26]:
linear_head.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, y in test_loader:
        x, y = x.cuda(), y.cuda()
        features = encoder(x)
        logits = linear_head(features)
        preds = logits.argmax(dim=1)

        correct += (preds == y).sum().item()
        total += y.size(0)

acc = correct / total * 100
print(f"Linear Probe Accuracy: {acc:.2f}%")

Linear Probe Accuracy: 71.50%


### 5.2 KNN testing

In [27]:
import numpy as np

def extract_features(loader):
    feats = []
    labels = []

    with torch.no_grad():
        for x, y in loader:
            x = x.cuda()
            f = encoder(x)
            feats.append(f.cpu())
            labels.append(y)

    return torch.cat(feats), torch.cat(labels)

train_feats, train_labels = extract_features(train_loader)
test_feats, test_labels = extract_features(test_loader)

In [28]:
train_feats = torch.nn.functional.normalize(train_feats, dim=1)
test_feats = torch.nn.functional.normalize(test_feats, dim=1)

In [29]:
def knn_accuracy(train_feats, train_labels, test_feats, test_labels, k=20):
    correct = 0

    for i in range(test_feats.size(0)):
        sim = torch.matmul(train_feats, test_feats[i])
        topk = sim.topk(k).indices
        pred = train_labels[topk].mode()[0]

        correct += (pred == test_labels[i]).item()

    return correct / test_feats.size(0) * 100

In [30]:
acc_knn = knn_accuracy(
    train_feats, train_labels,
    test_feats, test_labels,
    k=20
)

print(f"k-NN Accuracy (k=20): {acc_knn:.2f}%")

k-NN Accuracy (k=20): 67.29%
