In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.models import resnet18
from itertools import product
import numpy as np
import random
import copy
import os, ssl, urllib.request, zipfile

# ─── CONFIG ─────────────────────────────────────────────────────────────────────
LOCAL_OR_COLAB = "COLAB"
SEED           = 42
NUM_EPOCHS     = 20
DEVICE         = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# split fractions
TRAIN_FRAC = 0.6
VAL_FRAC   = 0.2
TEST_FRAC  = 0.2

# hyperparameter grid
BATCH_SIZES = [32, 64]
GRID        = [
    (2e-4,    0.1  ),  # SimCLR
    (1.875e-4,0.5  ),  # SatMIP
    (3.75e-4, 0.5  ),  # SatMIPS
]

# ─── DATASET DOWNLOAD ────────────────────────────────────────────────────────────
if LOCAL_OR_COLAB == "LOCAL":
    DATA_DIR = "/home/juliana/internship_LINUX/datasets/EuroSAT_RGB"
else:
    data_root = "/content/EuroSAT_RGB"
    zip_path  = "/content/EuroSAT.zip"
    if not os.path.exists(data_root):
        ssl._create_default_https_context = ssl._create_unverified_context
        urllib.request.urlretrieve(
            "https://madm.dfki.de/files/sentinel/EuroSAT.zip", zip_path
        )
        with zipfile.ZipFile(zip_path, "r") as z:
            z.extractall("/content")
        os.rename("/content/2750", data_root)
    DATA_DIR = data_root

# ─── HELPERS ─────────────────────────────────────────────────────────────────────
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark     = False

def get_data_loaders(data_dir, batch_size):
    tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485,0.456,0.406],
            std =[0.229,0.224,0.225]
        )
    ])
    ds = datasets.ImageFolder(root=data_dir, transform=tf)
    n   = len(ds)
    n_train = int(TRAIN_FRAC * n)
    n_val   = int(VAL_FRAC   * n)
    n_test  = n - n_train - n_val
    train_ds, val_ds, test_ds = random_split(ds, [n_train, n_val, n_test])
    return (
        DataLoader(train_ds, batch_size, shuffle=True),
        DataLoader(val_ds,   batch_size, shuffle=False),
        DataLoader(test_ds,  batch_size, shuffle=False),
        len(ds.classes)
    )

def build_model(n_cls, pretrained=False):
    m = resnet18(weights=None if not pretrained else "DEFAULT")
    m.fc = nn.Linear(m.fc.in_features, n_cls)
    return m.to(DEVICE)

def train_one_epoch(model, loader, opt, crit, sched=None):
    model.train()
    tot_loss, corr, tot = 0.0, 0, 0
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        opt.zero_grad()
        logits = model(xb)
        loss   = crit(logits, yb)
        loss.backward()
        opt.step()
        if sched: sched.step()
        tot_loss += loss.item()
        preds    = logits.argmax(dim=1)
        corr    += (preds==yb).sum().item()
        tot     += yb.size(0)
    return tot_loss/len(loader), 100*corr/tot

def evaluate(model, loader):
    model.eval()
    corr, tot = 0,0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            preds = model(xb).argmax(dim=1)
            corr += (preds==yb).sum().item()
            tot  += yb.size(0)
    return 100 * corr / tot

# ─── PHASE 1: GRID SEARCH ────────────────────────────────────────────────────────
def hyperparam_search(pretrained = True):
    best_val = -1.0
    best_cfg = None
    best_model = None
    # loop over all combos in one go
    for bs, (lr, wd) in product(BATCH_SIZES, GRID):
        print(f"\n>>> Testing BS={bs}, LR={lr:.1e}, WD={wd}")
        set_seed(SEED)
        tr_dl, val_dl, te_dl, n_cls = get_data_loaders(DATA_DIR, bs)
        model = build_model(n_cls, pretrained = pretrained)

        # optimizer + paper schedule
        opt = optim.AdamW(model.parameters(),
                          lr=lr, betas=(0.9,0.98), eps=1e-8, weight_decay=wd)
        total_steps  = NUM_EPOCHS * len(tr_dl)
        warmup_steps = len(tr_dl)
        sched = SequentialLR(
            opt,
            schedulers=[
                LinearLR(opt,  start_factor=1e-6, end_factor=1.0, total_iters=warmup_steps),
                CosineAnnealingLR(opt, T_max=total_steps-warmup_steps)
            ],
            milestones=[warmup_steps]
        )
        crit = nn.CrossEntropyLoss()

        # train & validate
        for ep in range(NUM_EPOCHS):
            tr_loss, tr_acc = train_one_epoch(model, tr_dl, opt, crit, sched)
            val_acc          = evaluate(model, val_dl)
            print(f"  Ep{ep+1}/{NUM_EPOCHS}: train={tr_acc:.1f}%  val={val_acc:.1f}%")

        # pick best
        if val_acc > best_val:
            best_val = val_acc
            best_cfg = (bs, lr, wd)
            best_model = copy.deepcopy(model)   # store the weights

    print(f"\n>>> Best config: BS={best_cfg[0]}, LR={best_cfg[1]:.1e}, WD={best_cfg[2]} "
          f"→ val={best_val:.1f}%")
    return best_cfg, best_model

# ─── PHASE 2: LINEAR PROBE ───────────────────────────────────────────────────────
def linear_probe(frozen_model, train_dl, test_dl, lr, wd):
    # freeze backbone
    for p in frozen_model.parameters():
        p.requires_grad = False
    # new head
    n_in = frozen_model.fc.in_features
    n_out = frozen_model.fc.out_features
    frozen_model.fc = nn.Linear(n_in, n_out).to(DEVICE)

    opt = optim.AdamW(frozen_model.fc.parameters(),
                      lr=lr, betas=(0.9,0.98), eps=1e-8, weight_decay=wd)
    crit = nn.CrossEntropyLoss()

    print("\n>>> Running linear probe on frozen backbone")
    for ep in range(NUM_EPOCHS):
        loss, acc = train_one_epoch(frozen_model, train_dl, opt, crit, sched=None)
        print(f"  Probe Ep{ep+1}/{NUM_EPOCHS}: train={acc:.1f}%")
    test_acc = evaluate(frozen_model, test_dl)
    print(f"→ Probe test acc: {test_acc:.1f}%")
    return test_acc

# ─── MAIN ───────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    best_cfg, best_model = hyperparam_search(pretrained = False)
    # rebuild loaders once more so we have the same splits
    bs, lr, wd = best_cfg
    tr_dl, val_dl, te_dl, _ = get_data_loaders(DATA_DIR, bs)

    # Option A: probe on just the original training split
    probe_acc = linear_probe(best_model, tr_dl, te_dl, lr, wd)



>>> Testing BS=32, LR=2.0e-04, WD=0.1
  Ep1/20: train=61.4%  val=68.3%
  Ep2/20: train=74.3%  val=76.4%
  Ep3/20: train=80.4%  val=76.4%
  Ep4/20: train=85.9%  val=76.6%
  Ep5/20: train=89.1%  val=79.1%
  Ep6/20: train=91.8%  val=82.6%
  Ep7/20: train=93.9%  val=81.3%
  Ep8/20: train=95.3%  val=83.9%
  Ep9/20: train=96.0%  val=81.2%
  Ep10/20: train=97.0%  val=83.5%
  Ep11/20: train=97.8%  val=86.0%
  Ep12/20: train=98.4%  val=86.3%
  Ep13/20: train=98.8%  val=87.1%
  Ep14/20: train=99.2%  val=85.6%
  Ep15/20: train=99.4%  val=87.8%
  Ep16/20: train=99.7%  val=84.6%
  Ep17/20: train=99.8%  val=87.4%
  Ep18/20: train=99.9%  val=88.2%
  Ep19/20: train=99.9%  val=88.3%
  Ep20/20: train=99.9%  val=88.5%

>>> Testing BS=32, LR=1.9e-04, WD=0.5
  Ep1/20: train=60.7%  val=67.4%
  Ep2/20: train=73.9%  val=79.4%
  Ep3/20: train=80.3%  val=76.3%
  Ep4/20: train=85.2%  val=77.5%
  Ep5/20: train=88.2%  val=83.2%
  Ep6/20: train=90.9%  val=78.5%
  Ep7/20: train=92.7%  val=80.7%
  Ep8/20: train=94.3

In [3]:
# Option B (train head on train+val):
merged = torch.utils.data.ConcatDataset([tr_dl.dataset, val_dl.dataset])
merged_dl = DataLoader(merged, bs, shuffle=True)
probe_acc = linear_probe(best_model, merged_dl, te_dl, lr, wd)



>>> Running linear probe on frozen backbone
  Probe Ep1/20: train=95.3%
  Probe Ep2/20: train=96.2%
  Probe Ep3/20: train=96.3%
  Probe Ep4/20: train=96.4%
  Probe Ep5/20: train=96.4%
  Probe Ep6/20: train=96.2%
  Probe Ep7/20: train=96.3%
  Probe Ep8/20: train=96.3%
  Probe Ep9/20: train=96.3%
  Probe Ep10/20: train=96.3%
  Probe Ep11/20: train=96.4%
  Probe Ep12/20: train=96.3%
  Probe Ep13/20: train=96.3%
  Probe Ep14/20: train=96.4%
  Probe Ep15/20: train=96.4%
  Probe Ep16/20: train=96.3%
  Probe Ep17/20: train=96.3%
  Probe Ep18/20: train=96.4%
  Probe Ep19/20: train=96.3%
  Probe Ep20/20: train=96.4%
→ Probe test acc: 97.0%


## Linear probing with scikit learn

In [4]:
import torch
import numpy as np
from tqdm import tqdm

def extract_embeddings(model, loader, device):
    model.eval()
    # remove last classifier layer
    backbone = torch.nn.Sequential(*list(model.children())[:-1])
    backbone.to(device)
    all_feats, all_labels = [], []
    with torch.no_grad():
        for xb, yb in tqdm(loader, desc="Extracting"):
            xb = xb.to(device)
            feats = backbone(xb)           # shape: (B, C, 1, 1)
            feats = feats.view(feats.size(0), -1)  # (B, C)
            all_feats.append(feats.cpu().numpy())
            all_labels.append(yb.numpy())
    return np.vstack(all_feats), np.concatenate(all_labels)

# 1) Extract embeddings from frozen best_model
X_train, y_train = extract_embeddings(best_model, tr_dl, DEVICE)
X_test,  y_test  = extract_embeddings(best_model, te_dl, DEVICE)

# 2) Fit a scikit‑learn “linear probe” (logistic regression)
from sklearn.linear_model    import LogisticRegression
from sklearn.preprocessing   import StandardScaler
from sklearn.metrics         import accuracy_score, classification_report

# scale features
scaler  = StandardScaler().fit(X_train)
X_tr_s  = scaler.transform(X_train)
X_te_s  = scaler.transform(X_test)

# C ≃ 1/weight_decay — try a small grid
clf = LogisticRegression(
    penalty='l2',
    C=1.0,
    solver='saga',
    multi_class='multinomial',
    max_iter=200
).fit(X_tr_s, y_train)

# 3) Evaluate
preds = clf.predict(X_te_s)
acc   = accuracy_score(y_test, preds)
print(f"sklearn probe test accuracy: {acc*100:.2f}%")
print(classification_report(y_test, preds, digits=4))


Extracting: 100%|██████████| 507/507 [00:12<00:00, 39.64it/s]
Extracting: 100%|██████████| 169/169 [00:03<00:00, 42.36it/s]


sklearn probe test accuracy: 96.83%
              precision    recall  f1-score   support

           0     0.9722    0.9655    0.9688       579
           1     0.9914    0.9897    0.9905       580
           2     0.9527    0.9363    0.9444       581
           3     0.9572    0.9400    0.9485       500
           4     0.9579    0.9701    0.9640       469
           5     0.9523    0.9744    0.9632       430
           6     0.9365    0.9420    0.9392       517
           7     0.9871    0.9919    0.9895       617
           8     0.9652    0.9727    0.9689       513
           9     0.9967    0.9935    0.9951       614

    accuracy                         0.9683      5400
   macro avg     0.9669    0.9676    0.9672      5400
weighted avg     0.9684    0.9683    0.9683      5400



