In [1]:
import wandb
wandb.login()  # Opens a browser once to authenticate


[34m[1mwandb[0m: Currently logged in as: [33manaliju[0m ([33manaliju-paris[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.models import resnet50
from itertools import product
import numpy as np
import random
import copy
import os, ssl, zipfile, urllib

LOCAL_OR_COLAB = "LOCAL"
SEED           = 42
NUM_EPOCHS     = 100
DEVICE         = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TRAIN_FRAC = 0.8
VAL_FRAC   = 0.1
TEST_FRAC  = 0.1

# hyperparameter grid
BATCH_SIZES = [256, 512] 
LRS = [1e-4]

if LOCAL_OR_COLAB == "LOCAL":
    DATA_DIR = "/users/c/carvalhj/datasets/EuroSAT_RGB/"
else:
    data_root = "/content/EuroSAT_RGB"
    zip_path  = "/content/EuroSAT.zip"
    if not os.path.exists(data_root):
        ssl._create_default_https_context = ssl._create_unverified_context
        urllib.request.urlretrieve(
            "https://madm.dfki.de/files/sentinel/EuroSAT.zip", zip_path
        )
        with zipfile.ZipFile(zip_path, "r") as z:
            z.extractall("/content")
        os.rename("/content/2750", data_root)
    DATA_DIR = data_root



In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark     = False

def compute_mean_std(dataset, batch_size):
    loader = DataLoader(dataset, batch_size, shuffle=False, num_workers=2)
    mean = 0.0
    std = 0.0
    n_samples = 0

    for data, _ in loader:
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)  # (B, C, H*W)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        n_samples += batch_samples

    mean /= n_samples
    std /= n_samples
    return mean.tolist(), std.tolist()

def get_data_loaders(data_dir, batch_size):
    base_tf = transforms.ToTensor()
    ds_all = datasets.ImageFolder(root=data_dir, transform=base_tf)

    n = len(ds_all)
    n_train = int(TRAIN_FRAC * n)
    n_val = int(VAL_FRAC * n)
    n_test = n - n_train - n_val
    train_ds, val_ds, test_ds = random_split(ds_all, [n_train, n_val, n_test])

    print("Computing mean and std from training set...")
    mean, std = compute_mean_std(train_ds, batch_size)
    print(f"Computed mean: {mean}")
    print(f"Computed std:  {std}")

    tf_final = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # Redefine datasets with transform
    ds_all = datasets.ImageFolder(root=data_dir, transform=tf_final)
    train_ds, val_ds, test_ds = random_split(ds_all, [n_train, n_val, n_test])

    return (
        DataLoader(train_ds, batch_size, shuffle=True),
        DataLoader(val_ds, batch_size, shuffle=False),
        DataLoader(test_ds, batch_size, shuffle=False),
        len(ds_all.classes)
    )


def build_model(n_cls, pretrained=False):
    m = resnet50(weights=None if not pretrained else "DEFAULT")
    m.fc = nn.Linear(m.fc.in_features, n_cls)
    return m.to(DEVICE)

def train_one_epoch(model, loader, opt, crit, sched=None):
    model.train()
    tot_loss, corr, tot = 0.0, 0, 0
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        opt.zero_grad()
        logits = model(xb)

        loss   = crit(logits, yb)
        loss.backward()
        opt.step()
        if sched: sched.step()
        tot_loss += loss.item()
        preds    = logits.argmax(dim=1)
        corr    += (preds==yb).sum().item()
        tot     += yb.size(0)
    return tot_loss/len(loader), 100*corr/tot

def evaluate(model, loader):
    model.eval()
    corr, tot = 0,0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            preds = model(xb).argmax(dim=1)
            corr += (preds==yb).sum().item()
            tot  += yb.size(0)
    return 100 * corr / tot

def hyperparam_search(pretrained=True):
    best_val = -1.0
    best_cfg = None
    best_model = None

    for bs, lr in product(BATCH_SIZES, LRS):
        print(f"\n>>> Testing BS={bs}, LR={lr:.1e}")
        set_seed(SEED)
        tr_dl, val_dl, te_dl, n_cls = get_data_loaders(DATA_DIR, bs)
        model = build_model(n_cls, pretrained=pretrained)

        opt = optim.Adam(model.parameters(), lr=lr)
        sched = optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.9)

        crit = nn.CrossEntropyLoss()

        # Start a W&B run here
        wandb_run = wandb.init(
            project="eurosat-supervised-scratch-grid-search",
            name=f"BS{bs}_LR{lr:.0e}",
            config={
                "batch_size": bs,
                "learning_rate": lr,
                "epochs": NUM_EPOCHS,
                "pretrained": pretrained,
            }
        )

        for ep in range(NUM_EPOCHS):
            tr_loss, tr_acc = train_one_epoch(model, tr_dl, opt, crit, sched)
            val_acc = evaluate(model, val_dl)
            print(f"  Ep{ep+1}/{NUM_EPOCHS}: train={tr_acc:.1f}%  val={val_acc:.1f}%")

            # Log metrics to W&B
            wandb.log({
                "epoch": ep + 1,
                "train_loss": tr_loss,
                "train_acc": tr_acc,
                "val_acc": val_acc
            })

        wandb_run.finish()

        if val_acc > best_val:
            best_val = val_acc
            best_cfg = (bs, lr)
            best_model = copy.deepcopy(model)

    print(f"\n>>> Best config: BS={best_cfg[0]}, LR={best_cfg[1]:.1e}, val={best_val:.1f}%")
    return best_cfg, best_model

def linear_probe(frozen_model, train_dl, test_dl, lr):
    for p in frozen_model.parameters():
        p.requires_grad = False
    # new head
    n_in = frozen_model.fc.in_features
    n_out = frozen_model.fc.out_features
    frozen_model.fc = nn.Linear(n_in, n_out).to(DEVICE)

    opt = optim.Adam(frozen_model.fc.parameters(), lr=lr)
    sched = optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.9)

    crit = nn.CrossEntropyLoss()

    print("\n>>> Running linear probe on frozen backbone")
    for ep in range(NUM_EPOCHS):
        loss, acc = train_one_epoch(frozen_model, train_dl, opt, crit, sched=None)
        print(f"  Probe Ep{ep+1}/{NUM_EPOCHS}: train={acc:.1f}%")
    test_acc = evaluate(frozen_model, test_dl)
    print(f"Probe test acc: {test_acc:.1f}%")
    wandb.init(
        project="eurosat-supervised-scratch-linear-probe",
        name=f"BS{train_dl.batch_size}_LR{lr:.0e}_probe",
        config={
            "batch_size": train_dl.batch_size,
            "learning_rate": lr,
            "epochs": NUM_EPOCHS,
            "pretrained": False,
            "probe": True
        }
    )
    wandb.log({"probe_test_acc": test_acc})

    return test_acc

# ─── MAIN ───────────────────────────────────────────────────────────────────────
best_cfg, best_model = hyperparam_search(pretrained = False)
# rebuild loaders once more so we have the same splits
bs, lr = best_cfg
tr_dl, val_dl, te_dl, _ = get_data_loaders(DATA_DIR, bs)

# Option A: probe on just the original training split
probe_acc = linear_probe(best_model, tr_dl, te_dl, lr)


>>> Testing BS=256, LR=1.0e-04
Computing mean and std from training set...
Computed mean: [0.3438493311405182, 0.38001248240470886, 0.4077288508415222]
Computed std:  [0.09294864535331726, 0.06473352760076523, 0.05418824777007103]


  Ep1/100: train=48.1%  val=60.2%
  Ep2/100: train=63.7%  val=66.7%
  Ep3/100: train=68.8%  val=67.8%
  Ep4/100: train=70.8%  val=68.3%
  Ep5/100: train=71.7%  val=68.4%
  Ep6/100: train=71.8%  val=69.4%
  Ep7/100: train=72.2%  val=68.9%
  Ep8/100: train=72.4%  val=68.2%
  Ep9/100: train=72.4%  val=68.9%
  Ep10/100: train=72.4%  val=68.9%
  Ep11/100: train=72.3%  val=69.0%
  Ep12/100: train=72.1%  val=68.8%
  Ep13/100: train=72.2%  val=69.3%
  Ep14/100: train=72.3%  val=68.6%
  Ep15/100: train=72.0%  val=68.6%
  Ep16/100: train=72.4%  val=69.5%
  Ep17/100: train=72.0%  val=69.0%
  Ep18/100: train=72.4%  val=69.1%
  Ep19/100: train=72.1%  val=68.9%
  Ep20/100: train=72.0%  val=68.8%
  Ep21/100: train=72.4%  val=69.0%
  Ep22/100: train=72.3%  val=68.4%
  Ep23/100: train=72.2%  val=68.8%
  Ep24/100: train=72.1%  val=69.1%
  Ep25/100: train=71.8%  val=68.8%
  Ep26/100: train=72.3%  val=69.1%
  Ep27/100: train=72.1%  val=68.9%
  Ep28/100: train=72.5%  val=68.9%
  Ep29/100: train=72.3%  val=

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇▇█████
train_acc,▁███████████████████████████████████████
train_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▁▆▇▇█▇█▇█▇█▇▇▇▇█▇▇▇▇▇█▇▇▇▇████▇█▇▇▇▇█▇██

0,1
epoch,100.0
train_acc,72.10648
train_loss,0.78771
val_acc,68.55556



>>> Testing BS=512, LR=1.0e-04
Computing mean and std from training set...
Computed mean: [0.3438493609428406, 0.3800123929977417, 0.4077288508415222]
Computed std:  [0.09294863790273666, 0.06473352760076523, 0.05418826639652252]


  Ep1/100: train=43.6%  val=51.1%
  Ep2/100: train=61.1%  val=63.7%
  Ep3/100: train=67.9%  val=67.6%
  Ep4/100: train=72.6%  val=70.1%
  Ep5/100: train=75.0%  val=71.3%
  Ep6/100: train=77.6%  val=71.4%
  Ep7/100: train=78.8%  val=71.6%
  Ep8/100: train=79.6%  val=71.5%
  Ep9/100: train=80.0%  val=71.5%
  Ep10/100: train=80.5%  val=72.1%
  Ep11/100: train=80.7%  val=71.5%
  Ep12/100: train=80.8%  val=72.0%
  Ep13/100: train=80.5%  val=71.5%
  Ep14/100: train=81.1%  val=72.0%
  Ep15/100: train=81.2%  val=71.8%
  Ep16/100: train=81.5%  val=71.8%
  Ep17/100: train=80.5%  val=71.9%
  Ep18/100: train=81.0%  val=71.8%
  Ep19/100: train=80.7%  val=71.7%
  Ep20/100: train=81.3%  val=72.0%
  Ep21/100: train=81.1%  val=72.0%
  Ep22/100: train=80.9%  val=71.9%
  Ep23/100: train=80.8%  val=71.4%
  Ep24/100: train=81.1%  val=71.7%
  Ep25/100: train=80.7%  val=71.7%
  Ep26/100: train=81.1%  val=71.5%
  Ep27/100: train=81.0%  val=71.8%
  Ep28/100: train=81.5%  val=71.6%
  Ep29/100: train=80.8%  val=

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇████
train_acc,▁▇▇▇████████████████████████████████████
train_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▁▆▇███████████████████████████▇█████████

0,1
epoch,100.0
train_acc,80.34722
train_loss,0.56053
val_acc,71.37037



>>> Best config: BS=512, LR=1.0e-04, val=71.4%
Computing mean and std from training set...
Computed mean: [0.3444792926311493, 0.3804510235786438, 0.40790361166000366]
Computed std:  [0.09314047545194626, 0.06477275490760803, 0.054232291877269745]

>>> Running linear probe on frozen backbone
  Probe Ep1/100: train=55.1%
  Probe Ep2/100: train=73.0%
  Probe Ep3/100: train=74.7%
  Probe Ep4/100: train=75.5%
  Probe Ep5/100: train=75.8%
  Probe Ep6/100: train=76.0%
  Probe Ep7/100: train=77.0%
  Probe Ep8/100: train=77.1%
  Probe Ep9/100: train=77.0%
  Probe Ep10/100: train=77.4%
  Probe Ep11/100: train=77.2%
  Probe Ep12/100: train=77.4%
  Probe Ep13/100: train=77.7%
  Probe Ep14/100: train=77.7%
  Probe Ep15/100: train=78.1%
  Probe Ep16/100: train=78.3%
  Probe Ep17/100: train=78.0%
  Probe Ep18/100: train=78.2%
  Probe Ep19/100: train=78.2%
  Probe Ep20/100: train=78.3%
  Probe Ep21/100: train=78.3%
  Probe Ep22/100: train=78.5%
  Probe Ep23/100: train=78.2%
  Probe Ep24/100: train=7

In [4]:
best_cfg, best_model 


((512, 0.0001),
 ResNet(
   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
   (layer1): Sequential(
     (0): Bottleneck(
       (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU(inplace=True)
       (downsample): Sequential(
         (0): Conv2d(64, 256,

In [5]:
torch.save(best_model.state_dict(), "best_model.pt")
artifact = wandb.Artifact("best_model", type="model")
artifact.add_file("best_model.pt")
wandb.log_artifact(artifact)


<Artifact best_model>