In [1]:
#!/usr/bin/env python
"""
MNIST SimCLR vs Supervised — Noisy-Label Robustness Experiment (PyTorch)

Single-file, reproducible pipeline to:
  1) Train a baseline supervised classifier on MNIST (clean or noisy labels)
  2) Pretrain with SimCLR on MNIST (no labels), then fine-tune supervised
  3) Compare results side-by-side and save metrics to CSV

Usage examples
--------------
# 1) Baseline supervised on clean MNIST
python mnist_ssl_noise_experiment.py --device cuda --mode baseline --epochs 10

# 2) SimCLR pretrain + fine-tune on clean MNIST
python mnist_ssl_noise_experiment.py --device cuda --mode simclr_then_finetune --pretrain-epochs 50 --epochs 10

# 3) Baseline supervised with 40% symmetric label noise on train set
python mnist_ssl_noise_experiment.py --device cuda --mode baseline --noise-rate 0.40 --epochs 10

# 4) SimCLR pretrain + fine-tune with 40% label-noisy training set
python mnist_ssl_noise_experiment.py --device cuda \
  --mode simclr_then_finetune --noise-rate 0.40 --pretrain-epochs 50 --epochs 10

Outputs
-------
- runs/metrics.csv: one row per run with settings and test accuracy
- runs/checkpoints/: model checkpoints
- runs/logs.txt: basic text log

Notes
-----
- SimCLR uses unlabeled views of the *same* MNIST images (train split), independent
  of any label corruption you may introduce for the downstream fine-tuning.
- We slightly tweak ResNet-18 for 1-channel 28x28 images.
- Augmentations are adapted for grayscale & small images.

"""
from __future__ import annotations
import os
import csv
import math
import time
import random
import argparse
from dataclasses import dataclass, asdict
from typing import Tuple, Optional
from types import SimpleNamespace

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models




# --------------------
# Utilities
# --------------------

def set_seed(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)


# --------------------
# Data
# --------------------

class NoisyLabelWrapper(Dataset):
    """Wrap a torchvision dataset to inject symmetric label noise on-the-fly.

    The noisy labels are generated once at init for reproducibility.
    """
    def __init__(self, base: Dataset, noise_rate: float, num_classes: int, seed: int = 42):
        self.base = base
        self.noise_rate = noise_rate
        self.num_classes = num_classes
        g = random.Random(seed)
        # precompute noisy labels
        self.targets = []
        for _, y in base:
            if g.random() < noise_rate:
                ny = g.randrange(num_classes)
                if ny == y:
                    ny = (ny + 1) % num_classes
                self.targets.append(ny)
            else:
                self.targets.append(y)
    def __len__(self): return len(self.base)
    def __getitem__(self, i):
        x, _ = self.base[i]
        return x, self.targets[i]



class SimCLRAugment:
    """Augmentations for SimCLR adapted to MNIST (1x28x28).

    We resize to 32x32 for a more comfortable ReNet-18 receptive field.
    """
    def __init__(self, size: int = 32):
        self.transform = transforms.Compose([
            transforms.Resize(size),
            transforms.RandomResizedCrop(size=size, scale=(0.6, 1.0)),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
        ])

    def __call__(self, x):
        xi = self.transform(x)
        xj = self.transform(x)
        return xi, xj


def get_dataloaders(batch_size: int, noise_rate: float, num_workers: int,
                    seed: int, for_simclr: bool, dataset: str):
    """
    Returns train_loader, test_loader, ssl_dataset (if for_simclr=True).
    """
    if dataset.lower() in ['cifar10', 'cifar100']:
        mean, std, num_classes = get_cifar_norm(dataset)
        sup_train_tf = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        test_tf = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        root = os.path.join("data", dataset.lower())

        if dataset.lower() == 'cifar10':
            base_train = datasets.CIFAR10(root=root, train=True, download=True, transform=sup_train_tf)
            test = datasets.CIFAR10(root=root, train=False, download=True, transform=test_tf)
        else:
            base_train = datasets.CIFAR100(root=root, train=True, download=True, transform=sup_train_tf)
            test = datasets.CIFAR100(root=root, train=False, download=True, transform=test_tf)

        train = NoisyLabelWrapper(base_train, noise_rate=noise_rate, num_classes=num_classes, seed=seed) \
                if noise_rate > 0 else base_train

        train_loader = DataLoader(train, batch_size=batch_size, shuffle=True,
                                  num_workers=num_workers, pin_memory=True)
        test_loader = DataLoader(test, batch_size=batch_size, shuffle=False,
                                 num_workers=num_workers, pin_memory=True)

        ssl_ds = None
        if for_simclr:
            # Important: SimCLR pipeline should NOT include normalization
            tf = SimCLRAugmentCIFAR(size=32)
            if dataset.lower() == 'cifar10':
                ssl_ds = datasets.CIFAR10(root=root, train=True, download=True, transform=tf)
            else:
                ssl_ds = datasets.CIFAR100(root=root, train=True, download=True, transform=tf)

        return train_loader, test_loader, ssl_ds

    elif dataset.lower() == 'mnist':
        # keep your existing MNIST path if you still want MNIST runs
        ...
    else:
        raise ValueError("Unsupported dataset")


# --------------------
# Models
# --------------------
class ResNet18Small(nn.Module):
    """
    ResNet-18 adapted for small images.
    - conv1: kernel=3, stride=1, padding=1
    - remove first maxpool
    - in_channels selectable (1 for MNIST, 3 for CIFAR)
    """
    def __init__(self, num_classes: int = 10, in_channels: int = 3):
        super().__init__()
        base = models.resnet18(weights=None)
        base.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        base.maxpool = nn.Identity()
        self.features = nn.Sequential(*list(base.children())[:-1])  # global avgpool yields [B,512,1,1]
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x, return_feat=False):
        x = self.features(x)
        x = torch.flatten(x, 1)
        if return_feat:
            return x
        return self.fc(x)


class ProjectionHead(nn.Module):
    def __init__(self, in_dim: int = 512, hidden_dim: int = 512, out_dim: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        z = self.net(x)
        z = F.normalize(z, dim=1)
        return z


class SimCLR(nn.Module):
    def __init__(self, encoder: ResNet18Small, proj: ProjectionHead):
        super().__init__()
        self.encoder = encoder
        self.proj = proj

    def forward(self, x):
        feat = self.encoder(x, return_feat=True)
        z = self.proj(feat)
        return z


# --------------------
# Losses
# --------------------
class NTXent(nn.Module):
    def __init__(self, temperature: float = 0.5):
        super().__init__()
        self.t = temperature

    def forward(self, z_i, z_j):
        # z_i, z_j: [B, D] normalized
        B, D = z_i.shape
        z = torch.cat([z_i, z_j], dim=0)  # [2B, D]
        sim = torch.matmul(z, z.T)  # cosine sim since normalized
        # mask out self-similarity
        mask = torch.eye(2 * B, dtype=torch.bool, device=z.device)
        sim = sim.masked_fill(mask, -9e15)
        # positives: for each i in [0..B-1], positive is i+B; for i in [B..2B-1], positive is i-B
        pos = torch.cat([torch.arange(B, 2 * B), torch.arange(0, B)]).to(z.device)
        logits = sim / self.t
        labels = pos
        loss = F.cross_entropy(logits, labels)
        return loss


# --------------------
# Training / Evaluation
# --------------------
@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: str) -> Tuple[float, float]:
    model.eval()
    correct, total, loss_accum = 0, 0, 0.0
    ce = nn.CrossEntropyLoss()
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = ce(logits, y)
        loss_accum += loss.item() * x.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    return correct / total, loss_accum / total


def train_supervised(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader,
                     device: str, epochs: int = 10, lr: float = 1e-3, wd: float = 1e-4,
                     save_path: Optional[str] = None):
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    ce = nn.CrossEntropyLoss()
    best_acc = 0.0
    for ep in range(1, epochs + 1):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            logits = model(x)
            loss = ce(logits, y)
            loss.backward()
            opt.step()
        acc, tloss = evaluate(model, test_loader, device)
        print(f"[Supervised] Epoch {ep:03d} | Test Acc={acc*100:.2f}% | Test Loss={tloss:.4f}")
        if save_path and acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), save_path)
    return best_acc


def pretrain_simclr(ssl_model: SimCLR, ssl_loader: DataLoader, device: str, epochs: int = 50,
                    lr: float = 1e-3, wd: float = 1e-4, temp: float = 0.5):
    loss_fn = NTXent(temperature=temp)
    opt = torch.optim.Adam(ssl_model.parameters(), lr=lr, weight_decay=wd)
    ssl_model.train()
    for ep in range(1, epochs + 1):
        loss_accum, n = 0.0, 0
        for (x_i, x_j), _ in ssl_loader:
            x_i, x_j = x_i.to(device), x_j.to(device)
            opt.zero_grad()
            z_i = ssl_model(x_i)
            z_j = ssl_model(x_j)
            loss = loss_fn(z_i, z_j)
            loss.backward()
            opt.step()
            bs = x_i.size(0)
            loss_accum += loss.item() * bs
            n += bs
        print(f"[SimCLR] Epoch {ep:03d} | Loss={loss_accum / max(n,1):.4f}")


# --------------------
# Orchestration
# --------------------
@dataclass
class RunResult:
    mode: str
    noise_rate: float
    pretrain_epochs: int
    finetune_epochs: int
    test_acc: float
    timestamp: float


def save_metrics_row(path_csv: str, row: RunResult):
    header = list(asdict(row).keys())
    exists = os.path.exists(path_csv)
    with open(path_csv, 'a', newline='') as f:
        w = csv.DictWriter(f, fieldnames=header)
        if not exists:
            w.writeheader()
        w.writerow(asdict(row))



def build_ssl_loader(batch_size: int, num_workers: int, dataset: str) -> DataLoader:
  #CIFAR version
    if dataset.lower() in ['cifar10', 'cifar100']:
        tf = SimCLRAugmentCIFAR(size=32)
        root = os.path.join("data", dataset.lower())
        base = datasets.CIFAR10 if dataset.lower()=='cifar10' else datasets.CIFAR100
        ds = base(root=root, train=True, download=True, transform=tf)

        class _Wrap(Dataset):
            def __init__(self, ds): self.ds = ds
            def __len__(self): return len(self.ds)
            def __getitem__(self, idx):
                (xi, xj), _ = self.ds[idx]
                return (xi, xj), 0

        ssl = _Wrap(ds)
        return DataLoader(ssl, batch_size=batch_size, shuffle=True,
                          num_workers=num_workers, drop_last=True, pin_memory=True)
    else:
        #MNIST version
      tf = SimCLRAugment(size=32)
      ssl_train = datasets.MNIST(root=os.path.join("data","mnist"), train=True, download=True, transform=tf)
      # The transform already returns two views, but torchvision MNIST expects a single tensor.
      # Wrap to conform to ((xi,xj), _)
      class _Wrap(Dataset):
          def __init__(self, ds): self.ds = ds
          def __len__(self): return len(self.ds)
          def __getitem__(self, idx):
              (xi, xj), _ = self.ds[idx]
              return (xi, xj), 0
      ssl_train = _Wrap(ssl_train)
      ssl_loader = DataLoader(ssl_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True, pin_memory=True)
      return ssl_loader

def save_encoder(encoder: nn.Module, path: str):
    ensure_dir(os.path.dirname(path))
    torch.save({'features': encoder.features.state_dict()}, path)

def load_encoder_into_classifier(path: str, clf: nn.Module):
    ckpt = torch.load(path, map_location='cpu')
    state = ckpt.get('features', ckpt)  # backward-compat if raw dict saved
    missing, unexpected = clf.features.load_state_dict(state, strict=False)
    print(f"Loaded encoder from {path}. Missing={missing}, Unexpected={unexpected}")
    return clf


def run_ex(args):
    """Extended runner with reusable pretrained encoder support.
    Modes: 'baseline', 'simclr_then_finetune', 'finetune_from_pretrained'
    """
    set_seed(args.seed)
    device = args.device if torch.cuda.is_available() and args.device.startswith('cuda') else 'cpu'
    ensure_dir('runs'); ensure_dir('runs/checkpoints')

    num_classes = 100 if args.dataset.lower()=='cifar100' else 10
    in_ch = 3  # CIFAR is RGB


    # Supervised data (label noise applied here only)
    train_loader, test_loader, _ = get_dataloaders(
        batch_size=args.batch_size,
        noise_rate=args.noise_rate,
        num_workers=args.workers,
        seed=args.seed,
        for_simclr=False,
        dataset=args.dataset
    )

    if args.mode == 'baseline':
        model = ResNet18Small(num_classes=10).to(device)
        ckpt = os.path.join('runs', 'checkpoints', f"baseline_noise{args.noise_rate:.2f}.pt")
        acc = train_supervised(model, train_loader, test_loader, device,
                               epochs=args.epochs, lr=args.lr, wd=args.wd,
                               save_path=ckpt)
        print(f"Final Test Acc (baseline, noise={args.noise_rate:.2f}): {acc*100:.2f}%")
        save_metrics_row(os.path.join('runs','metrics.csv'), RunResult(
            mode='baseline', noise_rate=args.noise_rate,
            pretrain_epochs=0, finetune_epochs=args.epochs,
            test_acc=acc, timestamp=time.time()
        ))
        return

    if args.mode == 'simclr_then_finetune':
        # Pretrain on unlabeled data
        ssl_loader = build_ssl_loader(args.batch_size_ssl, args.workers, args.dataset)
        encoder = ResNet18Small(num_classes=num_classes, in_channels=in_ch).to(device)
        proj = ProjectionHead(in_dim=512, hidden_dim=512, out_dim=args.proj_dim).to(device)
        ssl_model = SimCLR(encoder, proj).to(device)

        pretrain_simclr(ssl_model, ssl_loader, device,
                        epochs=args.pretrain_epochs, lr=args.pretrain_lr,
                        wd=args.pretrain_wd, temp=args.temperature)
        # Save encoder for reuse
        if getattr(args, 'save_pretrained_encoder_path', ''):
            save_encoder(encoder, args.save_pretrained_encoder_path)
            print(f"Saved pretrained encoder to: {args.save_pretrained_encoder_path}")
        # Fine-tune
        clf = ResNet18Small(num_classes=num_classes, in_channels=in_ch).to(device)
        clf.features.load_state_dict(encoder.features.state_dict())
        if getattr(args, 'reset_classifier_head', True):
            clf.fc.reset_parameters()
        if args.freeze_backbone:
            for p in clf.features.parameters():
                p.requires_grad = False
        ckpt = os.path.join('runs', 'checkpoints',
                            f"simclr_noise{args.noise_rate:.2f}_freeze{int(args.freeze_backbone)}.pt")
        acc = train_supervised(clf, train_loader, test_loader, device,
                               epochs=args.epochs, lr=args.lr, wd=args.wd,
                               save_path=ckpt)
        print(f"Final Test Acc (SimCLR→FT, noise={args.noise_rate:.2f}, freeze={args.freeze_backbone}): {acc*100:.2f}%")
        save_metrics_row(os.path.join('runs','metrics.csv'), RunResult(
            mode=f"simclr_then_finetune_freeze{int(args.freeze_backbone)}", noise_rate=args.noise_rate,
            pretrain_epochs=args.pretrain_epochs, finetune_epochs=args.epochs,
            test_acc=acc, timestamp=time.time()
        ))
        return

    if args.mode == 'finetune_from_pretrained':
        if not getattr(args, 'pretrained_encoder_path', '') or not os.path.exists(args.pretrained_encoder_path):
            raise FileNotFoundError("Set args.pretrained_encoder_path to a valid encoder .pt saved from SimCLR pretrain.")
        clf = ResNet18Small(num_classes=num_classes, in_channels=in_ch).to(device)
        load_encoder_into_classifier(args.pretrained_encoder_path, clf)
        if getattr(args, 'reset_classifier_head', True):
            clf.fc.reset_parameters()
        if args.freeze_backbone:
            for p in clf.features.parameters():
                p.requires_grad = False
        ckpt = os.path.join('runs', 'checkpoints',
                            f"ft_from_pretrained_noise{args.noise_rate:.2f}_freeze{int(args.freeze_backbone)}.pt")
        acc = train_supervised(clf, train_loader, test_loader, device,
                               epochs=args.epochs, lr=args.lr, wd=args.wd,
                               save_path=ckpt)
        print(f"Final Test Acc (FT from pretrained, noise={args.noise_rate:.2f}, freeze={args.freeze_backbone}): {acc*100:.2f}%")
        save_metrics_row(os.path.join('runs','metrics.csv'), RunResult(
            mode=f"finetune_from_pretrained_freeze{int(args.freeze_backbone)}", noise_rate=args.noise_rate,
            pretrain_epochs=0, finetune_epochs=args.epochs,
            test_acc=acc, timestamp=time.time()
        ))
        return

    raise ValueError(f"Unknown mode: {args.mode}")


In [2]:
from torchvision import datasets, transforms

CIFAR10_MEAN, CIFAR10_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
CIFAR100_MEAN, CIFAR100_STD = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)

def get_cifar_norm(dataset):
    if dataset.lower() == 'cifar10':
        return CIFAR10_MEAN, CIFAR10_STD, 10
    elif dataset.lower() == 'cifar100':
        return CIFAR100_MEAN, CIFAR100_STD, 100
    else:
        raise ValueError("dataset must be 'cifar10' or 'cifar100'")

class SimCLRAugmentCIFAR:
    """SimCLR-style augs for 32x32 color images."""
    def __init__(self, size=32):
        self.base = transforms.Compose([
            transforms.RandomResizedCrop(size, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply(
                [transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
        ])
    def __call__(self, img):
        return self.base(img), self.base(img)








In [3]:
from types import SimpleNamespace

args = SimpleNamespace(
    # Core run switches
    mode='baseline',                  # 'baseline', 'simclr_then_finetune', 'finetune_from_pretrained'
    dataset='cifar10',                # 'cifar10' or 'cifar100' (keep 'mnist' if you want both)
    device='cuda',
    seed=42,
    batch_size=256,
    workers=1,
    #epochs=50,                        # CIFAR usually benefits from more epochs than MNIST
    epochs=10,
    lr=1e-3,
    wd=1e-4,
    noise_rate=0.0,                   # symmetric label noise on train set

    # SSL pretrain
    #pretrain_epochs=200,              # SimCLR typically needs longer on CIFAR
    pretrain_epochs=10,
    pretrain_lr=1e-3,
    pretrain_wd=1e-4,
    batch_size_ssl=256,
    proj_dim=128,
    temperature=0.5,
    freeze_backbone=False,

    # Reuse / checkpointing
    save_pretrained_encoder_path='runs/checkpoints/simclr_cifar_encoder.pt',
    pretrained_encoder_path='',
    reset_classifier_head=True,
)




In [7]:
!nividia-smi


/bin/bash: line 1: nividia-smi: command not found


In [None]:
start = time.time()

args.mode = 'baseline'
args.noise_rate = 0.0
run_ex(args)

end = time.time()
print(f"Duration: {(end - start)/60:.2f} min ({end - start:.1f} sec)")

Files already downloaded and verified
Files already downloaded and verified


In [None]:
start = time.time()

args.dataset = 'cifar10'
args.mode = 'simclr_then_finetune'
args.noise_rate = 0.0             # irrelevant to pretrain; used for the FT in this step
args.pretrain_epochs = 10
args.save_pretrained_encoder_path = 'runs/checkpoints/simclr_cifar10_encoder_e10.pt'
run_ex(args)                       # (or your run(args) if you merged)

end = time.time()
print(f"Duration: {(end - start)/60:.2f} min ({end - start:.1f} sec)")
