# Implementation of the SIMCLR with resnet50 backbone

In [1]:
import wandb
wandb.login()  # Opens a browser once to authenticate


[34m[1mwandb[0m: Currently logged in as: [33manaliju[0m ([33manaliju-paris[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
# =========== GLOBAL CONFIGURATION ===========
import os
import ssl
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, random_split, Dataset, Subset
from torchvision.models import resnet50
from PIL import Image
import numpy as np
import random

In [3]:

# Prevent nondeterminism
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# torch.backends.cudnn.enabled = False

CONFIG = {
    "LOCAL_OR_COLAB": "LOCAL",
    "DATA_DIR_LOCAL": "/users/c/carvalhj/datasets/EuroSAT_RGB/",
    "DATA_DIR_COLAB": "/content/EuroSAT_RGB",
    "ZIP_PATH": "/content/EuroSAT.zip",
    "EUROSAT_URL": "https://madm.dfki.de/files/sentinel/EuroSAT.zip",
    "SEED": 42,  # Default seed (will be overridden per run)
    "BATCH_SIZE": 256,
    "LR": 1e-4,
    "LR_LINEAR": 1e-4,
    "EPOCHS_SIMCLR": 100,
    "EPOCHS_LINEAR": 100,
    "PROJ_DIM": 128,
    "FEATURE_DIM": 2048, # ResNet50 feature dimension = 2048
}

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# split fractions
TRAIN_FRAC = 0.6
VAL_FRAC   = 0.2
TEST_FRAC  = 0.2

PRETRAINED = False

TEMPERATURE = 0.1

LINEAR_PROB_TRAIN_SPLIT = 0.8

In [4]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark     = False

def prepare_data():
    if CONFIG["LOCAL_OR_COLAB"] == "LOCAL":
        return CONFIG["DATA_DIR_LOCAL"]

    if not os.path.exists(CONFIG["DATA_DIR_COLAB"]):
        print("Downloading EuroSAT RGB...")
        ssl._create_default_https_context = ssl._create_unverified_context
        urllib.request.urlretrieve(CONFIG["EUROSAT_URL"], CONFIG["ZIP_PATH"])
        with zipfile.ZipFile(CONFIG["ZIP_PATH"], 'r') as zip_ref:
            zip_ref.extractall("/content")
        os.rename("/content/2750", CONFIG["DATA_DIR_COLAB"])
        print("EuroSAT RGB dataset downloaded and extracted.")
    return CONFIG["DATA_DIR_COLAB"]

In [5]:

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark     = False

def prepare_data():
    if CONFIG["LOCAL_OR_COLAB"] == "LOCAL":
        return CONFIG["DATA_DIR_LOCAL"]

    if not os.path.exists(CONFIG["DATA_DIR_COLAB"]):
        print("Downloading EuroSAT RGB...")
        ssl._create_default_https_context = ssl._create_unverified_context
        urllib.request.urlretrieve(CONFIG["EUROSAT_URL"], CONFIG["ZIP_PATH"])
        with zipfile.ZipFile(CONFIG["ZIP_PATH"], 'r') as zip_ref:
            zip_ref.extractall("/content")
        os.rename("/content/2750", CONFIG["DATA_DIR_COLAB"])
        print("EuroSAT RGB dataset downloaded and extracted.")
    return CONFIG["DATA_DIR_COLAB"]


def compute_mean_std(dataset, batch_size):
    loader = DataLoader(dataset, batch_size, shuffle=False, num_workers=2)
    mean = 0.0
    std = 0.0
    n_samples = 0

    for data, _ in loader:
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)  # (B, C, H*W)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        n_samples += batch_samples

    mean /= n_samples
    std /= n_samples
    return mean.tolist(), std.tolist()


class TwoCropsTransform:
    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        return [self.base_transform(x), self.base_transform(x)]
    
class SimCLRDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, _ = self.dataset[idx]
        x1, x2 = self.transform(x)
        return x1, x2



def get_data_loaders(data_dir, batch_size):
    """
    Returns:
      - train_loader (SimCLR-style: yields (x1, x2) pairs)
      - val_loader  (standard: yields (image, label))
      - test_loader (standard: yields (image, label))
      - num_classes (int)
    """

    # Compute channel‐wise mean/std on the training split (unaugmented).
    # build a separate ImageFolder with `transform=ToTensor()`.
    dataset_for_stats = datasets.ImageFolder(
        root=data_dir,
        transform=transforms.ToTensor()
    )
    total_len = len(dataset_for_stats)
    n_train = int(TRAIN_FRAC * total_len)
    n_val   = int(VAL_FRAC   * total_len)
    n_test  = total_len - n_train - n_val

    train_for_stats, val_for_stats, test_for_stats = random_split(
        dataset_for_stats,
        [n_train, n_val, n_test]
    )
    # Compute mean/std on train_for_stats
    mean, std = compute_mean_std(train_for_stats, batch_size)

    # IMAGEFOLDER FOR TRAIN (NO TRANSFORM HERE)
    # We want the SimCLR augmentations applied on‐the‐fly, so we leave transform=None.
    dataset_train_no_transform = datasets.ImageFolder(
        root=data_dir,
        transform=None
    )

    generator = torch.Generator().manual_seed(seed)
    train_indices, val_indices, test_indices = random_split(
        list(range(total_len)),
        [n_train, n_val, n_test],
        generator=generator
    )
    
    all_indices = list(range(total_len))
    rnd = random.Random(seed)
    rnd.shuffle(all_indices)
    train_indices = all_indices[:n_train]
    val_indices   = all_indices[n_train : n_train + n_val]
    test_indices  = all_indices[n_train + n_val : ]

    # Now build Subsets pointing to dataset_train_no_transform:
    train_subset_no_transform = Subset(dataset_train_no_transform, train_indices)

    # 3) IMAGEFOLDER FOR VALIDATION/TEST (WITH EVAL TRANSFORM)
    eval_transform = transforms.Compose([
        transforms.Resize(72),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])
    dataset_eval = datasets.ImageFolder(
        root=data_dir,
        transform=eval_transform
    )
    val_subset = Subset(dataset_eval, val_indices)
    test_subset = Subset(dataset_eval, test_indices)

    # DEFINE SIMCLR TRANSFORMS FOR TRAIN
    normalize = transforms.Normalize(mean=mean, std=std)
    augment_transform = transforms.Compose([
        transforms.RandomResizedCrop(64, scale=(0.5, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        normalize,
    ])
    simclr_transform = TwoCropsTransform(augment_transform)
    train_ds_simclr = SimCLRDataset(train_subset_no_transform, simclr_transform)

    train_loader = DataLoader(
        train_ds_simclr,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=2
    )
    val_loader = DataLoader(
        val_subset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2
    )
    test_loader = DataLoader(
        test_subset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2
    )

    num_classes = len(dataset_for_stats.classes)
    return train_loader, val_loader, test_loader, num_classes

class ProjectionHead(nn.Module):
    def __init__(self, input_dim, proj_dim=128, hidden_dim=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, proj_dim)
        )

    def forward(self, x):
        return self.net(x)

class SimCLRModel(nn.Module):
    def __init__(self, base_encoder, proj_dim=128):
        super().__init__()
        self.encoder = base_encoder
        self.encoder.fc = nn.Identity()
        self.projection_head = ProjectionHead(input_dim=CONFIG["FEATURE_DIM"], proj_dim=proj_dim)

    def forward(self, x):
        feat = self.encoder(x)
        proj = self.projection_head(feat)
        return feat, proj

class NTXentLoss(nn.Module):
    def __init__(self, batch_size, temperature=0.5, device='cuda'):
        super().__init__()
        self.temperature = temperature
        self.batch_size = batch_size
        self.device = device
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, zis, zjs):
        N = zis.size(0)
        z = F.normalize(torch.cat([zis, zjs], dim=0), dim=1)
        sim = torch.matmul(z, z.T) / self.temperature
        mask = torch.eye(2 * N, dtype=torch.bool).to(self.device)
        sim = sim.masked_fill(mask, -1e9)
        labels = torch.cat([torch.arange(N, 2 * N), torch.arange(0, N)]).to(self.device)
        # return self.x(sim, labels)
        return self.criterion(sim, labels)

def train_simclr(model, loader, optimizer, criterion, device, epochs):
    model.train()
    model.to(device)
    for epoch in range(epochs):
        total_loss = 0
        for (x1, x2) in loader:
            x1, x2 = x1.to(device), x2.to(device)
            _, z1 = model(x1)
            _, z2 = model(x2)
            loss = criterion(z1, z2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg = total_loss / len(loader)
        print(f"[SimCLR] Epoch {epoch+1}/{epochs} - Loss: {avg:.4f}")
    print("Finished SimCLR pretraining.")

def evaluate(classifier, backbone, loader, device):
    classifier.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            features = backbone(images)
            outputs = classifier(features)
            total += labels.size(0)
            correct += (outputs.argmax(1) == labels).sum().item()
    return correct / total * 100



In [None]:
def train_linear_probe(backbone, train_loader, val_loader, device, epochs, lr, run_id):
    for p in backbone.parameters():
        p.requires_grad = False

    base_ds = train_loader.dataset
    while isinstance(base_ds, Subset):
        base_ds = base_ds.dataset
    num_classes = len(base_ds.classes)

    # ── Build a single‐layer classifier on top of the frozen features ──
    classifier = nn.Linear(CONFIG["FEATURE_DIM"], num_classes).to(device)

    optimizer = optim.Adam(classifier.parameters(), lr=lr)

    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        classifier.train()
        correct, total = 0, 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            features = backbone(images)
            outputs = classifier(features)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total += labels.size(0)
            correct += (outputs.argmax(1) == labels).sum().item()

        train_acc = correct / total * 100
        val_acc = evaluate(classifier, backbone, val_loader, device)
        print(f"[Linear] Epoch {epoch+1}/{epochs} - "
              f"Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")

    torch.save(classifier.state_dict(), f"linear_probe_seed{run_id}.pth")
    return val_acc

def load_evaluate_model(model_path, device, data_dir, seed):
    results = []

    backbone = resnet50(weights=None if not PRETRAINED else "DEFAULT")
    backbone.fc = nn.Identity()  # same as in SimCLRModel
    backbone.load_state_dict(torch.load(model_path), strict=False)
    backbone.to(device)    
    backbone.eval()
    for p in backbone.parameters():
        p.requires_grad = False

    # (Make sure 'data_dir' and 'seed' are in scope; if not, pass them in.)
    _, _, test_loader, _ = get_data_loaders(data_dir, CONFIG["BATCH_SIZE"])
    print(f"Starting linear probe on EuroSAT test split (seed={seed})...")

    # Split EuroSAT test‐subset into 80%/20% for probe‐train vs. probe‐val
    full_test_ds = test_loader.dataset  # this is a Subset of dataset_eval
    train_size = int(LINEAR_PROB_TRAIN_SPLIT * len(full_test_ds))
    val_size = len(full_test_ds) - train_size
    train_dataset, val_dataset = random_split(full_test_ds, [train_size, val_size])

    train_loader_from_test = DataLoader(
        train_dataset,
        batch_size=CONFIG["BATCH_SIZE"],
        shuffle=True,
        num_workers=2
    )
    val_loader_from_test = DataLoader(
        val_dataset,
        batch_size=CONFIG["BATCH_SIZE"],
        shuffle=False,
        num_workers=2
    )

    linear_probe_val_acc = train_linear_probe(
        backbone,
        train_loader_from_test,
        val_loader_from_test,
        DEVICE,
        epochs=CONFIG["EPOCHS_LINEAR"],
        lr=CONFIG["LR_LINEAR"],
        run_id=seed
    )
    print(f"[Linear‐Probe on EuroSAT test] Final Val Acc = {linear_probe_val_acc:.2f}%\n")

    # ─── 4) Save & return ───────────────────────────────────────────
    results.append({
        "seed": seed,
        "val_acc": linear_probe_val_acc
    })
    with open("linear_probe_results.txt", "a") as f:
        f.write(f"Seed: {seed}, Val Acc: {linear_probe_val_acc:.2f}%\n")
    print("Results saved to linear_probe_results.txt")
    return results

In [7]:
# =========== RUN EVERYTHING ===========

# Define the list of seeds for each run
seeds = [42]

for seed in seeds:
    print(f"\n=== Starting run with seed {seed} ===")
    set_seed(seed)
    
    data_dir = prepare_data()
    train_loader, val_loader, test_loader, num_classes = get_data_loaders(data_dir, CONFIG["BATCH_SIZE"])

    # Initialize base encoder and SimCLR model
    base_encoder = resnet50(weights=None if not PRETRAINED else "DEFAULT")
    simclr_model = SimCLRModel(base_encoder, proj_dim=CONFIG["PROJ_DIM"])
    optimizer = optim.Adam(simclr_model.parameters(), lr=CONFIG["LR"])
    bs = CONFIG["BATCH_SIZE"]
    loss_fn = NTXentLoss(bs, temperature=TEMPERATURE, device=DEVICE)

    print("Starting SimCLR training...")
    train_simclr(simclr_model, train_loader, optimizer, loss_fn, DEVICE, CONFIG["EPOCHS_SIMCLR"])

    print("Saving encoder...")
    torch.save(simclr_model.state_dict(), f"simclr_model_seed{seed}_temperature{TEMPERATURE}_bs{bs}.pth")



=== Starting run with seed 42 ===
Starting SimCLR training...
[SimCLR] Epoch 1/100 - Loss: 5.8148
[SimCLR] Epoch 2/100 - Loss: 5.4270
[SimCLR] Epoch 3/100 - Loss: 5.1051
[SimCLR] Epoch 4/100 - Loss: 4.8714
[SimCLR] Epoch 5/100 - Loss: 4.6871
[SimCLR] Epoch 6/100 - Loss: 4.4835
[SimCLR] Epoch 7/100 - Loss: 4.2467
[SimCLR] Epoch 8/100 - Loss: 4.0137
[SimCLR] Epoch 9/100 - Loss: 3.7987
[SimCLR] Epoch 10/100 - Loss: 3.5824
[SimCLR] Epoch 11/100 - Loss: 3.3657
[SimCLR] Epoch 12/100 - Loss: 3.1135
[SimCLR] Epoch 13/100 - Loss: 2.9641
[SimCLR] Epoch 14/100 - Loss: 2.8231
[SimCLR] Epoch 15/100 - Loss: 2.6822
[SimCLR] Epoch 16/100 - Loss: 2.5939
[SimCLR] Epoch 17/100 - Loss: 2.5121
[SimCLR] Epoch 18/100 - Loss: 2.4416
[SimCLR] Epoch 19/100 - Loss: 2.3567
[SimCLR] Epoch 20/100 - Loss: 2.2716
[SimCLR] Epoch 21/100 - Loss: 2.2032
[SimCLR] Epoch 22/100 - Loss: 2.1246
[SimCLR] Epoch 23/100 - Loss: 2.0770
[SimCLR] Epoch 24/100 - Loss: 2.0192
[SimCLR] Epoch 25/100 - Loss: 1.9564
[SimCLR] Epoch 26/100

In [8]:

# Run the evaluation
for seed in seeds:
    results = load_evaluate_model(f"simclr_model_seed{seed}.pth", DEVICE, data_dir, seed)
    print(f"Results for seed {seed}: {results}")
    

  backbone.load_state_dict(torch.load(model_path), strict=False)


Starting linear probe on EuroSAT test split (seed=42)...
[Linear] Epoch 1/100 - Train Acc: 9.54%, Val Acc: 7.13%
[Linear] Epoch 2/100 - Train Acc: 9.51%, Val Acc: 13.15%
[Linear] Epoch 3/100 - Train Acc: 11.44%, Val Acc: 12.22%
[Linear] Epoch 4/100 - Train Acc: 15.05%, Val Acc: 14.91%
[Linear] Epoch 5/100 - Train Acc: 18.89%, Val Acc: 14.91%
[Linear] Epoch 6/100 - Train Acc: 19.88%, Val Acc: 17.87%
[Linear] Epoch 7/100 - Train Acc: 21.78%, Val Acc: 19.72%
[Linear] Epoch 8/100 - Train Acc: 23.56%, Val Acc: 25.83%
[Linear] Epoch 9/100 - Train Acc: 24.44%, Val Acc: 19.54%
[Linear] Epoch 10/100 - Train Acc: 25.14%, Val Acc: 20.83%
[Linear] Epoch 11/100 - Train Acc: 24.86%, Val Acc: 27.78%
[Linear] Epoch 12/100 - Train Acc: 26.46%, Val Acc: 26.76%
[Linear] Epoch 13/100 - Train Acc: 26.11%, Val Acc: 21.20%
[Linear] Epoch 14/100 - Train Acc: 26.97%, Val Acc: 27.13%
[Linear] Epoch 15/100 - Train Acc: 27.78%, Val Acc: 25.19%
[Linear] Epoch 16/100 - Train Acc: 28.73%, Val Acc: 24.35%
[Linear] Ep