In [1]:
import wandb
wandb.login()  # Opens a browser once to authenticate
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, ConcatDataset
from torchvision import datasets, transforms
from torchvision.models import resnet50
from itertools import product
import numpy as np
import random
import copy
import os, ssl, zipfile, urllib
from sklearn.model_selection import StratifiedShuffleSplit
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from sklearn.metrics import confusion_matrix
import seaborn as sns
import torch.optim as optim
from torch.optim.lr_scheduler import LinearLR, SequentialLR, MultiStepLR
from torch.utils.data import ConcatDataset, DataLoader


[34m[1mwandb[0m: Currently logged in as: [33manaliju[0m ([33manaliju-paris[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:

LOCAL_OR_COLAB = "LOCAL"
SEED           = 42
NUM_EPOCHS     = 34
DEVICE         = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TRAIN_FRAC = 0.8
VAL_FRAC   = 0.1
TEST_FRAC  = 0.1

# hyperparameter grid
# BATCH_SIZES = [64, 128, 256]
BATCH_SIZES = [512]  # Using a single batch size for simplicity
LRS = [1e-4, 3e-4]

GRID = product(
    [0.1, 0.01],    # learning rate
    [0.01, 0.0001]  # weight decay
)

TRAINING_SCHEDULES = {
    "short": {"p": [750, 1500, 2250, 2500], "w": 200, "unit": "steps"},
    "medium": {"p": [3000, 6000, 9000, 10000], "w": 500, "unit": "steps"},
    "long": {"p": [30, 60, 80, 90], "w": 5, "unit": "epochs"}
}

# BETAS=(0.9,0.98)
# EPS = 1e-8

if LOCAL_OR_COLAB == "LOCAL":
    DATA_DIR = "/users/c/carvalhj/datasets/EuroSAT_RGB/"
else:
    data_root = "/content/EuroSAT_RGB"
    zip_path  = "/content/EuroSAT.zip"
    if not os.path.exists(data_root):
        ssl._create_default_https_context = ssl._create_unverified_context
        urllib.request.urlretrieve(
            "https://madm.dfki.de/files/sentinel/EuroSAT.zip", zip_path
        )
        with zipfile.ZipFile(zip_path, "r") as z:
            z.extractall("/content")
        os.rename("/content/2750", data_root)
    DATA_DIR = data_root

NUM_WORKERS = 4 

In [3]:
def build_lr_scheduler(optimizer, total_training_steps, schedule_cfg, steps_per_epoch):
    """
    Builds the learning rate scheduler based on the specified schedule configuration.

    Args:
        optimizer: The PyTorch optimizer.
        total_training_steps: Total number of optimization steps for the entire training.
        schedule_cfg: Dictionary containing 'p', 'w', and 'unit' for the schedule.
        steps_per_epoch: Number of optimization steps in one epoch.
    """
    warmup_iters = schedule_cfg["w"]
    milestones = [] # Points at which LR drops

    if schedule_cfg["unit"] == "steps":
        milestones = schedule_cfg["p"]
    elif schedule_cfg["unit"] == "epochs":
        # Convert epoch milestones to step milestones
        milestones = [m * steps_per_epoch for m in schedule_cfg["p"]]
        warmup_iters = schedule_cfg["w"] * steps_per_epoch # Convert warmup epochs to steps

    # Linear warm-up scheduler
    warmup_scheduler = LinearLR(optimizer, start_factor=1e-6, end_factor=1.0, total_iters=warmup_iters)

    # Step decay scheduler
    # The image states "decrease the learning rate by 10 per each learning phase p"
    # This means multiplying current LR by 0.1 at each milestone.
    decay_scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    # Combine them sequentially: warmup first, then decay
    scheduler = SequentialLR(
        optimizer,
        schedulers=[warmup_scheduler, decay_scheduler],
        milestones=[warmup_iters]
    )
    return scheduler

def hyperparam_search(pretrained=True):
    best_val = -1.0
    best_cfg = None
    best_model = None

    # Iterate over batch sizes, learning rates, weight decays, and training schedules
    for bs, (lr, wd), schedule_name in product(BATCH_SIZES, GRID, TRAINING_SCHEDULES.keys()):

        print(f"\n>>> Testing BS={bs}, LR={lr:.1e}, WD={wd:.1e}, Schedule={schedule_name}")

        tr_dl, val_dl, te_dl, n_cls = get_data_loaders(DATA_DIR, bs) # Assuming get_data_loaders is adapted for preprocessing


        steps_per_epoch = len(tr_dl)

        schedule_cfg = TRAINING_SCHEDULES[schedule_name]

        if schedule_cfg["unit"] == "steps":

            # Let's set total_steps to the last 'p' value for simplicity, or slightly more.
            # A common approach is to set total_steps = max(schedule_cfg['p'])
            total_steps = max(schedule_cfg["p"]) # This is the total number of steps for the scheduler's milestones.
            # We need to ensure NUM_EPOCHS is large enough to cover these steps.
            NUM_EPOCHS_FOR_RUN = int(np.ceil(total_steps / steps_per_epoch)) + 1 # Add a buffer epoch
        else: # schedule_cfg["unit"] == "epochs"
            total_epochs_from_schedule = max(schedule_cfg["p"]) + schedule_cfg["w"] # max 'p' + warmup epochs
            NUM_EPOCHS_FOR_RUN = total_epochs_from_schedule # Total epochs to run
            total_steps = NUM_EPOCHS_FOR_RUN * steps_per_epoch


        # Build model (ResNet50 v2, assuming build_model handles this)
        # Note: The document states "ResNet50 v2 architecture (He et al., 2016)".
        # PyTorch's `torchvision.models.resnet50` is ResNet v1.
        # ResNet v2 typically involves pre-activation. You might need a custom `build_model`
        # or a specific implementation like from `timm` library if you want ResNet50 v2 exactly.
        # For now, assuming your `build_model` can handle it or you are okay with standard ResNet.
        model = build_model(n_cls, pretrained=pretrained)
        model.to(DEVICE) # Move model to device

        # Optimizer: SGD with momentum set to 0.9
        opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
        crit = nn.CrossEntropyLoss()

        # Build the learning rate scheduler based on the current schedule
        sched = build_lr_scheduler(opt, total_steps, schedule_cfg, steps_per_epoch)

        # Start a W&B run
        wandb_run = wandb.init(
            project="eurosat-supervised-scratch-grid-search-lrsched",
            name=f"BS{bs}_LR{lr:.0e}_WD{wd:.0e}_Sched_{schedule_name}",
            config={
                "batch_size": bs,
                "learning_rate": lr,
                "weight_decay": wd,
                "schedule_name": schedule_name,
                "total_epochs_for_run": NUM_EPOCHS_FOR_RUN,
                "pretrained": pretrained,
                "optimizer": "SGD_momentum_0.9",
                "scheduler_type": "LinearWarmup_MultiStepLR",
                "warmup_steps_or_epochs": schedule_cfg["w"],
                "decay_milestones": schedule_cfg["p"],
                "decay_unit": schedule_cfg["unit"]
            }
        )

        for ep in range(NUM_EPOCHS_FOR_RUN):
            tr_loss, tr_acc = train_one_epoch(model, tr_dl, opt, crit, sched, DEVICE) # Pass DEVICE to train_one_epoch
            # Compute validation loss & accuracy
            model.eval()
            val_loss, corr, tot = 0.0, 0, 0
            with torch.no_grad():
                for xb, yb in val_dl:
                    xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                    logits = model(xb)
                    loss = crit(logits, yb)
                    val_loss += loss.item()
                    preds = logits.argmax(dim=1)
                    corr += (preds == yb).sum().item()
                    tot  += yb.size(0)
            val_loss /= len(val_dl)
            val_acc = 100.0 * corr / tot

            print(f"  Ep{ep+1}/{NUM_EPOCHS_FOR_RUN}: train_acc={tr_acc:.1f}%  train_loss={tr_loss:.4f}, "
                  f"val_acc={val_acc:.1f}%, val_loss={val_loss:.4f}")

            wandb.log({
                "epoch":       ep + 1,
                "train_loss":  tr_loss,
                "train_acc":   tr_acc,
                "val_loss":    val_loss,
                "val_acc":     val_acc,
                "learning_rate": opt.param_groups[0]['lr'] # Log current LR
            })

        wandb_run.finish()

        # Only use val_acc to pick best
        if val_acc > best_val:
            best_val   = val_acc
            best_cfg   = (bs, lr, wd, schedule_name)
            best_model = copy.deepcopy(model)

    print(f"\n>>> Best config: BS={best_cfg[0]}, LR={best_cfg[1]:.1e}, WD={best_cfg[2]:.1e}, Schedule={best_cfg[3]}, val_acc={best_val:.1f}%")

    return best_cfg, best_model

def train_one_epoch(model, dataloader, optimizer, criterion, scheduler, device):
    model.train()
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    for batch_idx, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step() # <--- IMPORTANT: Step the scheduler after each batch

        total_loss += loss.item() * inputs.size(0) # Accumulate weighted by batch size
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()

    avg_loss = total_loss / total_samples
    accuracy = 100 * correct_predictions / total_samples
    return avg_loss, accuracy

In [4]:

def compute_mean_std(dataset, batch_size):
    loader = DataLoader(dataset, batch_size, shuffle=False, num_workers=2)
    mean = 0.0
    std = 0.0
    n_samples = 0

    for data, _ in loader:
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)  # (B, C, H*W)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        n_samples += batch_samples

    mean /= n_samples
    std /= n_samples
    return mean.tolist(), std.tolist()

def get_data_loaders(data_dir, batch_size):

    base_tf = transforms.ToTensor()
    ds_all = datasets.ImageFolder(root=data_dir, transform=base_tf)
    labels = np.array(ds_all.targets)   # numpy array of shape (N,)
    num_classes = len(ds_all.classes)
    total_count = len(ds_all)
    print(f"Total samples in folder: {total_count}, classes: {ds_all.classes}")

    train_idx, val_idx, test_idx = get_split_indexes(labels, total_count)

    train_subset_for_stats = Subset(ds_all, train_idx)
    mean, std = compute_mean_std(train_subset_for_stats, batch_size)
    print(f"Computed mean: {mean}")
    print(f"Computed std:  {std}")

    tf_final = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    #  full ImageFolder but now with normalization baked in
    ds_all_norm = datasets.ImageFolder(root=data_dir, transform=tf_final)

    train_ds = Subset(ds_all_norm, train_idx)
    val_ds   = Subset(ds_all_norm, val_idx)
    test_ds  = Subset(ds_all_norm, test_idx)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=NUM_WORKERS, generator=torch.Generator().manual_seed(SEED))
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=NUM_WORKERS, generator=torch.Generator().manual_seed(SEED))
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=NUM_WORKERS, generator=torch.Generator().manual_seed(SEED))

    print(f"Train/Val/Test splits: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}")

    return train_loader, val_loader, test_loader, num_classes

def get_proportion(num_classes, dataset):
    return np.bincount(np.array(dataset.dataset.targets)[dataset.indices], minlength=num_classes) / len(dataset)

def get_split_indexes(labels, total_count):
    n_train = int(np.floor(TRAIN_FRAC * total_count))
    n_temp = total_count - n_train   # this is val + test

    sss1 = StratifiedShuffleSplit(
        n_splits=1,
        train_size=n_train,
        test_size=n_temp,
        random_state=SEED
    )
    # Train and temp(val+test) indices
    train_idx, temp_idx = next(sss1.split(np.zeros(total_count), labels))

    n_val = int(np.floor(VAL_FRAC * total_count))
    n_test = total_count - n_train - n_val
    assert n_temp == n_val + n_test, "Fractions must sum to 1."

    labels_temp = labels[temp_idx]

    sss2 = StratifiedShuffleSplit(
        n_splits=1,
        train_size=n_val,
        test_size=n_test,
        random_state=SEED
    )
    val_idx_in_temp, test_idx_in_temp = next(sss2.split(np.zeros(len(temp_idx)), labels_temp))

    val_idx = temp_idx[val_idx_in_temp]
    test_idx = temp_idx[test_idx_in_temp]

    assert len(train_idx) == n_train
    assert len(val_idx) == n_val
    assert len(test_idx) == n_test

    print(f"Stratified split sizes: train={len(train_idx)}, val={len(val_idx)}, test={len(test_idx)}")
    return train_idx,val_idx,test_idx



In [5]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark     = False

def build_model(n_cls, pretrained=False):
    m = resnet50(weights=None if not pretrained else "DEFAULT")
    m.fc = nn.Linear(m.fc.in_features, n_cls)
    return m.to(DEVICE)

# def train_one_epoch(model, loader, opt, crit, sched=None):
#     model.train()
#     tot_loss, corr, tot = 0.0, 0, 0
#     for xb, yb in loader:
#         xb, yb = xb.to(DEVICE), yb.to(DEVICE)
#         opt.zero_grad()
#         logits = model(xb)

#         loss   = crit(logits, yb)
#         loss.backward()
#         opt.step()
#         if sched: sched.step()
#         tot_loss += loss.item()
#         preds    = logits.argmax(dim=1)
#         corr    += (preds==yb).sum().item()
#         tot     += yb.size(0)
#         avg_loss = tot_loss / len(loader)

#     avg_loss = tot_loss / len(loader)
#     acc = 100.0 * corr / tot
#     return avg_loss, acc

def evaluate(model, loader, num_classes):
    model.eval()

    total_correct = 0
    total_samples = 0

    correct_per_class = torch.zeros(num_classes, dtype=torch.int64)
    total_per_class   = torch.zeros(num_classes, dtype=torch.int64)

    all_labels = []
    all_preds  = []

    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            logits = model(xb)
            preds  = logits.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(yb.cpu().numpy())

            total_correct += (preds == yb).sum().item()
            total_samples += yb.size(0)

            for c in range(num_classes):
                # mask of samples in this batch whose true label == c
                class_mask = (yb == c)
                if class_mask.sum().item() == 0:
                    continue

                total_per_class[c] += class_mask.sum().item()

                correct_per_class[c] += ((preds == yb) & class_mask).sum().item()

    overall_acc = 100.0 * total_correct / total_samples

    acc_per_class = {}
    for c in range(num_classes):
        if total_per_class[c].item() > 0:
            acc = 100.0 * correct_per_class[c].item() / total_per_class[c].item()
        else:
            acc = 0.0
        acc_per_class[c] = acc

    return overall_acc, acc_per_class, all_labels, all_preds

def plot_confusion_matrix_from_preds(y_true, y_pred, class_names):
    cm = confusion_matrix(y_true, y_pred)
    # normalize by true-label counts (row‐wise) to get percentages
    cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
    
    plt.figure(figsize=(8, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.colorbar()
    
    ticks = np.arange(len(class_names))
    plt.xticks(ticks, class_names, rotation=90)
    plt.yticks(ticks, class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    
    # threshold for text color
    thresh = cm.max() / 2.0
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            pct = cm_norm[i, j] * 100
            plt.text(
                j, i,
                f"{cm[i, j]}\n{pct:.1f}%",
                ha="center", va="center",
                color="white" if cm[i, j] > thresh else "black"
            )
    
    plt.tight_layout()
    plt.show()

def plot_class_acc_prop(te_dl, acc_vals, class_proportions_test):
    classes = te_dl.dataset.dataset.classes
    x = np.arange(len(classes))

    acc   = acc_vals
    prop  = class_proportions_test * 100

    fig, ax1 = plt.subplots(figsize=(12,6))
    bars = ax1.bar(x, acc, color='C0', alpha=0.7)
    ax1.set_ylabel('Accuracy (%)', color='C0')
    ax1.set_ylim(0, 100)
    ax1.tick_params(axis='y', labelcolor='C0')

    for bar in bars:
        h = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2, h + 1, f'{h:.1f}%', ha='center', va='bottom', color='C0')

    ax2 = ax1.twinx()
    line = ax2.plot(x, prop, color='C1', marker='o', linewidth=2)
    ax2.set_ylabel('Test Proportion (%)', color='C1')
    ax2.set_ylim(0, max(prop)*1.2)
    ax2.tick_params(axis='y', labelcolor='C1')

    for xi, yi in zip(x, prop):
        ax2.text(xi, yi + max(prop)*0.02, f'{yi:.1f}%', ha='center', va='bottom', color='C1')

    ax1.set_xticks(x)
    ax1.set_xticklabels(classes, rotation=45, ha='right')
    plt.title('Per-class Accuracy vs. Test Proportion')
    plt.tight_layout()
    plt.show()


# def hyperparam_search(pretrained=False):
#     best_val = -1.0
#     best_cfg = None
#     best_model = None

#     for bs, (lr, wd) in product(BATCH_SIZES, GRID):

#         print(f"\n>>> Testing BS={bs}, LR={lr:.1e}")
        
#         tr_dl, val_dl, te_dl, n_cls = get_data_loaders(DATA_DIR, bs)
#         model = build_model(n_cls, pretrained=pretrained)
        
#         total_steps  = NUM_EPOCHS * len(tr_dl)
#         warmup_steps = len(tr_dl)
#         opt = optim.AdamW(model.parameters(), lr=lr, betas=BETAS, eps=float(EPS), weight_decay=wd)
#         sched = SequentialLR(
#             opt,
#             schedulers=[
#                 LinearLR(opt,  start_factor=1e-6, end_factor=1.0, total_iters=warmup_steps),
#                 CosineAnnealingLR(opt, T_max=total_steps-warmup_steps)
#             ],
#             milestones=[warmup_steps]
#         )
#         crit  = nn.CrossEntropyLoss()

#         # Start a W&B run
#         wandb_run = wandb.init(
#             project="eurosat-supervised-scratch-grid-search",
#             name=f"BS{bs}_LR{lr:.0e}_TR{TRAIN_FRAC}",
#             config={
#                 "batch_size": bs,
#                 "learning_rate": lr,
#                 "epochs": NUM_EPOCHS,
#                 "pretrained": pretrained,
#             }
#         )

#         for ep in range(NUM_EPOCHS):
#             tr_loss, tr_acc = train_one_epoch(model, tr_dl, opt, crit, sched)
#             # Compute validation loss & accuracy in one pass
#             model.eval()
#             val_loss, corr, tot = 0.0, 0, 0
#             with torch.no_grad():
#                 for xb, yb in val_dl:
#                     xb, yb = xb.to(DEVICE), yb.to(DEVICE)
#                     logits = model(xb)
#                     loss = crit(logits, yb)
#                     val_loss += loss.item()
#                     preds = logits.argmax(dim=1)
#                     corr += (preds == yb).sum().item()
#                     tot  += yb.size(0)
#             val_loss /= len(val_dl)
#             val_acc = 100.0 * corr / tot

#             print(f"  Ep{ep+1}/{NUM_EPOCHS}: train_acc={tr_acc:.1f}%  train_loss={tr_loss:.4f}, "
#                   f"val_acc={val_acc:.1f}%, val_loss={val_loss:.4f}")

#             wandb.log({
#                 "epoch":       ep + 1,
#                 "train_loss":  tr_loss,
#                 "train_acc":   tr_acc,
#                 "val_loss":    val_loss,
#                 "val_acc":     val_acc
#             })

#         wandb_run.finish()

#         # Only use val_acc to pick best
#         if val_acc > best_val:
#             best_val   = val_acc
#             best_cfg   = (bs, lr, wd)
#             best_model = copy.deepcopy(model)

#     print(f"\n>>> Best config: BS={best_cfg[0]}, LR={best_cfg[1]:.1e}, val_acc={best_val:.1f}%")
    
#     return best_cfg, best_model




# Perform Hyperparameter Search, Retrain on Train + Validation Set, Evaluate on Test Set

In [None]:

# Assuming build_lr_scheduler and TRAINING_SCHEDULES are defined as before

# Alternative (better) make_optimizer_scheduler using the helper
def make_optimizer_scheduler_reused(params, lr, wd, schedule_name, steps_per_epoch):
    """
    Builds the SGD optimizer and the specific learning rate scheduler
    by reusing the build_lr_scheduler function.

    Args:
        params: Model parameters.
        lr: Learning rate.
        wd: Weight decay.
        schedule_name: Name of the training schedule ('short', 'medium', 'long').
        steps_per_epoch: Number of optimization steps in one epoch.
    """
    opt = optim.SGD(params, lr=lr, momentum=0.9, weight_decay=wd)
    schedule_cfg = TRAINING_SCHEDULES[schedule_name]

    # We need to provide a `total_training_steps` to build_lr_scheduler,
    # though MultiStepLR doesn't strictly use it beyond its milestones.
    # For consistency, we can pass the max step from milestones or a very large number.
    # Let's pass the max of the milestones as an effective 'total_steps' for the schedule.
    # The actual NUM_EPOCHS for training will be determined by the schedule logic.
    total_steps_for_scheduler_config = max(schedule_cfg['p']) if schedule_cfg['unit'] == 'steps' else max(schedule_cfg['p']) * steps_per_epoch

    scheduler = build_lr_scheduler(opt, total_steps_for_scheduler_config, schedule_cfg, steps_per_epoch)
    return opt, scheduler


# Assuming build_model, train_one_epoch, DEVICE, and TRAINING_SCHEDULES are defined

def retrain_final_model(tr_dl, val_dl, n_cls, bs, lr, wd, schedule_name): # Added schedule_name

    print("\n>>> Retraining final model on TRAIN+VAL combined with best hyperparameters")
    combined_ds = ConcatDataset([tr_dl.dataset, val_dl.dataset])

    combined_dl = DataLoader(combined_ds, batch_size=bs, shuffle=True, num_workers=4) # Assuming 4 workers

    model = build_model(n_cls, pretrained=False)
    model.to(DEVICE) # Move model to device

    # Determine total epochs for this specific schedule
    steps_per_epoch = len(combined_dl)
    schedule_cfg = TRAINING_SCHEDULES[schedule_name]

    if schedule_cfg["unit"] == "steps":
        total_steps_for_run = max(schedule_cfg["p"]) # Run for at least the last milestone
        num_epochs_for_run = int(np.ceil(total_steps_for_run / steps_per_epoch)) + 1 # Add buffer
    else: # schedule_cfg["unit"] == "epochs"
        num_epochs_for_run = max(schedule_cfg["p"]) + schedule_cfg["w"]


    # Use the new make_optimizer_scheduler
    optimizer, scheduler = make_optimizer_scheduler_reused( # Changed function name
        model.parameters(), lr, wd, schedule_name, steps_per_epoch
    )
    criterion = nn.CrossEntropyLoss()

    for ep in range(num_epochs_for_run): # Changed num_epochs to num_epochs_for_run
        loss, acc = train_one_epoch(model, combined_dl, optimizer, criterion, scheduler, DEVICE) # Pass DEVICE
        print(f"  Ep {ep+1}/{num_epochs_for_run}: train_acc={acc:.1f}%")
    return model, combined_ds

#  Evaluate & log to wandb
def evaluate_and_log(final_model, te_dl, combined_ds, n_cls, bs, lr):
    """
    Evaluate on test set, print per-class stats, log to wandb, and plot.
    """
    final_test_acc, acc_per_class, y_true, y_pred = evaluate(final_model, te_dl, n_cls)
    plot_confusion_matrix_from_preds(y_true, y_pred, te_dl.dataset.dataset.classes)

    test_targs = np.array(te_dl.dataset.dataset.targets)[te_dl.dataset.indices]
    prop_test = np.bincount(test_targs, minlength=n_cls) / len(test_targs)

    combined_targs = np.concatenate([
        np.array(ds.dataset.targets)[ds.indices] for ds in combined_ds.datasets
    ])
    prop_trainval = np.bincount(combined_targs, minlength=n_cls) / len(combined_targs)

    acc_vals = np.array([acc_per_class[c] for c in range(n_cls)])
    weighted_acc = (acc_vals * prop_test).sum()

    print("\n>>> Final Test Accuracy:")
    print(f"  Overall:             {final_test_acc:5.1f}%")
    print(f"  Weighted class acc.: {weighted_acc:5.1f}%\n")
    hdr = f"{'Class':20s}  {'Acc':>6s}   {'Train+Val':>9s}   {'Test':>6s}"
    print(hdr); print("-"*len(hdr))
    for c, name in enumerate(te_dl.dataset.dataset.classes):
        print(f"{name:20s}  {acc_vals[c]:6.1f}%   {prop_trainval[c]*100:8.0f}%   {prop_test[c]*100:6.0f}%")

    wandb.init(
        project="eurosat-supervised-scratch-final-lrsched",
        name=f"BS{bs}_LR{lr:.0e}_final",
        config={
            "batch_size": bs, "learning_rate": lr, "epochs": NUM_EPOCHS,
            "pretrained": False, "final_retrain": True
        }
    )
    wandb.log({
        "final_test_acc":     final_test_acc,
        "weighted_class_acc": weighted_acc,
        "per_class_acc":      acc_vals
    })
    wandb.finish()

    plot_class_acc_prop(te_dl, acc_vals, prop_test)


In [7]:
# Main
set_seed(SEED)

best_cfg, _    = hyperparam_search(pretrained=False)
bs, lr, wd     = best_cfg
tr_dl, val_dl, te_dl, n_cls = get_data_loaders(DATA_DIR, bs)

# Retrain on TRAIN+VAL
final_model, combined_ds = retrain_final_model(tr_dl, val_dl, n_cls, bs, lr, wd, NUM_EPOCHS)

evaluate_and_log(final_model, te_dl, combined_ds, n_cls, bs, lr)

final_path = f"models/eurosat_supervised_final_bs{bs}_lr{lr:.0e}_epcs{NUM_EPOCHS}.pth"
torch.save(final_model.state_dict(), final_path)
print(f"Final model saved to {final_path}")



>>> Testing BS=512, LR=1.0e-01, WD=1.0e-02, Schedule=short
Total samples in folder: 27000, classes: ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']
Stratified split sizes: train=21600, val=2700, test=2700


Computed mean: [0.3441457152366638, 0.3800985515117645, 0.40766361355781555]
Computed std:  [0.09299741685390472, 0.06464490294456482, 0.05413917079567909]
Train/Val/Test splits: 21600/2700/2700


  Ep1/60: train_acc=25.8%  train_loss=2.0838, val_acc=31.2%, val_loss=33.5076
  Ep2/60: train_acc=59.7%  train_loss=1.1405, val_acc=35.0%, val_loss=1.9872
  Ep3/60: train_acc=71.4%  train_loss=0.8384, val_acc=21.1%, val_loss=2.8427
  Ep4/60: train_acc=65.9%  train_loss=1.0572, val_acc=21.6%, val_loss=2.1569




  Ep5/60: train_acc=68.7%  train_loss=0.8811, val_acc=21.5%, val_loss=2.2104
  Ep6/60: train_acc=72.8%  train_loss=0.7819, val_acc=21.5%, val_loss=2.3453
  Ep7/60: train_acc=71.2%  train_loss=0.8449, val_acc=36.1%, val_loss=1.6493
  Ep8/60: train_acc=74.2%  train_loss=0.7450, val_acc=31.9%, val_loss=1.9516
  Ep9/60: train_acc=71.8%  train_loss=0.8444, val_acc=32.7%, val_loss=1.8682
  Ep10/60: train_acc=73.2%  train_loss=0.7777, val_acc=43.1%, val_loss=1.6699
  Ep11/60: train_acc=73.7%  train_loss=0.7708, val_acc=32.6%, val_loss=3.5736
  Ep12/60: train_acc=73.0%  train_loss=0.7876, val_acc=41.1%, val_loss=2.0222
  Ep13/60: train_acc=74.5%  train_loss=0.7356, val_acc=61.5%, val_loss=1.0721
  Ep14/60: train_acc=75.9%  train_loss=0.6984, val_acc=42.1%, val_loss=1.8844
  Ep15/60: train_acc=78.1%  train_loss=0.6420, val_acc=52.3%, val_loss=1.3692
  Ep16/60: train_acc=76.7%  train_loss=0.6840, val_acc=37.3%, val_loss=2.2963
  Ep17/60: train_acc=78.1%  train_loss=0.6389, val_acc=47.8%, val_los

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
learning_rate,▂▄▆▇███████████▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_acc,▁▄▅▅▅▅▆▅▆▆▆▆▆▆▅▇▇▇▇▇▇▇▇▇▇███████████████
train_loss,█▆▆▅▆▆▅▅▅▅▅▅▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▂▂▁▁▁▃▂▃▅▃▃▁▄▁▃▇▇▇▇▇▇▅▇▇▇███████████████
val_loss,█▂▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,60.0
learning_rate,0.0001
train_acc,99.59259
train_loss,0.04921
val_acc,93.03704
val_loss,0.22493



>>> Testing BS=512, LR=1.0e-01, WD=1.0e-02, Schedule=medium
Total samples in folder: 27000, classes: ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']
Stratified split sizes: train=21600, val=2700, test=2700
Computed mean: [0.3441457152366638, 0.3800985515117645, 0.40766361355781555]
Computed std:  [0.09299741685390472, 0.06464490294456482, 0.05413917079567909]
Train/Val/Test splits: 21600/2700/2700


  Ep1/234: train_acc=22.7%  train_loss=2.0879, val_acc=21.2%, val_loss=2.8882
  Ep2/234: train_acc=48.1%  train_loss=1.3973, val_acc=52.1%, val_loss=1.6670
  Ep3/234: train_acc=68.0%  train_loss=0.9018, val_acc=37.7%, val_loss=1.8779
  Ep4/234: train_acc=73.7%  train_loss=0.7480, val_acc=23.4%, val_loss=3.5837
  Ep5/234: train_acc=75.7%  train_loss=0.7064, val_acc=29.6%, val_loss=2.2233
  Ep6/234: train_acc=77.9%  train_loss=0.6502, val_acc=28.7%, val_loss=2.1823
  Ep7/234: train_acc=76.4%  train_loss=0.6760, val_acc=23.7%, val_loss=2.3774
  Ep8/234: train_acc=78.7%  train_loss=0.6253, val_acc=24.6%, val_loss=2.3478
  Ep9/234: train_acc=77.4%  train_loss=0.6782, val_acc=26.4%, val_loss=2.1369
  Ep10/234: train_acc=76.2%  train_loss=0.7009, val_acc=29.0%, val_loss=2.2424
  Ep11/234: train_acc=78.3%  train_loss=0.6517, val_acc=30.5%, val_loss=1.9619




  Ep12/234: train_acc=76.6%  train_loss=0.6944, val_acc=24.0%, val_loss=2.7840
  Ep13/234: train_acc=76.8%  train_loss=0.6880, val_acc=35.0%, val_loss=1.9749
  Ep14/234: train_acc=76.5%  train_loss=0.6940, val_acc=44.4%, val_loss=1.7644
  Ep15/234: train_acc=80.3%  train_loss=0.5831, val_acc=36.7%, val_loss=2.0977
  Ep16/234: train_acc=75.9%  train_loss=0.7209, val_acc=45.4%, val_loss=1.8836
  Ep17/234: train_acc=79.5%  train_loss=0.6056, val_acc=46.4%, val_loss=1.8234
  Ep18/234: train_acc=77.0%  train_loss=0.6780, val_acc=36.5%, val_loss=2.7140
  Ep19/234: train_acc=79.2%  train_loss=0.6252, val_acc=61.3%, val_loss=1.1872
  Ep20/234: train_acc=79.9%  train_loss=0.5995, val_acc=32.3%, val_loss=2.9145
  Ep21/234: train_acc=79.4%  train_loss=0.6238, val_acc=54.5%, val_loss=1.4266
  Ep22/234: train_acc=81.1%  train_loss=0.5712, val_acc=35.5%, val_loss=1.9005
  Ep23/234: train_acc=74.6%  train_loss=0.7605, val_acc=66.4%, val_loss=0.9400
  Ep24/234: train_acc=80.7%  train_loss=0.5727, val_

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▄▄▄▅▅▅▆▆▆▆▇▇▇▇▇███
learning_rate,▆██████████████▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_acc,▂▄▄▅▅▄▄▁▃▃▃▃▃▄▆▆▆▆▆▇▇▇▇▇████████████████
train_loss,█▅▄▅▄▅▅▄▄▅▅▄▅▅▃▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▁▁▄▂▄▅▄▃▄▃▃▄▄▇▆▇▇▆▇▅▆▆▆▅▇███████████████
val_loss,▂▂▂▂▁▁▂█▁▁▂▂▂▂▁▁▁▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,234.0
learning_rate,0.0001
train_acc,99.98148
train_loss,0.02307
val_acc,91.0
val_loss,0.33006



>>> Testing BS=512, LR=1.0e-01, WD=1.0e-02, Schedule=long
Total samples in folder: 27000, classes: ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']
Stratified split sizes: train=21600, val=2700, test=2700
Computed mean: [0.3441457152366638, 0.3800985515117645, 0.40766361355781555]
Computed std:  [0.09299741685390472, 0.06464490294456482, 0.05413917079567909]
Train/Val/Test splits: 21600/2700/2700


  Ep1/95: train_acc=26.0%  train_loss=2.0657, val_acc=26.5%, val_loss=5.8963
  Ep2/95: train_acc=57.2%  train_loss=1.2006, val_acc=39.0%, val_loss=2.0896
  Ep3/95: train_acc=70.8%  train_loss=0.8438, val_acc=17.4%, val_loss=3.1725
  Ep4/95: train_acc=56.3%  train_loss=1.4307, val_acc=16.4%, val_loss=2.9891




  Ep5/95: train_acc=69.4%  train_loss=0.8638, val_acc=21.5%, val_loss=2.6482
  Ep6/95: train_acc=69.0%  train_loss=0.9159, val_acc=22.3%, val_loss=2.3186
  Ep7/95: train_acc=70.9%  train_loss=0.8563, val_acc=24.6%, val_loss=2.2777
  Ep8/95: train_acc=75.4%  train_loss=0.7192, val_acc=34.3%, val_loss=2.2257
  Ep9/95: train_acc=71.6%  train_loss=0.8345, val_acc=39.1%, val_loss=1.8986
  Ep10/95: train_acc=68.0%  train_loss=0.9247, val_acc=42.7%, val_loss=2.0064
  Ep11/95: train_acc=74.6%  train_loss=0.7395, val_acc=45.1%, val_loss=1.6581
  Ep12/95: train_acc=73.1%  train_loss=0.7757, val_acc=32.2%, val_loss=3.8916
  Ep13/95: train_acc=71.4%  train_loss=0.8380, val_acc=61.9%, val_loss=1.2716
  Ep14/95: train_acc=76.0%  train_loss=0.6970, val_acc=42.7%, val_loss=1.5920
  Ep15/95: train_acc=76.3%  train_loss=0.6882, val_acc=39.7%, val_loss=2.3777
  Ep16/95: train_acc=77.0%  train_loss=0.6659, val_acc=47.3%, val_loss=1.6255
  Ep17/95: train_acc=78.7%  train_loss=0.6272, val_acc=43.7%, val_los

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▇▇▇▇▇▇██
learning_rate,▄▇████████████▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_acc,▁▁▄▃▃▄▄▄▅▄▅▄▅▄▄▄▄▆▆▆▇▇▇▇▇▇▇▇▇▇██████████
train_loss,█▄▆▄▄▃▄▃▃▃▃▃▃▃▄▄▃▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▂▁▁▂▂▃▄▂▃▄▄▄▃▃▅▃▄▂▄▄▇▇▆▇▆▆▆▇▇▇██████████
val_loss,█▃▅▄▄▃▃▃▆▃▃▄▆▃▃▂▁▁▁▂▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,95.0
learning_rate,1e-05
train_acc,99.88889
train_loss,0.03044
val_acc,92.81481
val_loss,0.24943



>>> Testing BS=512, LR=1.0e-01, WD=1.0e-04, Schedule=short
Total samples in folder: 27000, classes: ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']
Stratified split sizes: train=21600, val=2700, test=2700
Computed mean: [0.3441457152366638, 0.3800985515117645, 0.40766361355781555]
Computed std:  [0.09299741685390472, 0.06464490294456482, 0.05413917079567909]
Train/Val/Test splits: 21600/2700/2700


  Ep1/60: train_acc=24.1%  train_loss=2.1721, val_acc=21.4%, val_loss=70.9680
  Ep2/60: train_acc=55.7%  train_loss=1.2217, val_acc=48.7%, val_loss=2.7230
  Ep3/60: train_acc=69.9%  train_loss=0.8667, val_acc=61.5%, val_loss=1.0913
  Ep4/60: train_acc=69.4%  train_loss=0.9786, val_acc=42.0%, val_loss=12.7844




  Ep5/60: train_acc=75.6%  train_loss=0.6980, val_acc=65.1%, val_loss=1.2664
  Ep6/60: train_acc=81.2%  train_loss=0.5370, val_acc=76.2%, val_loss=0.7054
  Ep7/60: train_acc=83.4%  train_loss=0.4704, val_acc=81.6%, val_loss=0.4924
  Ep8/60: train_acc=87.0%  train_loss=0.3782, val_acc=78.8%, val_loss=0.6390
  Ep9/60: train_acc=89.8%  train_loss=0.2923, val_acc=82.4%, val_loss=0.5696
  Ep10/60: train_acc=90.5%  train_loss=0.2663, val_acc=60.7%, val_loss=3.1739
  Ep11/60: train_acc=90.9%  train_loss=0.2584, val_acc=64.8%, val_loss=4.8810
  Ep12/60: train_acc=91.5%  train_loss=0.2406, val_acc=83.2%, val_loss=0.5728
  Ep13/60: train_acc=94.5%  train_loss=0.1574, val_acc=78.0%, val_loss=0.9296
  Ep14/60: train_acc=94.7%  train_loss=0.1549, val_acc=83.9%, val_loss=0.5315
  Ep15/60: train_acc=96.2%  train_loss=0.1130, val_acc=85.3%, val_loss=0.5495
  Ep16/60: train_acc=96.5%  train_loss=0.0965, val_acc=82.4%, val_loss=0.6820
  Ep17/60: train_acc=95.4%  train_loss=0.1386, val_acc=75.1%, val_los

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇███
learning_rate,▄▇█████████████▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
train_acc,▁▄▅▆▆▇▇▇▇▇█▆▆▆▇▇▇███████████████████████
train_loss,█▆▅▄▄▃▂▂▂▂▂▂▅▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▄▁▄▆▇▄▄▇▆▇▅▆▅▆██████████████████████████
val_loss,█▁▁▂▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,60.0
learning_rate,0.0001
train_acc,99.89352
train_loss,0.00662
val_acc,88.11111
val_loss,0.60835



>>> Testing BS=512, LR=1.0e-01, WD=1.0e-04, Schedule=medium
Total samples in folder: 27000, classes: ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']
Stratified split sizes: train=21600, val=2700, test=2700
Computed mean: [0.3441457152366638, 0.3800985515117645, 0.40766361355781555]
Computed std:  [0.09299741685390472, 0.06464490294456482, 0.05413917079567909]
Train/Val/Test splits: 21600/2700/2700


  Ep1/234: train_acc=23.1%  train_loss=2.0663, val_acc=20.8%, val_loss=7.6332
  Ep2/234: train_acc=46.4%  train_loss=1.4680, val_acc=43.7%, val_loss=3.4499
  Ep3/234: train_acc=67.5%  train_loss=0.9260, val_acc=44.0%, val_loss=2.3131
  Ep4/234: train_acc=73.4%  train_loss=0.7557, val_acc=53.0%, val_loss=1.8635
  Ep5/234: train_acc=77.7%  train_loss=0.6357, val_acc=56.3%, val_loss=3.0373
  Ep6/234: train_acc=80.3%  train_loss=0.5666, val_acc=58.4%, val_loss=1.4696
  Ep7/234: train_acc=82.8%  train_loss=0.4984, val_acc=71.1%, val_loss=0.8551
  Ep8/234: train_acc=84.8%  train_loss=0.4334, val_acc=69.6%, val_loss=0.9583
  Ep9/234: train_acc=86.9%  train_loss=0.3814, val_acc=65.1%, val_loss=1.5002
  Ep10/234: train_acc=86.5%  train_loss=0.3977, val_acc=52.2%, val_loss=3.1940
  Ep11/234: train_acc=79.6%  train_loss=0.6554, val_acc=52.1%, val_loss=32.0678




  Ep12/234: train_acc=84.9%  train_loss=0.4436, val_acc=80.8%, val_loss=0.5640
  Ep13/234: train_acc=88.9%  train_loss=0.3210, val_acc=83.6%, val_loss=0.5362
  Ep14/234: train_acc=90.3%  train_loss=0.2874, val_acc=83.0%, val_loss=0.5162
  Ep15/234: train_acc=61.9%  train_loss=1.4221, val_acc=30.5%, val_loss=7541.6324
  Ep16/234: train_acc=67.2%  train_loss=0.9244, val_acc=57.6%, val_loss=1.1676
  Ep17/234: train_acc=73.2%  train_loss=0.7587, val_acc=74.0%, val_loss=0.7297
  Ep18/234: train_acc=76.9%  train_loss=0.6428, val_acc=64.9%, val_loss=1.0262
  Ep19/234: train_acc=80.4%  train_loss=0.5535, val_acc=80.7%, val_loss=0.5325
  Ep20/234: train_acc=82.9%  train_loss=0.4856, val_acc=77.1%, val_loss=0.6481
  Ep21/234: train_acc=83.5%  train_loss=0.4711, val_acc=74.2%, val_loss=0.9843
  Ep22/234: train_acc=85.1%  train_loss=0.4200, val_acc=83.4%, val_loss=0.4672
  Ep23/234: train_acc=87.1%  train_loss=0.3636, val_acc=71.7%, val_loss=0.9703
  Ep24/234: train_acc=88.9%  train_loss=0.3140, v

KeyboardInterrupt: 