## TODO:

* TRAINING


* RECOMMENDED: (training speed-up) Precompute and cache each submodel’s softmax outputs once, store them on disk, and then train the meta‐learner on those saved vectors. This removes the four forward passes during each epoch.


* OPTIONAL: chatgpt recommends to not use the same train set for training ensemble model --> check this


* OPTIONAL: look at the difference between using one-hidden layer or not (BIG DIFFERENCE)


* OPTIONAL: improve trainig loop (with optional interrupt key and saving model)

In [1]:
import sys, os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
softmax = torch.nn.Softmax(dim=1)
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import (
    top_k_accuracy_score,
    classification_report,
    confusion_matrix
)
import random
from pathlib import Path
from tqdm import tqdm
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from typing import List, Tuple
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [30]:
# Compute absolute path to the `src/` folder
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
SRC_PATH     = os.path.join(PROJECT_ROOT, "src")

if SRC_PATH not in sys.path:
    sys.path.insert(0, SRC_PATH)

from utils import get_dataloaders, load_model, evaluate_model, print_metrics, plot_confusion_matrix, show_sample_predictions, plot_random_image_with_label_and_prediction

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

In [4]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Using device:", device)

Using device: mps


In [5]:
COUNTRIES = ["Albania","Andorra","Argentina","Australia","Austria","Bangladesh","Belgium","Bhutan","Bolivia","Botswana","Brazil","Bulgaria","Cambodia","Canada","Chile","Colombia","Croatia","Czechia","Denmark","Dominican Republic","Ecuador","Estonia","Eswatini","Finland","France","Germany","Ghana","Greece","Greenland","Guatemala","Hungary","Iceland","Indonesia","Ireland","Israel","Italy","Japan","Jordan","Kenya","Kyrgyzstan","Latvia","Lesotho","Lithuania","Luxembourg","Malaysia","Mexico","Mongolia","Montenegro","Netherlands","New Zealand","Nigeria","North Macedonia","Norway","Palestine","Peru","Philippines","Poland","Portugal","Romania","Russia","Senegal","Serbia","Singapore","Slovakia","Slovenia","South Africa","South Korea","Spain","Sri Lanka","Sweden","Switzerland","Taiwan","Thailand","Turkey","Ukraine","United Arab Emirates","United Kingdom","United States","Uruguay"]
num_classes = len(COUNTRIES)
project_root   = Path().resolve().parent

### Data

In [6]:
train_root = project_root/ "data" / "final_datasets" / "train"
train_loader = get_dataloaders(train_root, batch_size=32)

val_root = project_root/ "data" / "final_datasets" / "val"
val_loader = get_dataloaders(val_root, batch_size=32)

test_root = project_root/ "data" / "final_datasets" / "test"
test_loader = get_dataloaders(test_root, batch_size=32)

### Load models

In [8]:
base_model = load_model(model_path=project_root / "models" / "resnet_finetuned" / "main.pth", device=device)
road = load_model(model_path=project_root / "models" / "resnet_finetuned_road" / "main.pth", device=device)
terrain = load_model(model_path=project_root / "models" / "resnet_finetuned_terrain" / "main.pth", device=device)
vegetation = load_model(model_path=project_root / "models" / "resnet_finetuned_vegetation" / "main.pth", device=device)

  model.load_state_dict(torch.load(model_path, map_location=device))


In [9]:
MODEL_NAME = "nvidia/segformer-b0-finetuned-cityscapes-768-768"

feature_extractor = SegformerFeatureExtractor.from_pretrained(MODEL_NAME)
seg_model = SegformerForSemanticSegmentation.from_pretrained(MODEL_NAME).eval()



In [14]:
CITYSCAPES_LABELS = {
    0: 'road', 
    8: 'vegetation',  9: 'terrain'
}

TARGET_CLASSES = {'road','terrain','vegetation'}

## Ensemble

In [45]:
# Utility to fetch softmax probs from a pretrained submodel
def get_probs(model, img_tensor, device):
    model.eval()
    with torch.no_grad():
        out = model(img_tensor.to(device))
        probs = nn.functional.softmax(out, dim=1).cpu().squeeze(0).numpy()
    return probs

In [46]:
# Ensemble Dataset: wraps CountryImageDataset, runs submodels to produce features
class EnsembleDataset(Dataset):
    def __init__(self, base_ds, submodels, device):
        """
        base_ds: CountryImageDataset
        submodels: dict {'base': base_model, 'road':road_model, ...}
        Each model returns a softmax vector of length num_classes.
        """
        self.base_ds   = base_ds
        self.submodels = submodels
        self.device    = device

    def __len__(self):
        return len(self.base_ds)

    def __getitem__(self, idx):
        img, label = self.base_ds[idx]
        x = img.unsqueeze(0)  # add batch dim
        # 1) Collect each model's probs
        feats = []
        for name, m in self.submodels.items():
            p = get_probs(m, x, self.device)  # shape (num_classes,)
            feats.append(p)
        # 2) Concatenate into one feature vector
        feature_vector = np.concatenate(feats).astype(np.float32)
        return torch.from_numpy(feature_vector), label

In [47]:
# Ensemble Network: one hidden layer
class EnsembleNet(nn.Module):
    def __init__(self, in_dim, hid_dim, num_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, num_classes)
        )

    def forward(self, x):
        return self.net(x)

In [48]:
def train_epoch(model, loader, loss_fn, optimizer, device, epoch, log_every=5):
    """
    Runs one epoch of training, printing updates every `log_every` batches.

    Args:
        model       (nn.Module):      the network to train
        loader      (DataLoader):     training data loader
        loss_fn     (callable):       loss function
        optimizer   (torch.optim.Optimizer)
        device      (torch.device)
        epoch       (int):            current epoch number (for prints)
        log_every   (int):            how many batches between prints

    Returns:
        avg_loss (float), avg_acc (float)
    """
    model.train()
    running_loss = 0.0
    running_correct = 0
    total_samples = 0

    for batch_idx, (imgs, labels) in enumerate(loader, start=1):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        # Metrics
        preds = outputs.argmax(dim=1)
        batch_correct = (preds == labels).sum().item()
        batch_size = imgs.size(0)

        running_loss    += loss.item() * batch_size
        running_correct += batch_correct
        total_samples   += batch_size

        # Log every N batches
        if batch_idx % log_every == 0 or batch_idx == len(loader):
            batch_loss = running_loss / total_samples
            batch_acc  = running_correct / total_samples
            print(f"Epoch {epoch} [{batch_idx}/{len(loader)}]  "
                  f"Loss: {batch_loss:.4f}  Acc: {batch_acc:.4f}")

    avg_loss = running_loss / total_samples
    avg_acc  = running_correct / total_samples
    return avg_loss, avg_acc


In [25]:
def eval_epoch(model, loader, loss_fn, device):
    model.eval()
    total_loss = total_correct = 0
    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            logits = model(X)
            loss   = loss_fn(logits, y)
            preds  = logits.argmax(dim=1)
            total_correct += (preds==y).sum().item()
            total_loss    += loss.item() * X.size(0)
    return total_loss/len(loader.dataset), total_correct/len(loader.dataset)

In [49]:
# a) Wrap submodels
submodels = {
    'base':     base_model,
    'road':     road,
    'terrain':  terrain,
    'vegetation': vegetation
}

# b) Build ensemble datasets
train_ds = EnsembleDataset(train_loader.dataset, submodels, device)
val_ds   = EnsembleDataset(val_loader.dataset,   submodels, device)
test_ds  = EnsembleDataset(test_loader.dataset,  submodels, device)

train_el = DataLoader(train_ds, batch_size=32, shuffle=True)
val_el   = DataLoader(val_ds,   batch_size=32)
test_el  = DataLoader(test_ds,  batch_size=32)

In [50]:
# Instantiate meta‐model
project_root = Path().resolve().parent

hid_dim=128
epochs=10
lr=1e-3

in_dim = num_classes * len(submodels)
model = EnsembleNet(in_dim, hid_dim, num_classes).to(device)
loss_fn = nn.CrossEntropyLoss()
opt     = optim.Adam(model.parameters(), lr=lr)

best_val_loss = float('inf')
ckpt_path = project_root / "models" / "ensemble" / "main.pth"

In [51]:
for epoch in range(1, epochs+1):
    tr_loss, tr_acc = train_epoch(model, train_el, loss_fn, opt, device, epoch)
    vl_loss, vl_acc = eval_epoch(model, val_el,   loss_fn, None, device)
    print(f"Epoch {epoch}: train {tr_loss:.3f}/{tr_acc:.3f} | val {vl_loss:.3f}/{vl_acc:.3f}")

    # Save best
    if vl_loss < best_val_loss:
        best_val_loss = vl_loss
        torch.save(model.state_dict(), ckpt_path)

Epoch 1 [5/889]  Loss: 4.3709  Acc: 0.0312
Epoch 1 [10/889]  Loss: 4.3681  Acc: 0.0187
Epoch 1 [15/889]  Loss: 4.3612  Acc: 0.0479
Epoch 1 [20/889]  Loss: 4.3545  Acc: 0.0766
Epoch 1 [25/889]  Loss: 4.3507  Acc: 0.0875
Epoch 1 [30/889]  Loss: 4.3437  Acc: 0.1042
Epoch 1 [35/889]  Loss: 4.3378  Acc: 0.1179
Epoch 1 [40/889]  Loss: 4.3305  Acc: 0.1375
Epoch 1 [45/889]  Loss: 4.3239  Acc: 0.1569
Epoch 1 [50/889]  Loss: 4.3155  Acc: 0.1831
Epoch 1 [55/889]  Loss: 4.3088  Acc: 0.2028
Epoch 1 [60/889]  Loss: 4.2989  Acc: 0.2286
Epoch 1 [65/889]  Loss: 4.2883  Acc: 0.2505
Epoch 1 [70/889]  Loss: 4.2781  Acc: 0.2665


KeyboardInterrupt: 

With hidden layers in 2 min 

Epoch 1 [5/889]  Loss: 4.3709  Acc: 0.0312
Epoch 1 [10/889]  Loss: 4.3681  Acc: 0.0187
Epoch 1 [15/889]  Loss: 4.3612  Acc: 0.0479
Epoch 1 [20/889]  Loss: 4.3545  Acc: 0.0766
Epoch 1 [25/889]  Loss: 4.3507  Acc: 0.0875
Epoch 1 [30/889]  Loss: 4.3437  Acc: 0.1042
Epoch 1 [35/889]  Loss: 4.3378  Acc: 0.1179
Epoch 1 [40/889]  Loss: 4.3305  Acc: 0.1375
Epoch 1 [45/889]  Loss: 4.3239  Acc: 0.1569
Epoch 1 [50/889]  Loss: 4.3155  Acc: 0.1831
Epoch 1 [55/889]  Loss: 4.3088  Acc: 0.2028
Epoch 1 [60/889]  Loss: 4.2989  Acc: 0.2286
Epoch 1 [65/889]  Loss: 4.2883  Acc: 0.2505
Epoch 1 [70/889]  Loss: 4.2781  Acc: 0.2665

No hidden layer in 15 min (started from 0 actually)

Epoch 1 [5/889]  Loss: 4.2254  Acc: 0.3000
Epoch 1 [10/889]  Loss: 4.2103  Acc: 0.3375
Epoch 1 [15/889]  Loss: 4.2114  Acc: 0.3229
Epoch 1 [20/889]  Loss: 4.2089  Acc: 0.3328
Epoch 1 [25/889]  Loss: 4.2062  Acc: 0.3337
Epoch 1 [30/889]  Loss: 4.2012  Acc: 0.3458
Epoch 1 [35/889]  Loss: 4.1964  Acc: 0.3696
Epoch 1 [40/889]  Loss: 4.1932  Acc: 0.3742
Epoch 1 [45/889]  Loss: 4.1904  Acc: 0.3792
Epoch 1 [50/889]  Loss: 4.1871  Acc: 0.3812
Epoch 1 [55/889]  Loss: 4.1824  Acc: 0.3937
Epoch 1 [60/889]  Loss: 4.1785  Acc: 0.4005
Epoch 1 [65/889]  Loss: 4.1748  Acc: 0.4101
Epoch 1 [70/889]  Loss: 4.1718  Acc: 0.4170
Epoch 1 [75/889]  Loss: 4.1685  Acc: 0.4217
Epoch 1 [80/889]  Loss: 4.1657  Acc: 0.4242
Epoch 1 [85/889]  Loss: 4.1624  Acc: 0.4290
Epoch 1 [90/889]  Loss: 4.1595  Acc: 0.4323
Epoch 1 [95/889]  Loss: 4.1570  Acc: 0.4352
Epoch 1 [100/889]  Loss: 4.1532  Acc: 0.4409
Epoch 1 [105/889]  Loss: 4.1505  Acc: 0.4458
Epoch 1 [110/889]  Loss: 4.1477  Acc: 0.4472
Epoch 1 [115/889]  Loss: 4.1455  Acc: 0.4473
Epoch 1 [120/889]  Loss: 4.1422  Acc: 0.4523
Epoch 1 [125/889]  Loss: 4.1387  Acc: 0.4585
Epoch 1 [130/889]  Loss: 4.1355  Acc: 0.4620
Epoch 1 [135/889]  Loss: 4.1324  Acc: 0.4644
Epoch 1 [140/889]  Loss: 4.1296  Acc: 0.4656
Epoch 1 [145/889]  Loss: 4.1267  Acc: 0.4690
Epoch 1 [150/889]  Loss: 4.1238  Acc: 0.4700
Epoch 1 [155/889]  Loss: 4.1206  Acc: 0.4736
Epoch 1 [160/889]  Loss: 4.1174  Acc: 0.4770
Epoch 1 [165/889]  Loss: 4.1144  Acc: 0.4797
Epoch 1 [170/889]  Loss: 4.1113  Acc: 0.4825
Epoch 1 [175/889]  Loss: 4.1082  Acc: 0.4855
Epoch 1 [180/889]  Loss: 4.1053  Acc: 0.4884
Epoch 1 [185/889]  Loss: 4.1025  Acc: 0.4900
Epoch 1 [190/889]  Loss: 4.0994  Acc: 0.4914
Epoch 1 [195/889]  Loss: 4.0966  Acc: 0.4933
Epoch 1 [200/889]  Loss: 4.0933  Acc: 0.4980
Epoch 1 [205/889]  Loss: 4.0902  Acc: 0.5012
Epoch 1 [210/889]  Loss: 4.0867  Acc: 0.5045
Epoch 1 [215/889]  Loss: 4.0839  Acc: 0.5061
Epoch 1 [220/889]  Loss: 4.0807  Acc: 0.5078
Epoch 1 [225/889]  Loss: 4.0776  Acc: 0.5108
Epoch 1 [230/889]  Loss: 4.0745  Acc: 0.5130
Epoch 1 [235/889]  Loss: 4.0717  Acc: 0.5160
Epoch 1 [240/889]  Loss: 4.0684  Acc: 0.5188
Epoch 1 [245/889]  Loss: 4.0654  Acc: 0.5213
Epoch 1 [250/889]  Loss: 4.0622  Acc: 0.5234
Epoch 1 [255/889]  Loss: 4.0599  Acc: 0.5241
Epoch 1 [260/889]  Loss: 4.0567  Acc: 0.5266
Epoch 1 [265/889]  Loss: 4.0542  Acc: 0.5270
Epoch 1 [270/889]  Loss: 4.0509  Acc: 0.5297
Epoch 1 [275/889]  Loss: 4.0480  Acc: 0.5307
Epoch 1 [280/889]  Loss: 4.0449  Acc: 0.5338
Epoch 1 [285/889]  Loss: 4.0418  Acc: 0.5359
Epoch 1 [290/889]  Loss: 4.0385  Acc: 0.5387
Epoch 1 [295/889]  Loss: 4.0354  Acc: 0.5405
Epoch 1 [300/889]  Loss: 4.0329  Acc: 0.5418
Epoch 1 [305/889]  Loss: 4.0296  Acc: 0.5440
Epoch 1 [310/889]  Loss: 4.0261  Acc: 0.5464
Epoch 1 [315/889]  Loss: 4.0232  Acc: 0.5484
Epoch 1 [320/889]  Loss: 4.0199  Acc: 0.5504
Epoch 1 [325/889]  Loss: 4.0175  Acc: 0.5516
Epoch 1 [330/889]  Loss: 4.0149  Acc: 0.5532
Epoch 1 [335/889]  Loss: 4.0119  Acc: 0.5552
Epoch 1 [340/889]  Loss: 4.0089  Acc: 0.5567
Epoch 1 [345/889]  Loss: 4.0059  Acc: 0.5587
