## Imports

In [1]:
import sys, os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
softmax = torch.nn.Softmax(dim=1)
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import (
    top_k_accuracy_score,
    classification_report,
    confusion_matrix
)
import random
from pathlib import Path
from tqdm import tqdm
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation, SwinForImageClassification, SwinConfig
from typing import List, Tuple
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Compute absolute path to the `src/` folder
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
SRC_PATH     = os.path.join(PROJECT_ROOT, "src")

if SRC_PATH not in sys.path:
    sys.path.insert(0, SRC_PATH)

from utils import get_dataloaders, load_model, evaluate_model, print_metrics, plot_confusion_matrix, show_sample_predictions, plot_random_image_with_label_and_prediction, load_vit_model

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

In [4]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Using device:", device)

Using device: mps


In [5]:
COUNTRIES = ['Albania', 'Andorra', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bhutan', 'Bolivia', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'Colombia', 'Croatia', 'Czechia', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Eswatini', 'Finland', 'France', 'Germany', 'Greece', 'Guatemala', 'Hungary', 'Iceland', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Malaysia', 'Mexico', 'Montenegro', 'Netherlands', 'New Zealand', 'North Macedonia', 'Norway', 'Palestine', 'Peru', 'Poland', 'Portugal', 'Romania', 'Russia', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sweden', 'Switzerland', 'Taiwan', 'Thailand', 'Turkey', 'United Arab Emirates', 'United Kingdom', 'United States']
num_classes = len(COUNTRIES)
project_root   = Path().resolve().parent

## Data

In [6]:
train_root = project_root/ "data" / "final_datasets" / "train"
train_loader = get_dataloaders(train_root, batch_size=32)

val_root = project_root/ "data" / "final_datasets" / "val"
val_loader = get_dataloaders(val_root, batch_size=32)

test_root = project_root/ "data" / "final_datasets" / "test"
test_loader = get_dataloaders(test_root, batch_size=32)

## Load models

In [7]:
base_model = load_model(model_path=project_root / "models" / "resnet_finetuned_new" / "main.pth", device=device)
road = load_model(model_path=project_root / "models" / "resnet_finetuned_road_new" / "main.pth", device=device)
terrain = load_model(model_path=project_root / "models" / "resnet_finetuned_terrain_new" / "main.pth", device=device)
vegetation = load_model(model_path=project_root / "models" / "resnet_finetuned_vegetation_new" / "main.pth", device=device)

  model.load_state_dict(torch.load(model_path, map_location=device))


In [8]:
base_vit = load_vit_model(model_path=project_root / "models" / "swin_b_finetuned" / "swin_b_finetuned", device=device)
vit_road = load_vit_model(model_path=project_root / "models" / "swin_b_finetuned" / "swin_b_finetuned_road", device=device)
vit_terrain = load_vit_model(model_path=project_root / "models" / "swin_b_finetuned" / "swin_b_finetuned_terrain", device=device)
vit_vegetation = load_vit_model(model_path=project_root / "models" / "swin_b_finetuned" / "swin_b_finetuned_vegetation", device=device)

In [9]:
MODEL_NAME = "nvidia/segformer-b0-finetuned-cityscapes-768-768"

feature_extractor = SegformerFeatureExtractor.from_pretrained(MODEL_NAME)
seg_model = SegformerForSemanticSegmentation.from_pretrained(MODEL_NAME).eval()



In [10]:
CITYSCAPES_LABELS = {
    0: 'road', 
    8: 'vegetation',  9: 'terrain'
}

TARGET_CLASSES = {'road','terrain','vegetation'}

## Ensemble

### Utils

In [11]:
# Utility to fetch softmax probs from a pretrained submodel
def get_probs(model, img_tensor, device):
    model.eval()
    with torch.no_grad():
        out = model(img_tensor.to(device))
        probs = nn.functional.softmax(out, dim=1).cpu().squeeze(0).numpy()
    return probs

In [12]:
def cache_submodel_outputs(base_ds, submodels, device, cache_dir):
    """
    Runs each submodel once over base_ds and writes:
      - feats.npy: shape (N, num_models * num_classes)
      - labels.npy: shape (N,)
    under cache_dir.
    """
    N = len(base_ds)
    num_classes = submodels['base'].fc.out_features
    num_models  = len(submodels)
    feats  = np.zeros((N, num_models * num_classes), dtype=np.float32)
    labels = np.zeros(N, dtype=np.int64)

    for i in range(N):
        print(f"Printing {i}/{N}")
        img, lbl = base_ds[i]                 # load and transform image
        labels[i] = lbl
        x = img.unsqueeze(0).to(device)

        vecs = []
        for name, m in submodels.items():
            m.eval()
            with torch.no_grad():
                out = m(x)
                p   = softmax(out).cpu().numpy().squeeze(0)
            vecs.append(p)
        feats[i] = np.concatenate(vecs)

    # Persist to disk once, not every epoch
    np.save(cache_dir / "feats.npy", feats)
    np.save(cache_dir / "labels.npy", labels)

In [13]:
class CachedEnsembleDataset(Dataset):
    def __init__(self, cache_dir):
        """
        Loads feats.npy and labels.npy once into memory.
        """
        self.feats  = np.load(cache_dir / "feats.npy")
        self.labels = np.load(cache_dir / "labels.npy")

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, i):
        # returns (feature_vector, label) as torch tensors
        return torch.from_numpy(self.feats[i]), int(self.labels[i])

In [39]:
# Ensemble Network: one hidden layer
class EnsembleNet(nn.Module):
    def __init__(self, in_dim, hid_dim, num_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, num_classes)
        )

    def forward(self, x):
        return self.net(x)

In [15]:
# Ensemble Network: no hidden layer
class EnsembleNet0(nn.Module):
    def __init__(self, in_dim, num_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, num_classes)
        )

    def forward(self, x):
        return self.net(x)

In [16]:
def train_epoch(model, loader, loss_fn, optimizer, device, epoch, log_every=5):
    """
    Runs one epoch of training, printing updates every `log_every` batches.

    Args:
        model       (nn.Module):      the network to train
        loader      (DataLoader):     training data loader
        loss_fn     (callable):       loss function
        optimizer   (torch.optim.Optimizer)
        device      (torch.device)
        epoch       (int):            current epoch number (for prints)
        log_every   (int):            how many batches between prints

    Returns:
        avg_loss (float), avg_acc (float)
    """
    model.train()
    running_loss = 0.0
    running_correct = 0
    total_samples = 0

    for batch_idx, (imgs, labels) in enumerate(loader, start=1):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        # Metrics
        preds = outputs.argmax(dim=1)
        batch_correct = (preds == labels).sum().item()
        batch_size = imgs.size(0)

        running_loss    += loss.item() * batch_size
        running_correct += batch_correct
        total_samples   += batch_size


    avg_loss = running_loss / total_samples
    avg_acc  = running_correct / total_samples
    return avg_loss, avg_acc


In [17]:
def eval_epoch(model, loader, loss_fn, device):
    model.eval()
    total_loss = total_correct = 0
    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            logits = model(X)
            loss   = loss_fn(logits, y)
            preds  = logits.argmax(dim=1)
            total_correct += (preds==y).sum().item()
            total_loss    += loss.item() * X.size(0)
    return total_loss/len(loader.dataset), total_correct/len(loader.dataset)

In [18]:
def train_with_early_stopping(
    model, train_loader, dev_loader,
    loss_fn, optimizer, device,
    ckpt_path,
    max_epochs=100, patience=5
):
    best_val_loss = float('inf')
    patience_ctr  = 0

    for epoch in range(1, max_epochs+1):
        # 1) Train epoch
        tr_loss, tr_acc = train_epoch(
            model, train_loader, loss_fn, optimizer, device, epoch
        )

        # 2) Eval on dev
        vl_loss, vl_acc = eval_epoch(model, dev_loader, loss_fn, device)
        print(f"Epoch {epoch}: train {tr_loss:.3f}/{tr_acc:.3f} | "
              f"val   {vl_loss:.3f}/{vl_acc:.3f}")

        # 3) Check for improvement
        if vl_loss < best_val_loss:
            best_val_loss = vl_loss
            patience_ctr  = 0
            torch.save(model.state_dict(), ckpt_path)
            print("  ↳ New best val loss; checkpoint saved.")
        else:
            patience_ctr += 1
            print(f"  ↳ No improvement. Patience {patience_ctr}/{patience}.")
            if patience_ctr >= patience:
                print("Early stopping triggered.")
                break

    # 4) Load best model before returning
    model.load_state_dict(torch.load(ckpt_path))
    return model

### Build cache (to speed up training)

In [19]:
submodels = {
    'base':     base_model,
    'road':     road,
    'terrain':  terrain,
    'vegetation': vegetation,
    'base_vit':    base_vit,
    'vit_road':    vit_road,
    'vit_terrain': vit_terrain,
    'vit_vegetation': vit_vegetation
}

In [20]:
train_cache_dir = project_root / "cache" / "cache_train"
dev_cache_dir = project_root / "cache" / "cache_dev"
test_cache_dir = project_root / "cache" / "cache_test"

In [None]:
# Train split caching
train_cache_dir.mkdir(exist_ok=True)
cache_submodel_outputs(train_loader.dataset, submodels, device, train_cache_dir)

# Dev split caching
dev_cache_dir.mkdir(exist_ok=True)
cache_submodel_outputs(val_loader.dataset, submodels, device, dev_cache_dir)

# Test split caching
test_cache_dir.mkdir(exist_ok=True)
cache_submodel_outputs(test_loader.dataset, submodels, device, test_cache_dir)

### Training

In [21]:
train_cached_ds = CachedEnsembleDataset(train_cache_dir)
train_cached_loader = DataLoader(train_cached_ds, batch_size=32, shuffle=True, num_workers=0)

dev_cached_ds = CachedEnsembleDataset(dev_cache_dir)
dev_cached_loader = DataLoader(dev_cached_ds, batch_size=32, shuffle=False, num_workers=0)

test_cached_ds = CachedEnsembleDataset(test_cache_dir)
test_cached_loader = DataLoader(test_cached_ds, batch_size=32, shuffle=False, num_workers=0)

#### Linear model

In [36]:
# Instantiate linear_model
epochs=100
lr=1e-3

in_dim = num_classes * len(submodels)
linear_model = EnsembleNet0(in_dim, num_classes).to(device)
loss_fn = nn.CrossEntropyLoss()
opt     = optim.Adam(linear_model.parameters(), lr=lr)

best_val_loss = float('inf')
ckpt_path = project_root / "models" / "ensemble" / "linear.pth"

In [37]:
linear_model = train_with_early_stopping(
    linear_model,
    train_cached_loader,
    dev_cached_loader,
    loss_fn,
    opt,
    device,
    ckpt_path,
    max_epochs=100,
    patience=5
)

Epoch 1: train 3.528/0.649 | val   3.147/0.602
  ↳ New best val loss; checkpoint saved.
Epoch 2: train 2.461/0.823 | val   2.461/0.606
  ↳ New best val loss; checkpoint saved.
Epoch 3: train 1.749/0.843 | val   2.050/0.612
  ↳ New best val loss; checkpoint saved.
Epoch 4: train 1.312/0.852 | val   1.814/0.616
  ↳ New best val loss; checkpoint saved.
Epoch 5: train 1.047/0.860 | val   1.675/0.621
  ↳ New best val loss; checkpoint saved.
Epoch 6: train 0.879/0.865 | val   1.589/0.621
  ↳ New best val loss; checkpoint saved.
Epoch 7: train 0.765/0.869 | val   1.534/0.627
  ↳ New best val loss; checkpoint saved.
Epoch 8: train 0.685/0.873 | val   1.499/0.624
  ↳ New best val loss; checkpoint saved.
Epoch 9: train 0.626/0.877 | val   1.475/0.627
  ↳ New best val loss; checkpoint saved.
Epoch 10: train 0.580/0.879 | val   1.459/0.622
  ↳ New best val loss; checkpoint saved.
Epoch 11: train 0.544/0.883 | val   1.450/0.625
  ↳ New best val loss; checkpoint saved.
Epoch 12: train 0.515/0.885 | 

  model.load_state_dict(torch.load(ckpt_path))


In [38]:
linear_model.eval()
test_loss, test_acc = eval_epoch(linear_model, test_cached_loader, loss_fn, device)

print(f"Test   Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")


Test   Loss: 1.3891 | Test Acc: 0.6404


#### Non-Linear model

In [43]:
# Instantiate linear_model
hid_dim=128
epochs=100
lr=1e-3

in_dim = num_classes * len(submodels)
model = EnsembleNet(in_dim, hid_dim, num_classes).to(device)
loss_fn = nn.CrossEntropyLoss()
opt     = optim.Adam(model.parameters(), lr=lr)

best_val_loss = float('inf')
ckpt_path = project_root / "models" / "ensemble" / "main.pth"

In [44]:
model = train_with_early_stopping(
    model,
    train_cached_loader,
    dev_cached_loader,
    loss_fn,
    opt,
    device,
    ckpt_path,
    max_epochs=100,
    patience=5
)

Epoch 1: train 1.519/0.782 | val   1.526/0.612
  ↳ New best val loss; checkpoint saved.
Epoch 2: train 0.481/0.881 | val   1.579/0.616
  ↳ No improvement. Patience 1/5.
Epoch 3: train 0.427/0.889 | val   1.625/0.616
  ↳ No improvement. Patience 2/5.
Epoch 4: train 0.400/0.895 | val   1.653/0.615
  ↳ No improvement. Patience 3/5.
Epoch 5: train 0.381/0.898 | val   1.677/0.618
  ↳ No improvement. Patience 4/5.
Epoch 6: train 0.364/0.903 | val   1.714/0.613
  ↳ No improvement. Patience 5/5.
Early stopping triggered.


  model.load_state_dict(torch.load(ckpt_path))


In [45]:
model.eval()
test_loss, test_acc = eval_epoch(model, test_cached_loader, loss_fn, device)

print(f"Test   Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")

Test   Loss: 1.4552 | Test Acc: 0.6323


### Weights visualization

In [46]:
def inspect_ensemble_weights(model, submodel_names, COUNTRIES):

    W1 = model.net[0].weight.data.cpu().numpy()  
    num_classes = len(COUNTRIES)

    block_norms = []
    for i, name in enumerate(submodel_names):
        start = i * num_classes
        end   = (i+1) * num_classes
        block = W1[:, start:end]          
        norm  = np.linalg.norm(block)      
        block_norms.append((name, norm))

    print("First‐layer L2 norms per submodel block:")
    for name, norm in block_norms:
        print(f"  {name:12s}: {norm:.2f}")


In [47]:
submodel_names = [
    'res_base', 'res_road', 'res_terrain', 'res_veg',
    'vit_base','vit_road','vit_terrain','vit_vegetation'
]

inspect_ensemble_weights(linear_model, submodel_names, COUNTRIES)


First‐layer L2 norms per submodel block:
  res_base    : 164.62
  res_road    : 68.41
  res_terrain : 38.97
  res_veg     : 50.85
  vit_base    : 95.04
  vit_road    : 22.06
  vit_terrain : 16.11
  vit_vegetation: 45.79


In [48]:
inspect_ensemble_weights(model, submodel_names, COUNTRIES)


First‐layer L2 norms per submodel block:
  res_base    : 23.93
  res_road    : 9.78
  res_terrain : 4.71
  res_veg     : 8.25
  vit_base    : 12.00
  vit_road    : 8.90
  vit_terrain : 9.08
  vit_vegetation: 8.98


In [66]:
inspect_ensemble_weights(model, COUNTRIES)

First‐layer L2 norms per submodel block:
  base: 171.35
  road: 69.65
  terrain: 40.79
  veg: 51.57
