In [None]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import (
    top_k_accuracy_score,
    classification_report,
    confusion_matrix
)
import random

In [2]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Using device:", device)

Using device: mps


## Test data

In [None]:
COUNTRIES = ["Albania","Andorra","Argentina","Australia","Austria","Bangladesh","Belgium","Bhutan","Bolivia","Botswana","Brazil","Bulgaria","Cambodia","Canada","Chile","Colombia","Croatia","Czechia","Denmark","Dominican Republic","Ecuador","Estonia","Eswatini","Finland","France","Germany","Ghana","Greece","Greenland","Guatemala","Hungary","Iceland","Indonesia","Ireland","Israel","Italy","Japan","Jordan","Kenya","Kyrgyzstan","Latvia","Lesotho","Lithuania","Luxembourg","Malaysia","Mexico","Mongolia","Montenegro","Netherlands","New Zealand","Nigeria","North Macedonia","Norway","Palestine","Peru","Philippines","Poland","Portugal","Romania","Russia","Senegal","Serbia","Singapore","Slovakia","Slovenia","South Africa","South Korea","Spain","Sri Lanka","Sweden","Switzerland","Taiwan","Thailand","Turkey","Ukraine","United Arab Emirates","United Kingdom","United States","Uruguay"]
num_classes = len(COUNTRIES)

In [None]:
class CountryImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.samples = []
        self.transform = transform
        for idx, country in enumerate(COUNTRIES):
            country_dir = root_dir / country
            for img_file in country_dir.iterdir():
                if img_file.suffix.lower() in (".jpg", ".jpeg", ".png"):
                    self.samples.append((img_file, idx))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, i):
        path, label = self.samples[i]
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

In [None]:
def get_dataloaders(root_dir, batch_size=32):
    dataset = CountryImageDataset(root_dir, transform)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return data_loader

In [None]:
test_loader = get_dataloaders(" ", batch_size=32)

## Load model

In [None]:
def load_model(model_path):
    model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, 79)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()
    return model

## Evaluation

In [None]:
softmax = torch.nn.Softmax(dim=1)

def evaluate_model(model, data_loader, criterion, device):
    """
    Runs model on data_loader and returns:
      - avg_loss: float
      - top1_acc: float
      - all_targets: np.array shape (N,)
      - all_preds:   np.array shape (N,)
      - all_probs:   np.array shape (N, num_classes)
    """
    model.eval()
    total_loss, total_correct, total_samples = 0.0, 0, 0
    all_preds, all_targets, all_probs = [], [], []

    with torch.no_grad():
        for imgs, labels in data_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss    = criterion(outputs, labels)

            # accumulate loss & top‐1 accuracy
            batch_size = imgs.size(0)
            total_loss    += loss.item() * batch_size
            preds         = outputs.argmax(dim=1)
            total_correct += (preds == labels).sum().item()
            total_samples += batch_size

            # store for detailed metrics
            all_probs.append(softmax(outputs).cpu().numpy())
            all_preds.append(preds.cpu().numpy())
            all_targets.append(labels.cpu().numpy())

    # flatten
    all_probs   = np.vstack(all_probs)
    all_preds   = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)

    avg_loss = total_loss / total_samples
    top1_acc = total_correct / total_samples

    return avg_loss, top1_acc, all_targets, all_preds, all_probs

In [None]:
def print_metrics(all_targets, all_preds, all_probs, class_names):
    """
    Given flattened targets, preds, and probs:
      - prints Top-3/5 accuracy
      - prints classification report
      - plots normalized confusion matrix
    """
    top3 = top_k_accuracy_score(all_targets, all_probs, k=3)
    top5 = top_k_accuracy_score(all_targets, all_probs, k=5)
    print(f"Top-3 Accuracy: {top3:.4f}")
    print(f"Top-5 Accuracy: {top5:.4f}\n")

    report = classification_report(
        all_targets, all_preds,
        target_names=class_names,
        zero_division=0
    )
    print("Classification Report:\n")
    print(report)

In [None]:
def plot_confusion_matrix(all_targets, all_preds, class_names):
    
    cm = confusion_matrix(all_targets, all_preds, normalize='true')
    fig, ax = plt.subplots(figsize=(12,12))
    im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    ax.set_title("Normalized Confusion Matrix")
    fig.colorbar(im, ax=ax)
    ticks = np.arange(len(class_names))
    ax.set_xticks(ticks); ax.set_yticks(ticks)
    ax.set_xticklabels(class_names, rotation=90)
    ax.set_yticklabels(class_names)
    ax.set_ylabel('True Label')
    ax.set_xlabel('Predicted Label')
    plt.tight_layout()
    plt.show()

In [None]:
def show_sample_predictions(all_targets, all_probs, class_names, n=5):
    """
    Prints n random examples of true label vs top-3 predictions+probs.
    """
    total = len(all_targets)
    print(f"\nSample predictions ({n} examples):\n")
    idxs = np.random.choice(total, size=n, replace=False)
    for i in idxs:
        true_lbl = class_names[all_targets[i]]
        probs_i  = all_probs[i]
        topk     = probs_i.argsort()[::-1][:3]
        topk_str = ", ".join(f"{class_names[k]} ({probs_i[k]:.2f})" for k in topk)
        print(f"True: {true_lbl:20s}  ↔  Pred Top-3: {topk_str}")