In [10]:
import sys, os
import torch
import torch.nn as nn
from pathlib import Path
import numpy as np

In [3]:
# Compute absolute path to the `src/` folder
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
SRC_PATH     = os.path.join(PROJECT_ROOT, "src")

if SRC_PATH not in sys.path:
    sys.path.insert(0, SRC_PATH)

from utils import get_dataloaders, load_model, evaluate_model, print_metrics, plot_confusion_matrix, show_sample_predictions

In [5]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Using device:", device)

Using device: cuda


In [6]:
COUNTRIES = ["Albania","Andorra","Argentina","Australia","Austria","Bangladesh","Belgium","Bhutan","Bolivia","Botswana","Brazil","Bulgaria","Cambodia","Canada","Chile","Colombia","Croatia","Czechia","Denmark","Dominican Republic","Ecuador","Estonia","Eswatini","Finland","France","Germany","Ghana","Greece","Greenland","Guatemala","Hungary","Iceland","Indonesia","Ireland","Israel","Italy","Japan","Jordan","Kenya","Kyrgyzstan","Latvia","Lesotho","Lithuania","Luxembourg","Malaysia","Mexico","Mongolia","Montenegro","Netherlands","New Zealand","Nigeria","North Macedonia","Norway","Palestine","Peru","Philippines","Poland","Portugal","Romania","Russia","Senegal","Serbia","Singapore","Slovakia","Slovenia","South Africa","South Korea","Spain","Sri Lanka","Sweden","Switzerland","Taiwan","Thailand","Turkey","Ukraine","United Arab Emirates","United Kingdom","United States","Uruguay"]
num_classes = len(COUNTRIES)
project_root   = Path().resolve().parent

### Data

In [7]:
test_root  = project_root/ "datasets" / "segmented" / "road" / "final_datasets" / "test"
test_loader = get_dataloaders(test_root, batch_size=32)

### Load model

In [8]:
model = load_model(model_path=project_root / "models" / "resnet_finetuned_road" / "main.pth", device=device)

  model.load_state_dict(torch.load(model_path, map_location=device))


### Evaluation

In [9]:
criterion = nn.CrossEntropyLoss()

In [13]:
country_to_continent = {
    "Albania": "Europe",
    "Andorra": "Europe",
    "Argentina": "South America",
    "Australia": "Oceania",
    "Austria": "Europe",
    "Bangladesh": "Asia",
    "Belgium": "Europe",
    "Bhutan": "Asia",
    "Bolivia": "South America",
    "Botswana": "Africa",
    "Brazil": "South America",
    "Bulgaria": "Europe",
    "Cambodia": "Asia",
    "Canada": "North America",
    "Chile": "South America",
    "Colombia": "South America",
    "Croatia": "Europe",
    "Czechia": "Europe",
    "Denmark": "Europe",
    "Dominican Republic": "North America",
    "Ecuador": "South America",
    "Estonia": "Europe",
    "Eswatini": "Africa",
    "Finland": "Europe",
    "France": "Europe",
    "Germany": "Europe",
    "Ghana": "Africa",
    "Greece": "Europe",
    "Greenland": "North America",
    "Guatemala": "North America",
    "Hungary": "Europe",
    "Iceland": "Europe",
    "Indonesia": "Asia",
    "Ireland": "Europe",
    "Israel": "Asia",
    "Italy": "Europe",
    "Japan": "Asia",
    "Jordan": "Asia",
    "Kenya": "Africa",
    "Kyrgyzstan": "Asia",
    "Latvia": "Europe",
    "Lesotho": "Africa",
    "Lithuania": "Europe",
    "Luxembourg": "Europe",
    "Malaysia": "Asia",
    "Mexico": "North America",
    "Mongolia": "Asia",
    "Montenegro": "Europe",
    "Netherlands": "Europe",
    "New Zealand": "Oceania",
    "Nigeria": "Africa",
    "North Macedonia": "Europe",
    "Norway": "Europe",
    "Palestine": "Asia",
    "Peru": "South America",
    "Philippines": "Asia",
    "Poland": "Europe",
    "Portugal": "Europe",
    "Romania": "Europe",
    "Russia": "Europe",
    "Senegal": "Africa",
    "Serbia": "Europe",
    "Singapore": "Asia",
    "Slovakia": "Europe",
    "Slovenia": "Europe",
    "South Africa": "Africa",
    "South Korea": "Asia",
    "Spain": "Europe",
    "Sri Lanka": "Asia",
    "Sweden": "Europe",
    "Switzerland": "Europe",
    "Taiwan": "Asia",
    "Thailand": "Asia",
    "Turkey": "Asia",
    "Ukraine": "Europe",
    "United Arab Emirates": "Asia",
    "United Kingdom": "Europe",
    "United States": "North America",
    "Uruguay": "South America"
}


In [16]:
avg_loss, top1_acc, all_targets, all_preds, all_probs = evaluate_model(model, test_loader, criterion, device)

# Get top 5 predictions for each image
top5_indices = np.argsort(all_probs, axis=1)[:, -3:]  # Get indices of top 5 probabilities
top5_countries = [[COUNTRIES[idx] for idx in row] for row in top5_indices]
top5_continents = [[country_to_continent[country] for country in row] for row in top5_countries]

# Get most frequent continent from top 5 predictions
predicted_continents = []
for continents in top5_continents:
    # Count occurrences of each continent
    continent_counts = {}
    for continent in continents:
        continent_counts[continent] = continent_counts.get(continent, 0) + 1
    # Get continent with highest count
    most_frequent = max(continent_counts.items(), key=lambda x: x[1])[0]
    predicted_continents.append(most_frequent)

# Get true continents
true_countries = [COUNTRIES[idx] for idx in all_targets]
true_continents = [country_to_continent[country] for country in true_countries]

# Calculate continent accuracy
correct_continent = sum(1 for pred, true in zip(predicted_continents, true_continents) if pred == true)
continent_accuracy = correct_continent / len(true_continents)

print(f"\nContinent Classification Results (based on most frequent continent in top 5 predictions):")
print(f"Continent Accuracy: {continent_accuracy:.2%}")


Continent Classification Results (based on most frequent continent in top 5 predictions):
Continent Accuracy: 46.48%


In [None]:
print_metrics(all_targets, all_preds, all_probs, COUNTRIES)

Top-3 Accuracy: 0.4175
Top-5 Accuracy: 0.5063

Classification Report:

                      precision    recall  f1-score   support

             Albania       0.24      0.27      0.25        45
             Andorra       0.23      0.48      0.31        44
           Argentina       0.14      0.07      0.09        45
           Australia       0.12      0.27      0.17        44
             Austria       0.36      0.38      0.37        45
          Bangladesh       0.10      0.07      0.08        43
             Belgium       0.09      0.02      0.04        43
              Bhutan       0.31      0.26      0.29        42
             Bolivia       0.28      0.29      0.28        45
            Botswana       0.42      0.29      0.34        45
              Brazil       0.08      0.11      0.09        45
            Bulgaria       0.18      0.07      0.10        45
            Cambodia       0.36      0.22      0.27        41
              Canada       0.22      0.04      0.07        4