## Code based on the papers in the references of the project

### Imports

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50
from torchvision.models import ResNet50_Weights
from PIL import Image
import numpy as np
from tqdm import tqdm
import time
from dataset import CityscapesFineDataset, LostAndFoundDataset, COLORS, PALETTE2ID, CITYSCAPES_19_TO_7_MACRO, rgb_to_id
from torchmetrics.classification import MulticlassJaccardIndex
import matplotlib.pyplot as plt

### Globals

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using GPU: {torch.cuda.get_device_name(0)}") if torch.cuda.is_available() else print("Using CPU")

# Training Parameters
batch_size = 8 # Adjusted batch size for better performance
model_output_classes = 8 # Number of classes
epochs = 1
resize = (256, 512)
checkpoint_path = 'deeplabv3_cityscapes_fine.pth'
best_miou = 0.0 # To keep track of the best model based on validation mIoU

# Optimizer Parameters (from "Road Obstacle Detection based on Unknown Objectness Scores" paper inspiration)
initial_lr = 0.01
momentum = 0.9
weight_decay = 0.0001
poly_power = 0.9 # Power for the "poly" learning rate policy

# Conformal Prediction Parameters (from "A Gentle Introduction to Conformal Prediction..." paper)
alpha = 0.1 # Alpha for Conformal Prediction (1 - alpha confidence level, e.g., 0.1 for 90% coverage)

# Unknown Obstacle Detection Parameters
UNKNOWN_OBSTACLE_ID = 7 # Using num_classes as the ID for unknown obstacles
ALL_COLORS_FOR_VISUALIZATION = COLORS.tolist() # This will pick up the COLORS from dataset.py
anomaly_threshold_uos = 0.01 # Threshold for anomaly detection in unknown obstacle detection
# Example image paths for inference (For Lost and Found dataset we need to change the rgb_to_id mapping)
example_image_path = 'datasets/realcityscapes/leftImg8bit/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png'
example_ground_truth_path = 'datasets/realcityscapes/gtFine/val/frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png'

### Utils

In [None]:
def poly_lr_scheduler(optimizer, initial_lr, current_iter, total_iter, power=0.9):
    """
    Implements the "poly" learning rate policy: 
    lr = initial_lr * (1 - iter/total_iters)^power, clipped to avoid negatives.
    """
    # Evita che current_iter superi total_iter
    ratio = min(current_iter / total_iter, 1.0)
    
    # Calcolo LR (sempre positivo e reale)
    lr = initial_lr * (1 - ratio) ** power
    lr = max(float(lr), 1e-8)  # opzionale: evita LR troppo vicino a 0

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


In [None]:
def decode_segmap(pred_mask, colors_list=ALL_COLORS_FOR_VISUALIZATION, ignore_id=255):
    """
    Decodes a segmentation mask (ID-based) into an RGB image using predefined colors.
    Args:
        pred_mask (np.array): Predicted segmentation mask (H, W) with class IDs.
        colors_list (list): List of RGB color tuples/lists for each class ID. Defaults to ALL_COLORS_FOR_VISUALIZATION.
        ignore_id (int): The ID to be ignored (e.g., 255 for Cityscapes void pixels).
    Returns:
        np.array: RGB image (H, W, 3).
    """
    h, w = pred_mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for label_id in np.unique(pred_mask):
        if label_id == ignore_id: # Handle ignore_index
            color_mask[pred_mask == label_id] = [0, 0, 0] # Black for ignored pixels
        elif 0 <= label_id < len(colors_list):
            color_mask[pred_mask == label_id] = colors_list[label_id]
        else:
            # Fallback for unexpected labels (shouldn't happen if UNKNOWN_OBSTACLE_ID is handled correctly)
            print(f"Warning: Unexpected label ID {label_id} found in prediction mask. Mapping to blue.")
            color_mask[pred_mask == label_id] = [0, 0, 255] # Blue for unmapped unexpected
    return color_mask

In [None]:
def evaluate_model(model, dataloader, device, num_classes_for_metric, ignore_index=255):
    """
    Evaluates the model on a given dataloader and computes the average loss and mIoU.
    """
    model.eval() # Set model to evaluation mode
    total_loss = 0.0
    
    # Initialize JaccardIndex (IoU) metric.
    # num_classes_for_metric should be 7 for your 7 macro classes (0-6).
    metric = MulticlassJaccardIndex(num_classes=num_classes_for_metric, ignore_index=ignore_index, average='macro', validate_args=False).to(device)

    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Evaluating"):
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)

            outputs = model(images)['out']
            loss = criterion(outputs, masks)
            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            metric.update(preds, masks)

    avg_loss = total_loss / len(dataloader)
    mean_iou = metric.compute()

    model.train()
    return avg_loss, mean_iou.item()

### Data

In [None]:
# Transformations for input images
transform = transforms.Compose([
    # ColorJitter is for images only, keep it here.
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(), # Convert PIL Image to Tensor
    # Normalization for ImageNet pre-trained models (DeepLabV3 uses ResNet50 backbone)
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
print("Loading datasets...")

# Using CityscapesFineDataset from dataset.py
train_ds = CityscapesFineDataset(root='datasets/realcityscapes', split='train', transform=transform, resize=resize, unknown_obstacle_id=UNKNOWN_OBSTACLE_ID)
val_ds = CityscapesFineDataset(root='datasets/realcityscapes', split='val', transform=transform, resize=resize, unknown_obstacle_id=UNKNOWN_OBSTACLE_ID)

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4, # This can be adjusted based on your system's capabilities
    pin_memory=True,
    persistent_workers=True, # Set to True for better performance on systems with multiple workers
    prefetch_factor=4, # Prefetch factor for better performance
)

val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False, # No need to shuffle validation data
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
)
print("Datasets loaded.")

### Network

In [None]:
# Initialize DeepLabV3 with ImageNet pre-trained weights for the ResNet50 backbone.
# This helps in transfer learning and often leads to faster convergence and better performance.
model = deeplabv3_resnet50(weights_backbone=ResNet50_Weights.IMAGENET1K_V1, num_classes=model_output_classes).to(device)
model = model.to(memory_format=torch.channels_last) # Optimize memory layout for NVIDIA GPUs

### Train

In [None]:
# Optimizer: SGD with momentum and weight decay as specified in the paper
optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=momentum, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss(ignore_index=255) # ignore_index=255 for unlabeled pixels
scaler = torch.cuda.amp.GradScaler() # For mixed precision training

# Checkpoint Handling
start_epoch = 0
current_iteration = 0 # Initialize current_iteration for learning rate scheduler

if os.path.exists(checkpoint_path):
    ck = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(ck['model'])
    optimizer.load_state_dict(ck['optimizer'])
    scaler.load_state_dict(ck['scaler'])
    start_epoch = ck['epoch'] + 1
    if 'best_miou' in ck:
        best_miou = ck['best_miou']
    if 'current_iteration' in ck:
        current_iteration = ck['current_iteration']
    print(f"Restarting from epoch {start_epoch}, best mIoU until now: {best_miou:.4f}")
else:
    print("No checkpoint found. Starting training from zero.")

In [None]:
# Training Loop
print("Starting training...")
total_iterations = epochs * len(train_loader) # Define total_iterations here for scheduler

patience = 5 # Patience for early stopping
epochs_no_improve = 0 # Counter for early stopping

for epoch in range(start_epoch, start_epoch + epochs):
    start_time = time.time()

    model.train()
    running_loss = 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{start_epoch+epochs}")
    for images, masks in loop:
        current_iteration += 1
        current_lr = poly_lr_scheduler(optimizer, initial_lr, current_iteration, total_iterations, power=poly_power)

        images = images.to(device, memory_format=torch.channels_last, non_blocking=True)
        masks = masks.to(device, non_blocking=True)
        
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            out = model(images)['out']
            loss = criterion(out, masks)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item()
        loop.set_postfix(loss=running_loss/(loop.n+1), lr=f"{current_lr:.6f}")

    train_elapsed = time.time() - start_time
    avg_train_loss = running_loss/len(train_loader)
    print(f"\nEpoch {epoch+1}/{start_epoch+epochs} - Avg Loss: {avg_train_loss:.4f} - Time: {train_elapsed:.2f} sec")

    # Validation during Training
    # Pass 7 for num_classes_for_metric to evaluate mIoU over your 7 known macro classes (0-6).
    val_loss, val_miou = evaluate_model(model, val_loader, device, num_classes_for_metric=7 , ignore_index=255)
    print(f"Validation Loss: {val_loss:.4f}, Validation mIoU: {val_miou:.4f}")

    # Save Checkpoint (only if mIoU improves)
    if val_miou > best_miou:
        best_miou = val_miou
        epochs_no_improve = 0 # Reset early stopping counter
        torch.save({
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scaler': scaler.state_dict(),
            'epoch': epoch,
            'best_miou': best_miou,
            'current_iteration': current_iteration
        }, checkpoint_path)
        print(f"Model saved in {checkpoint_path} with improved mIoU: {best_miou:.4f}")
    else:
        epochs_no_improve += 1
        if epochs_no_improve == 1:
            print(f"No improvement in mIoU for {epochs_no_improve} epoch. Best current: {best_miou:.4f}")
        else:
            print(f"No improvement in mIoU for {epochs_no_improve} epochs. Best current: {best_miou:.4f}")
        # Early Stopping
        if epochs_no_improve == patience:
            print(f"Early stopping triggered after {patience} epochs without improvement.")
            break # Exiting the training loop

print("\nTraining completed!")

### Evaluation (in our case with inference part)

In [None]:
print("\nStarting Inference and Uncertainty Quantification")

# Load model for inference (ensure it's the same architecture as trained)
model_inf = deeplabv3_resnet50(weights_backbone=ResNet50_Weights.IMAGENET1K_V1, num_classes=model_output_classes).to(device)
if os.path.exists(checkpoint_path):
    checkpoint_inf = torch.load(checkpoint_path, map_location=device)
    model_inf.load_state_dict(checkpoint_inf['model'])
    print("Model loaded for inference.")
else:
    print(f"Checkpoint model not found in {checkpoint_path}. Impossible to proceed with inference.")
    exit() # Exit if model is not found, as subsequent steps depend on it

model_inf.eval() # Set model to evaluation mode

In [None]:
# Conformal Prediction Calibration Step (using validation set as calibration data)
# This step is crucial for Conformal Prediction to establish the 'q_hat' threshold.
print("\nConformal Prediction: Calibration phase for Inference")
calibration_scores = []

with torch.no_grad():
    for images, masks in tqdm(val_loader, desc="Calibration"): # Use val_loader as calibration set
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True) # Ground truth IDs (B, H, W)

        outputs = model_inf(images)['out'] # Logits (B, C, H, W)
        
        # Resize outputs to original mask size if necessary (it should be resize already)
        outputs = torch.nn.functional.interpolate(outputs, size=masks.shape[1:], mode='bilinear', align_corners=False)
        
        probabilities = torch.sigmoid(outputs) # (B, C, H, W)

        # Create a clamped version of masks for indexing.
        # Any value 255 will be replaced by 0 (or any valid class ID, e.g., 0)
        # This prevents "index out of bounds" in torch.gather.
        clamped_masks = torch.where(masks == 255, torch.tensor(0, device=device, dtype=masks.dtype), masks)
        
        # For each pixel, get the probability of the true class
        # masks has shape (B, H, W). unsqueeze(1) makes it (B, 1, H, W) for gather.
        true_class_probs = torch.gather(probabilities, 1, clamped_masks.unsqueeze(1)).squeeze(1) # (B, H, W)

        # Conformal Score: s_i = 1 - P(Y_true | X_i)
        # Filter out ignore_index (255) from scores as they are not part of the ground truth for coverage
        valid_pixels_mask = (masks != 255) # Use original masks for filtering
        pixel_scores = (1 - true_class_probs)[valid_pixels_mask] # Only consider valid pixels

        calibration_scores.extend(pixel_scores.cpu().numpy())

calibration_scores = np.sort(calibration_scores)
n_calib = len(calibration_scores)
# Calculate the quantile q_hat: ceil((n+1)(1-alpha))/n empirical quantile
# The +1 in numerator makes it non-asymptotic and valid (as per Angelopoulos et al. paper)
# The index should be 0-based
idx = int(np.ceil((n_calib + 1) * (1 - alpha))) - 1 # -1 for 0-based indexing
q_hat = calibration_scores[idx] if idx < n_calib else (calibration_scores[-1] if n_calib > 0 else 0) # Handle edge case if idx out of bounds or empty
print(f"Calibration completed. Number of valid pixels for calibration: {n_calib}, q_hat: {q_hat:.4f}")


In [None]:
# Perform Inference on an Example Image
print("\nInference Execution on Sample Image")

# Load and preprocess example image
image = Image.open(example_image_path).convert('RGB')
ground_truth_mask_ids = Image.open(example_ground_truth_path)

# Resize images for consistency before processing
image_resized = image.resize((resize[1], resize[0]), Image.BILINEAR)
ground_truth_mask_ids_resized = ground_truth_mask_ids.resize((resize[1], resize[0]), Image.NEAREST)

# Apply macro mapping directly to the loaded labelTrainIds
# This requires creating a numpy mapping array as done in CityscapesFineDataset
cityscapes_macro_mapping_array = np.full(256, UNKNOWN_OBSTACLE_ID, dtype=np.uint8) # Default to UNKNOWN_OBSTACLE_ID (7)
for orig_id, macro_id in CITYSCAPES_19_TO_7_MACRO.items():
    if 0 <= orig_id < 255:
        cityscapes_macro_mapping_array[orig_id] = macro_id
cityscapes_macro_mapping_array[255] = 255 # Ensure ignore_index remains 255
mapped_ground_truth_mask_id = cityscapes_macro_mapping_array[np.array(ground_truth_mask_ids_resized)]
    
# Apply transformations for input image
input_tensor = transform(image_resized).unsqueeze(0).to(device)

with torch.no_grad():
    output_logits = model_inf(input_tensor)['out']
    # Ensure output matches visualization size
    output_logits = torch.nn.functional.interpolate(output_logits, size=resize, mode='bilinear', align_corners=False)
    
    probabilities = torch.sigmoid(output_logits).squeeze(0) # Shape: (C, H, W)

    # 1. Predicted Segmentation (Known Classes ONLY for this visualization)
    # Get argmax over known classes (0-6)
    # This will give the most probable known class, but if UNKNOWN_OBSTACLE_ID (7)
    # had the highest probability overall, it won't be reflected here directly.
    predicted_known_only_argmax = torch.argmax(probabilities[:UNKNOWN_OBSTACLE_ID, :, :], dim=0).cpu().numpy()
    
    # Identify pixels where the model's overall highest probability was for the UNKNOWN_OBSTACLE_ID (7)
    overall_prediction = torch.argmax(probabilities, dim=0) # Argmax over all channels (0-7)
    is_overall_unknown = (overall_prediction == UNKNOWN_OBSTACLE_ID).cpu().numpy()

    # For the "Predicted Segmentation (Known Classes)" visualization, set pixels predicted as unknown to ignore_id
    # This makes them appear black in the plot, indicating they are not classified as known.
    prediction_for_known_viz = predicted_known_only_argmax.copy()
    prediction_for_known_viz[is_overall_unknown] = 255 # Set to ignore_id (black) for this specific visualization

    # Unknown Objectness Score (UOS) Approximation (Noguchi et al.)
    # Define "object" class IDs from our 7 macro classes (0-6).
    # Based on CITYSCAPES_19_TO_7_MACRO:
    # Macro Class 2: Human (person, rider) -> ID 2
    # Macro Class 3: Vehicle Group -> ID 3
    # Macro Class 5: Objects (pole, traffic light, traffic sign) -> ID 5
    object_class_ids = [2, 3, 5]
    
    # Filter out any IDs that might be None or out of bounds.
    # Crucially, ensure these are less than UNKNOWN_OBSTACLE_ID (7) as UOS is calculated over *known* objects.
    object_class_ids = [idx for idx in object_class_ids if idx is not None and 0 <= idx < UNKNOWN_OBSTACLE_ID]
    
    # Approximate p_O (Objectness Score): Sum of probabilities for object classes
    if object_class_ids:
        # Sum over the probabilities of the identified object classes
        # Use probabilities from channels corresponding to known objects
        objectness_score = torch.sum(probabilities[object_class_ids, :, :], dim=0).cpu().numpy()
    else:
        objectness_score = np.zeros(probabilities.shape[1:], dtype=np.float32)
        print("Warning: No object class IDs defined for UOS calculation. Objectness score will be zero.")

    # Calculate the product term: Product of (1 - p_ik) for all K known classes
    # probabilities[:-1, :, :] correctly slices to include channels 0 through 6 (the 7 macro classes),
    # explicitly excluding the last channel (index 7) which is for the UNKNOWN_OBSTACLE_ID.
    product_term = torch.prod(1 - probabilities[:-1, :, :], dim=0).cpu().numpy()

    # Calculate Unknown Objectness Score (UOS)
    unknown_objectness_score = objectness_score * product_term
    
    # 2. Combined Prediction (Known + Unknown Obstacle)
    # Start with the argmax over ALL classes (0-7)
    combined_prediction = overall_prediction.cpu().numpy().copy()
    # Then, if UOS is high, overwrite with UNKNOWN_OBSTACLE_ID (7)
    is_anomaly_pixel = (unknown_objectness_score > anomaly_threshold_uos)
    combined_prediction[is_anomaly_pixel] = UNKNOWN_OBSTACLE_ID # This will be ID 7

    # Conformal Prediction Sets (Varisco Heatmap - Mossina et al.)
    # Form pixel-wise prediction sets: C(x) = {y' | P(y'|x) >= 1 - q_hat}
    p_threshold_conformal = 1 - q_hat

    # varisco_heatmap will show the number of classes in the prediction set for each pixel
    # `probabilities[:-1, :, :]` ensures only known classes (0-6) are considered for the prediction set size.
    varisco_heatmap = torch.sum((probabilities[:-1, :, :] >= p_threshold_conformal), dim=0).cpu().numpy()

In [None]:
# Enhanced Visualization
plt.figure(figsize=(20, 12))

# Original Image
plt.subplot(2, 3, 1)
plt.title("Original Image")
plt.imshow(image_resized)
plt.axis('off')

# Predicted Segmentation (Known Classes ONLY)
plt.subplot(2, 3, 2)
plt.title("Predicted Segmentation (Known Classes Only)")
# Use the specially prepared prediction_for_known_viz for this plot
seg_image_known = decode_segmap(prediction_for_known_viz, colors_list=ALL_COLORS_FOR_VISUALIZATION[:UNKNOWN_OBSTACLE_ID]) # <--- Uses colors for 0-6
plt.imshow(seg_image_known)
plt.axis('off')

# Ground Truth Segmentation
plt.subplot(2, 3, 3)
plt.title("Ground Truth Segmentation")
# Decode GT using colors for known classes only (0-6)
seg_image_gt = decode_segmap(mapped_ground_truth_mask_id, colors_list=ALL_COLORS_FOR_VISUALIZATION[:UNKNOWN_OBSTACLE_ID]) # <--- Uses colors for 0-6
plt.imshow(seg_image_gt)
plt.axis('off')

# Anomaly Score (UOS) Map (Noguchi et al.)
plt.subplot(2, 3, 4)
plt.title("Anomaly Map (Unknown Objectness Score)")
plt.imshow(unknown_objectness_score, cmap='magma')
plt.colorbar(label='UOS')
plt.axis('off')

# Combined Predicted Segmentation (Known + Unknown Obstacle)
plt.subplot(2, 3, 5)
plt.title(f"Combined Prediction (Known + Unknown Obstacle ID {UNKNOWN_OBSTACLE_ID})")
# Use ALL_COLORS_FOR_VISUALIZATION which includes the color for ID 7 (white)
seg_image_combined = decode_segmap(combined_prediction, colors_list=ALL_COLORS_FOR_VISUALIZATION) # <--- Uses all 8 colors
plt.imshow(seg_image_combined)
plt.axis('off')

# Conformal Prediction Set Size (Varisco Heatmap - Mossina et al.)
plt.subplot(2, 3, 6)
plt.title(f"Conformal Prediction Set Size (α={alpha})")
plt.imshow(varisco_heatmap, cmap='viridis')
plt.colorbar(label='Number of classes in the set')
plt.axis('off')

plt.tight_layout()
plt.show()