In [19]:
import sys, os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
softmax = torch.nn.Softmax(dim=1)
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import (
    top_k_accuracy_score,
    classification_report,
    confusion_matrix
)
import random
from pathlib import Path
from tqdm import tqdm
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation

In [72]:
# Compute absolute path to the `src/` folder
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
SRC_PATH     = os.path.join(PROJECT_ROOT, "src")

if SRC_PATH not in sys.path:
    sys.path.insert(0, SRC_PATH)

from utils import get_dataloaders, load_model, evaluate_model, print_metrics, plot_confusion_matrix, show_sample_predictions, plot_random_image_with_label_and_prediction

In [14]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

In [9]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Using device:", device)

Using device: mps


In [10]:
COUNTRIES = ["Albania","Andorra","Argentina","Australia","Austria","Bangladesh","Belgium","Bhutan","Bolivia","Botswana","Brazil","Bulgaria","Cambodia","Canada","Chile","Colombia","Croatia","Czechia","Denmark","Dominican Republic","Ecuador","Estonia","Eswatini","Finland","France","Germany","Ghana","Greece","Greenland","Guatemala","Hungary","Iceland","Indonesia","Ireland","Israel","Italy","Japan","Jordan","Kenya","Kyrgyzstan","Latvia","Lesotho","Lithuania","Luxembourg","Malaysia","Mexico","Mongolia","Montenegro","Netherlands","New Zealand","Nigeria","North Macedonia","Norway","Palestine","Peru","Philippines","Poland","Portugal","Romania","Russia","Senegal","Serbia","Singapore","Slovakia","Slovenia","South Africa","South Korea","Spain","Sri Lanka","Sweden","Switzerland","Taiwan","Thailand","Turkey","Ukraine","United Arab Emirates","United Kingdom","United States","Uruguay"]
num_classes = len(COUNTRIES)
project_root   = Path().resolve().parent

### Data

In [11]:
test_root = project_root/ "data" / "processed_data" / "medium_dataset" / "test"
test_loader = get_dataloaders(test_root, batch_size=32)

In [12]:
get_dataloaders(test_root, batch_size=32)

<torch.utils.data.dataloader.DataLoader at 0x33ade0040>

### Load model

In [73]:
model = load_model(model_path=project_root / "models" / "resnet_finetuned" / "main.pth", device=device)

  model.load_state_dict(torch.load(model_path, map_location=device))


### Get base model prediction probs

In [None]:
def get_prob(true_lbl, probs_i, class_names, n=5):
    topk     = probs_i.argsort()[::-1][:3]
    class_names = np.array(class_names)
    topk_str = ", ".join(f"{class_names[k]} ({probs_i[k]:.2f})" for k in topk)
    print(f"True: {true_lbl:20s}  ↔  Pred Top-3: {topk_str}")

In [None]:
# get random image 
all_countries = [d for d in test_root.iterdir() if d.is_dir()]
country = random.choice(all_countries).name
img_files = list((test_root / country).glob("*.jpg"))
img_path = random.choice(img_files)

img = Image.open(img_path).convert("RGB")

# Preprocess and predict
input_tensor = transform(img).unsqueeze(0).to(device)  # add batch dimension
with torch.no_grad():
    outputs = model(input_tensor)
    pred_idx = outputs.argmax(dim=1).item()
    pred_label = COUNTRIES[pred_idx]

probs = softmax(outputs).cpu().numpy()

get_prob(country, np.squeeze(probs), COUNTRIES, n=1)

True: South Africa          ↔  Pred Top-3: United States (0.81), United Kingdom (0.08), New Zealand (0.06)


### Segment (to adjust)

In [6]:
# Cityscapes class mapping
CITYSCAPES_ID2LABEL = {
    0: 'road', 1: 'sidewalk', 2: 'building', 3: 'wall', 4: 'fence',
    5: 'pole', 6: 'traffic_light', 7: 'traffic_sign', 8: 'vegetation', 9: 'terrain',
    10: 'sky', 11: 'person', 12: 'rider', 13: 'car', 14: 'truck',
    15: 'bus', 16: 'train', 17: 'motorcycle', 18: 'bicycle',
}

In [5]:
segments = {'bicycle', 'building', 'car', 'fence', 'person', 'pole', 'road', 'sidewalk', 'terrain', 'traffic_light', 'traffic_sign', 'vegetation'}

In [None]:
MODEL_NAME = "nvidia/segformer-b0-finetuned-cityscapes-768-768"

feature_extractor = SegformerFeatureExtractor.from_pretrained(MODEL_NAME)
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_NAME).eval()



In [7]:
def segment_and_return(image_path):
    """
    Given an image file, run the segmentation model and return a dict
    of {class_name: masked_PIL_image} for every class present.
    """
    # 1) Load & preprocess
    image = Image.open(image_path).convert("RGB")
    image_np = np.array(image)
    inputs = feature_extractor(images=image, return_tensors="pt")

    # 2) Inference + upsample logits
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits              # (1, num_classes, h/4, w/4)
        upsampled_logits = torch.nn.functional.interpolate(
            logits,
            size=image.size[::-1],           # (H, W)
            mode="bilinear",
            align_corners=False
        )
        predicted = upsampled_logits.argmax(dim=1)[0].cpu().numpy()  # (H, W)

    # 3) Build result dict
    results = {}
    for class_idx, class_name in segments.items():
        mask = (predicted == class_idx).astype(np.uint8)
        if mask.any():
            # zero out everything except this class
            masked_np = image_np.copy()
            masked_np[mask == 0] = 0
            masked_pil = Image.fromarray(masked_np)
            results[class_name] = masked_pil

    return results