In [None]:
"""
This code runs inference on the Cityscapes dataset using our best bdd100k model
"""

In [1]:
# ==============================================================================
# CELL 1: SETUP & PATHS
# ==============================================================================
import os
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
from google.colab import drive
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from glob import glob
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

# 1. Mount Drive
if not os.path.exists('/content/drive'):
    print("Mounting Google Drive...")
    drive.mount('/content/drive')

# 2. Configuration
DEVICE = torch.device("cuda") if torch.cuda.is_available() else "cpu"
print(f"Running Inference on: {DEVICE}")

# Paths (Adjust if your folder structure is different)
DRIVE_ROOT = "/content/drive/MyDrive/Latitude_AI_Team/data"
CITYSCAPES_ROOT = os.path.join(DRIVE_ROOT, "cityscapes")
CITYSCAPES_IMG_DIR = os.path.join(CITYSCAPES_ROOT, 'leftImg8bit', 'train')
CITYSCAPES_LABEL_DIR = os.path.join(CITYSCAPES_ROOT, 'gtFine', 'train')

# Path to your BEST BDD100K Model
MODEL_PATH = "/content/drive/MyDrive/Latitude_AI_Team/models/best_model_checkpoint_full_run.pth"

# Verify paths
if not os.path.exists(CITYSCAPES_IMG_DIR):
    print(f" Cityscapes Images not found at: {CITYSCAPES_IMG_DIR}")
else:
    print(f" Cityscapes Images found.")

if not os.path.exists(MODEL_PATH):
    print(f" Model Checkpoint not found at: {MODEL_PATH}")
else:
    print(f"Model Checkpoint found.")

ðŸ”Œ Mounting Google Drive...
Mounted at /content/drive
ðŸš€ Running Inference on: cuda
 Cityscapes Images found.
Model Checkpoint found.


In [2]:
# DATASET & MODEL DEFINITIONS

# CITYSCAPES DATASET CLASS
class CityscapesDataset(Dataset):
    def __init__(self, img_dir, label_dir, transforms=None, limit=None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.transforms = transforms

        # Recursive search for files (Cityscapes has subfolders like 'aachen', 'bochum')
        self.image_paths = sorted(glob(os.path.join(img_dir, '**', '*leftImg8bit.png'), recursive=True))
        self.label_paths = sorted(glob(os.path.join(label_dir, '**', '*gtFine_instanceIds.png'), recursive=True))

        if limit:
            self.image_paths = self.image_paths[:limit]
            self.label_paths = self.label_paths[:limit]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")

        # Load instance IDs (Ground Truth for visualization)
        instance_ids_img = Image.open(self.label_paths[idx])
        instance_ids_np = np.array(instance_ids_img, dtype=np.int32)

        obj_ids = np.unique(instance_ids_np)
        # Cityscapes IDs < 10000 are usually groups/background
        obj_ids = obj_ids[obj_ids >= 1000]

        masks = []
        boxes = []
        labels = []

        for obj_id in obj_ids:
            mask_instance = (instance_ids_np == obj_id)
            if mask_instance.sum() < 10: continue # Filter noise

            pos = np.where(mask_instance)
            xmin = np.min(pos[1]); xmax = np.max(pos[1])
            ymin = np.min(pos[0]); ymax = np.max(pos[0])

            if (xmax - xmin) < 5 or (ymax - ymin) < 5: continue

            boxes.append([xmin, ymin, xmax, ymax])
            masks.append(mask_instance)
            labels.append(1) # Placeholder label (We only care about visual overlap for now)

        if boxes:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            masks = torch.as_tensor(np.stack(masks), dtype=torch.uint8)
        else:
            H, W = instance_ids_np.shape
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
            masks = torch.zeros((0, H, W), dtype=torch.uint8)

        target = {"boxes": boxes, "labels": labels, "masks": masks, "image_id": torch.tensor([idx])}

        if self.transforms:
            img = self.transforms(img)

        return img, target

# --- 2. MODEL DEFINITION (Must match BDD Training) ---
BDD_IDS = [6, 7, 11, 13, 12, 14, 15, 17, 18]
NUM_CLASSES = 1 + len(BDD_IDS) # Background + 9 classes

def get_model(num_classes):
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=None) # No pre-train needed, we load ours
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, 256, num_classes)
    return model

# --- 3. LOADERS ---
def get_cityscapes_dataloader(limit=5):
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = CityscapesDataset(CITYSCAPES_IMG_DIR, CITYSCAPES_LABEL_DIR, transforms=transform, limit=limit)
    return DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

In [3]:


# 1. Load the Model with Trained Weights
print(f" Loading model from: {MODEL_PATH}")
model = get_model(NUM_CLASSES)
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()
print("Model loaded successfully.")

# 2. Visualization Function
LABEL_TO_NAME = {
    1: "Traffic Light", 2: "Traffic Sign", 3: "Person", 4: "Car",
    5: "Rider", 6: "Truck", 7: "Bus", 8: "Motorcycle", 9: "Bicycle"
}

def visualize_inference(image, prediction, conf_threshold=0.5):
    img_np = image.permute(1, 2, 0).cpu().numpy()
    fig, ax = plt.subplots(1, 1, figsize=(16, 8))
    ax.imshow(img_np)
    ax.set_title("Cityscapes Inference (BDD Model)", fontsize=16)
    ax.axis('off')

    # Prediction data
    boxes = prediction[0]['boxes'].cpu().numpy()
    labels = prediction[0]['labels'].cpu().numpy()
    scores = prediction[0]['scores'].cpu().numpy()
    masks = prediction[0]['masks'].cpu().numpy()

    for i, score in enumerate(scores):
        if score < conf_threshold: continue

        # Draw Box (Red)
        x1, y1, x2, y2 = boxes[i]
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='red', facecolor='none')
        ax.add_patch(rect)

        # Draw Label
        label_name = LABEL_TO_NAME.get(labels[i], f"ID {labels[i]}")
        ax.text(x1, y1-5, f"{label_name}: {score:.2f}", color='red', fontsize=10, weight='bold', bbox=dict(facecolor='white', alpha=0.5))

        # Draw Mask (Red overlay)
        mask = masks[i, 0]
        masked_image = np.ma.masked_where(mask < 0.5, mask)
        ax.imshow(masked_image, cmap='Reds', alpha=0.5)

    plt.show()

# 3. Run Loop on Cityscapes Data
test_loader = get_cityscapes_dataloader(limit=5) # Check first 5 images

print(" Starting Inference on Cityscapes Data...")
with torch.no_grad():
    for i, (images, targets) in enumerate(test_loader):
        images = [img.to(DEVICE) for img in images]

        # Run Inference
        predictions = model(images)

        # Visualize
        visualize_inference(images[0], predictions, conf_threshold=0.5)

Output hidden; open in https://colab.research.google.com to view.

In [5]:

# CELL 4: QUANTITATIVE EVALUATION ON CITYSCAPES (THE "DOMAIN GAP" CHECK)
!pip install torchmetrics


import torch
from torchmetrics.detection import MeanAveragePrecision
from tqdm import tqdm



CITYSCAPES_MAPPING = {
    # Cityscapes Train ID : BDD Model ID
    26: 4, # Car
    24: 3, # Person
    25: 5, # Rider
    27: 6, # Truck
    28: 7, # Bus
    32: 8, # Motorcycle
    33: 9, # Bicycle
    7: 1,  # Traffic Light (based on standard Cityscapes ID)
    8: 2   # Traffic Sign
}


print("ðŸš€ Starting FINAL Evaluation (Class Agnostic Mode)...")
# This ignores the specific label ID and just checks if BOXES and MASKS match.
metric = MeanAveragePrecision(box_format="xyxy", iou_type="segm", class_metrics=False).to(DEVICE)

model.eval()

with torch.no_grad():
    for i, (images, targets) in tqdm(enumerate(test_loader), total=len(test_loader)):
        images = [img.to(DEVICE) for img in images]
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        outputs = model(images)

        processed_preds = []
        processed_targets = []

        # Set all labels to '1' for both Preds and Targets
        # This forces the metric to ignore class mismatch and just check detection quality
        for pred in outputs:
            if 'masks' in pred:
                pred['masks'] = (pred['masks'] > 0.5).squeeze(1).to(torch.uint8)
            # FORCE LABEL 1
            pred['labels'] = torch.ones_like(pred['labels'])
            processed_preds.append(pred)

        for tgt in targets:
            # FORCE LABEL 1
            tgt['labels'] = torch.ones_like(tgt['labels'])
            processed_targets.append(tgt)

        metric.update(processed_preds, processed_targets)

# 2. COMPUTE
print("\nComputing Class-Agnostic Score (Pure Detection Quality)...")
results = metric.compute()

print("\n" + "="*50)
print("CITYSCAPES ROBUSTNESS SCORE (Class Agnostic)")
print("="*50)
print(f"AP @[.50:.95]: {results['map'].item():.4f}")
print(f"AP @ .50:      {results['map_50'].item():.4f}")
print(f"AP @ .75:      {results['map_75'].item():.4f}")
print("="*50)

ðŸš€ Starting FINAL Evaluation (Class Agnostic Mode)...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 100/100 [01:55<00:00,  1.16s/it]



Computing Class-Agnostic Score (Pure Detection Quality)...

CITYSCAPES ROBUSTNESS SCORE (Class Agnostic)
AP @[.50:.95]: 0.1034
AP @ .50:      0.2281
AP @ .75:      0.0828
