In [None]:
!pip -q install timm captum torchvision torchmetrics scikit-learn opencv-python


In [None]:
import os, torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import EuroSAT as TVEuroSAT

from src.preprocess import get_transforms
from src.eurosat_dataset import EuroSAT
from src.model_loader import load_vit_base_model
from src.training import train_and_evaluate
from src.model_utils import calculate_f1_score, calculate_confusion_matrix, calculate_balanced_accuracy
from src.XAI import attribution_maps

device = "cuda" if torch.cuda.is_available() else "cpu"
device


In [None]:
SEED = 42
BATCH_SIZE = 64
IMG_SIZE = 224
LR = 3e-4
EPOCHS = 20
EARLY_STOP = 5
MODEL_NAME = "vit_base_patch16_224"
SAVE_PATH = "best_eurosat_vit.pth"
NUM_CLASSES = 10

torch.manual_seed(SEED)
if device == "cuda":
    torch.cuda.manual_seed_all(SEED)


In [None]:
root = "./data"
train_transform, test_transform = get_transforms(IMG_SIZE)

full_ds = TVEuroSAT(root=root, download=True, transform=None)
n_total = len(full_ds)
n_train = int(0.8 * n_total)
n_test  = n_total - n_train
train_raw, test_raw = random_split(full_ds, [n_train, n_test], generator=torch.Generator().manual_seed(SEED))

train_ds = EuroSAT(train_raw, transform=train_transform)
test_ds  = EuroSAT(test_raw, transform=test_transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

len(train_ds), len(test_ds)


In [None]:
model, n_params, index_to_class, class_to_index = load_vit_base_model(
    model_name=MODEL_NAME,
    pretrained=True,
    device=device,
    num_classes=NUM_CLASSES,
)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)

n_params


In [None]:
history, final_metrics, total_time = train_and_evaluate(
    model, train_loader, test_loader,
    criterion, optimizer, device,
    n_epochs=EPOCHS,
    early_stopping_threshold=EARLY_STOP,
    final_model_path=SAVE_PATH,
)
final_metrics, total_time


In [None]:
f1 = calculate_f1_score(model, test_loader, device)
bal_acc = calculate_balanced_accuracy(model, test_loader, device)
cm = calculate_confusion_matrix(model, test_loader, device)

print("Weighted F1:", f1)
print("Balanced Acc:", bal_acc)
cm


In [None]:
import random
# pick one test sample for XAI
model.eval()
sample_idx = random.randrange(len(test_ds))
pil_img, label_idx = test_ds.dataset[sample_idx]

x = test_transform(pil_img).unsqueeze(0)

svg_path, pred_idx, pred_score = attribution_maps(
    input_image=x, model=model, true_label_idx=label_idx, device=device,
    index_to_class=index_to_class, output_dir="xai_outputs",
    alpha=0.6, threshold_percentile=99.0, blur_radius=5,
)
print("Saved:", svg_path)
print("Pred:", index_to_class[pred_idx], "(", f"{pred_score:.2f}", ") | True:", index_to_class[label_idx])
