# SimCLR‑Lite (Self‑Supervised Contrastive Learning) on CIFAR‑10  
**Portfolio notebook (Deep Learning / Representation Learning)**

This notebook implements a minimal, research‑style **SimCLR** pipeline:
1. Self‑supervised pretraining (NT‑Xent contrastive loss)
2. Linear probing (freeze encoder, train a linear classifier)
3. t‑SNE visualization of learned representations

Tip for CPU: reduce epochs, batch size, and number of t‑SNE samples.


## 0) Install & imports
If you're on Colab, you can run the install cell. If you already have PyTorch, skip it.


In [None]:
# (Optional) Install dependencies (useful on Google Colab)
# !pip -q install torch torchvision tqdm numpy matplotlib scikit-learn


In [None]:
import os, time, json, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

print("torch:", torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)


## 1) Reproducibility & config


In [None]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Training config (safe defaults)
EPOCHS_SSL = 10          # increase to 50-200 for stronger results
BATCH_SIZE = 256         # CPU: try 64 or 128
LR_SSL = 3e-4
TEMPERATURE = 0.2
PROJ_DIM = 128

EPOCHS_LINEAR = 15       # increase to 30-100 for better linear probe
LR_LINEAR = 0.1

NUM_WORKERS = 2
PIN_MEMORY = True if device == "cuda" else False

OUT_DIR = "outputs_notebook"
os.makedirs(OUT_DIR, exist_ok=True)

print("Config loaded.")


## 2) SimCLR augmentations (two random views)
SimCLR learns by comparing two different augmentations of the same image.


In [None]:
class TwoCropsTransform:
    # Create two random augmented views of the same image
    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return q, k

def simclr_augment(image_size=32):
    # SimCLR-style augmentations tuned for CIFAR-10 size
    color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
    return transforms.Compose([
        transforms.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomApply([color_jitter], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])


## 3) Dataset & dataloader (SSL)


In [None]:
ssl_transform = TwoCropsTransform(simclr_augment(32))
train_ssl_ds = CIFAR10(root="data", train=True, download=True, transform=ssl_transform)

train_ssl_dl = DataLoader(
    train_ssl_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    drop_last=True,
)

print("SSL train batches:", len(train_ssl_dl))


## 4) Model: ResNet18 encoder + projection head
We use a CIFAR-friendly ResNet18 (3×3 conv, no maxpool) + MLP projection head.


In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, hidden_dim=512, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x):
        return self.net(x)

class SimCLRLite(nn.Module):
    def __init__(self, proj_dim=128):
        super().__init__()
        backbone = resnet18(weights=None)
        backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        backbone.maxpool = nn.Identity()
        feat_dim = backbone.fc.in_features
        backbone.fc = nn.Identity()

        self.backbone = backbone
        self.projector = ProjectionHead(feat_dim, hidden_dim=512, out_dim=proj_dim)

    def encode(self, x):
        return self.backbone(x)

    def forward(self, x):
        h = self.encode(x)
        z = self.projector(h)
        z = F.normalize(z, dim=1)
        return h, z

model = SimCLRLite(proj_dim=PROJ_DIM).to(device)
print("Model ready.")


## 5) NT-Xent loss (SimCLR objective)


In [None]:
def nt_xent_loss(z1, z2, temperature=0.2):
    # z1, z2: [B, D] normalized
    batch_size = z1.size(0)
    z = torch.cat([z1, z2], dim=0)                 # [2B, D]
    sim = torch.mm(z, z.t()) / temperature         # [2B, 2B]

    # mask self similarity
    mask = torch.eye(2 * batch_size, device=z.device).bool()
    sim = sim.masked_fill(mask, float("-inf"))

    # positives are diagonal offsets by B
    pos = torch.cat([torch.diag(sim, batch_size), torch.diag(sim, -batch_size)], dim=0)  # [2B]

    loss = -pos + torch.logsumexp(sim, dim=1)
    return loss.mean()


## 6) Self-supervised training


In [None]:
optimizer = optim.AdamW(model.parameters(), lr=LR_SSL, weight_decay=1e-4)

ssl_losses = []
model.train()

start = time.time()
for epoch in range(1, EPOCHS_SSL + 1):
    running = 0.0
    pbar = tqdm(train_ssl_dl, desc=f"SSL Epoch {epoch}/{EPOCHS_SSL}")
    for (x1, x2), _ in pbar:
        x1 = x1.to(device, non_blocking=True)
        x2 = x2.to(device, non_blocking=True)

        _, z1 = model(x1)
        _, z2 = model(x2)
        loss = nt_xent_loss(z1, z2, temperature=TEMPERATURE)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        running += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    avg = running / len(train_ssl_dl)
    ssl_losses.append(avg)
    print(f"Epoch {epoch}: avg_ssl_loss={avg:.4f}")

elapsed = time.time() - start
print(f"SSL training done in {elapsed/60:.1f} min")

ckpt_path = os.path.join(OUT_DIR, "simclr_lite_cifar10.pt")
torch.save({"model": model.state_dict(), "config": {
    "EPOCHS_SSL": EPOCHS_SSL,
    "BATCH_SIZE": BATCH_SIZE,
    "LR_SSL": LR_SSL,
    "TEMPERATURE": TEMPERATURE,
    "PROJ_DIM": PROJ_DIM,
    "SEED": SEED,
}}, ckpt_path)
print("Saved checkpoint:", ckpt_path)


### Plot SSL loss


In [None]:
plt.figure(figsize=(6,4))
plt.plot(range(1, len(ssl_losses)+1), ssl_losses)
plt.xlabel("Epoch")
plt.ylabel("NT-Xent Loss")
plt.title("SimCLR SSL Training Loss")
plt.grid(True)
plt.show()


## 7) Linear probe (freeze encoder)
We train a linear classifier on frozen features to measure representation quality.


In [None]:
tf_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
tf_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

train_sup_ds = CIFAR10(root="data", train=True, download=True, transform=tf_train)
test_sup_ds  = CIFAR10(root="data", train=False, download=True, transform=tf_test)

train_sup_dl = DataLoader(train_sup_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
test_sup_dl  = DataLoader(test_sup_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

model.eval()
for p in model.parameters():
    p.requires_grad = False

with torch.no_grad():
    x, _ = next(iter(train_sup_dl))
    x = x.to(device)
    feat_dim = model.encode(x[:2]).shape[1]

print("Feature dim:", feat_dim)


In [None]:
class LinearClassifier(nn.Module):
    def __init__(self, in_dim, num_classes=10):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

def acc(logits, y):
    return (logits.argmax(1) == y).float().mean().item()

clf = LinearClassifier(feat_dim, 10).to(device)
criterion = nn.CrossEntropyLoss()
opt = optim.SGD(clf.parameters(), lr=LR_LINEAR, momentum=0.9, weight_decay=1e-4)
sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS_LINEAR)

train_acc_hist, test_acc_hist, train_loss_hist = [], [], []

for epoch in range(1, EPOCHS_LINEAR + 1):
    clf.train()
    run_loss, run_acc = 0.0, 0.0
    pbar = tqdm(train_sup_dl, desc=f"Linear Epoch {epoch}/{EPOCHS_LINEAR}")
    for x, y in pbar:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        with torch.no_grad():
            feat = model.encode(x)

        logits = clf(feat)
        loss = criterion(logits, y)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        a = acc(logits, y)
        run_loss += loss.item()
        run_acc += a
        pbar.set_postfix(loss=f"{loss.item():.4f}", acc=f"{a*100:.1f}%")

    sched.step()

    train_loss = run_loss / len(train_sup_dl)
    train_acc = run_acc / len(train_sup_dl)

    clf.eval()
    test_acc_sum = 0.0
    with torch.no_grad():
        for x, y in test_sup_dl:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            feat = model.encode(x)
            logits = clf(feat)
            test_acc_sum += acc(logits, y)
    test_acc = test_acc_sum / len(test_sup_dl)

    train_loss_hist.append(train_loss)
    train_acc_hist.append(train_acc)
    test_acc_hist.append(test_acc)

    print(f"Epoch {epoch}: train_loss={train_loss:.4f}, train_acc={train_acc*100:.2f}%, test_acc={test_acc*100:.2f}%")

with open(os.path.join(OUT_DIR, "linear_probe_log.json"), "w", encoding="utf-8") as f:
    json.dump({
        "train_loss": train_loss_hist,
        "train_acc": train_acc_hist,
        "test_acc": test_acc_hist
    }, f, indent=2)

print("Saved linear probe log:", os.path.join(OUT_DIR, "linear_probe_log.json"))


### Plot Linear Probe accuracy


In [None]:
plt.figure(figsize=(6,4))
plt.plot(range(1, len(train_acc_hist)+1), [a*100 for a in train_acc_hist], label="Train Acc")
plt.plot(range(1, len(test_acc_hist)+1), [a*100 for a in test_acc_hist], label="Test Acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.title("Linear Probe Accuracy")
plt.grid(True)
plt.legend()
plt.show()


## 8) t-SNE visualization


In [None]:
NUM_SAMPLES_TSNE = 2000  # CPU: 500-1500 | GPU: 5000-10000

tf_plain = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

tsne_ds = CIFAR10(root="data", train=False, download=True, transform=tf_plain)
tsne_dl = DataLoader(tsne_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

features, labels = [], []
collected = 0

model.eval()
with torch.no_grad():
    for x, y in tqdm(tsne_dl, desc="Collecting features"):
        x = x.to(device)
        h = model.encode(x).detach().cpu().numpy()
        features.append(h)
        labels.append(y.numpy())
        collected += x.size(0)
        if collected >= NUM_SAMPLES_TSNE:
            break

X = np.concatenate(features, axis=0)[:NUM_SAMPLES_TSNE]
y = np.concatenate(labels, axis=0)[:NUM_SAMPLES_TSNE]
print("X shape:", X.shape)


In [None]:
tsne = TSNE(n_components=2, init="pca", learning_rate="auto", perplexity=30, random_state=SEED)
Z = tsne.fit_transform(X)

plt.figure(figsize=(8,6))
plt.scatter(Z[:,0], Z[:,1], c=y, s=6, alpha=0.8)
plt.title("CIFAR-10 Representations (t-SNE)")
plt.xlabel("Dim-1")
plt.ylabel("Dim-2")
plt.grid(True)
plt.show()

tsne_path = os.path.join(OUT_DIR, "tsne.png")
plt.figure(figsize=(8,6))
plt.scatter(Z[:,0], Z[:,1], c=y, s=6, alpha=0.8)
plt.title("CIFAR-10 Representations (t-SNE)")
plt.xlabel("Dim-1")
plt.ylabel("Dim-2")
plt.grid(True)
plt.savefig(tsne_path, dpi=200, bbox_inches="tight")
print("Saved:", tsne_path)


## 9) GitHub README suggestions
- Explain NT-Xent and the role of augmentations  
- Report final linear probe test accuracy  
- Add `outputs_notebook/tsne.png` as a figure
