## IMPORTS

In [None]:
#import sklearn
import os
from PIL import Image
import torch
import torch.nn as nn
from tqdm import tqdm
import cv2
import numpy as np
import random
import shutil
import torchvision.transforms as T
import torch.nn.functional as F
import torchvision
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.models.segmentation import deeplabv3_resnet50
#import segmentation_models_pytorch as smp
import torch.optim as optim


from utils import visualize_uos_with_conformal, nonconformity_score, fix_lostandfound, fix_cityscapes, CityscapesTrainEvalDataset, CityscapesTestDataset, visualize_one_hot_vertical, visualize_erosion_mask, visualize_dilation_mask, visualize_boundary_mask, LostAndFoundTrainEvalDataset, MultiLabelDeepLabV3, BoundaryAwareBCELoss, get_boundary_mask_batch, BoundaryAwareBCELossFineTuning, pixel_accuracy, mean_iou, dice_score, precision_recall, unknown_objectness_score, uos_heatmap

## GLOBALS

In [None]:
# Macro class index mapping
MACRO_CLASSES = {
    "road": 0,
    "flat": 1,
    "human": 2,
    "vehicle": 3,
    "construction": 4,
    "background": 5,
    "pole": 6,
    "object": 7,  # auxiliary objectness channel
}

# Map from original label ID to (macro class or None, is_object)   [None is only for the poles and traffic signs and lights]
CLASS_MAPPING = {
    7: ("road", False), # road
    8: ("flat", False), # sidewalk
    11: ("construction", False), # building
    12: ("construction", False), # wall
    13: ("construction", False), # fence
    17: ("pole", True),  # pole
    19: ("pole", True),  # traffic sign
    20: ("pole", True),  # traffic light
    21: ("background", False), # vegetation
    22: ("flat", False), # terrain
    23: ("background", False), # sky
    24: ("human", True), # person
    25: ("human", True), # rider
    26: ("vehicle", True), # car
    27: ("vehicle", True), # truck
    28: ("vehicle", True), # bus
    31: ("vehicle", True), # train
    32: ("vehicle", True), # motorcycle
    33: ("vehicle", True), # bicycle
}

class_names_8 = ["road", "flat", "human", "vehicle", "construction", "background", "pole", "object"]

# Set the relative path for the dataset
relative_path = '../../'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
print(f"PyTorch version: {torch.__version__}")
print(f"Torchvision version: {torchvision.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version PyTorch was built with: {torch.version.cuda}")
try:
    print(f"CUDA runtime version: {torch._C._cuda_getCompiledVersion()}")
except AttributeError:
    print("CUDA is not available, running on CPU.")

In [None]:
def set_seed(seed=1996734):
    random.seed(seed)                      # Python
    np.random.seed(seed)                   # NumPy
    torch.manual_seed(seed)                # PyTorch CPU
    torch.cuda.manual_seed(seed)           # PyTorch GPU

set_seed(1996734)  # Call this at the top of your script

## DATA

**Fixing the Datasets**

Here is the function to fix the structure of the two datasets, downloaded from the official CityScapes and LostAndFound websites, from their original form to a preferred one, following the structure defined therein

In [None]:
is_fixed = True
if not is_fixed:
    print("Fixing cityscapes dataset...")
    # Fix the dataset
    fix_cityscapes(relative_path + 'cityscapes', relative_path + 'cityscapes_f')

In [None]:
is_fixed = True
if not is_fixed:
    print("Fixing lostandfound dataset...")
    # Fix the dataset
    fix_lostandfound(relative_path + 'datasets/lostandfound', relative_path + 'lostandfound_f')

**Loading the Datasets**

In [None]:
train_set = CityscapesTrainEvalDataset(relative_path + 'cityscapes_f/img/train', relative_path + 'cityscapes_f/mask/train')
benchmark_set = CityscapesTrainEvalDataset(relative_path + 'cityscapes_f/img/val', relative_path + 'cityscapes_f/mask/val')

# Split the training set into training, validation and calibration sets
train_size = len(train_set)
cal_and_val_size = int(0.2 * train_size)
cal_size = int(0.5 * cal_and_val_size)  # 10% of the original training set size
train_set, calib_and_val_set = random_split(train_set, [train_size - cal_and_val_size, cal_and_val_size])
cal_set, val_set = random_split(calib_and_val_set, [cal_size, cal_and_val_size - cal_size])

batch_size = 5

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True) # drop_last=True to ensure all batches have the same size
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, drop_last=True)
cal_loader = DataLoader(cal_set, batch_size=batch_size, shuffle=False, drop_last=True)
benchmark_loader = DataLoader(benchmark_set, batch_size=batch_size, shuffle=False, drop_last=True)

for imgs, masks, original_mask in val_loader:
    print("Batch of images shape:", imgs.shape)  # Should be [B, 3, H, W]
    print("Batch of masks shape:", masks.shape)  # Should be [B, 8, H, W]
    print("Original mask shape:", original_mask.shape)  # Should be [B, H, W]
    break  # Just to check the first batch

In [None]:
dataset_lostandfound = LostAndFoundTrainEvalDataset(relative_path + 'lostandfound_f/img/train', relative_path + 'lostandfound_f/mask/train')

batch_size = 5

# Define the split ratio
train_ratio = 0.8  # 80% for training, 20% for validation
train_size = int(train_ratio * len(dataset_lostandfound))
val_size = len(dataset_lostandfound) - train_size

# Split the dataset
train_subset, val_subset = random_split(dataset_lostandfound, [train_size, val_size])

# Create DataLoaders
train_loader_lostandfound = DataLoader(train_subset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader_lostandfound = DataLoader(val_subset, batch_size=batch_size, shuffle=False, drop_last=False)


for imgs, masks, original_mask in train_loader_lostandfound:
    print("Batch of images shape:", imgs.shape)  # Should be [B, 3, H, W]
    print("Batch of masks shape:", masks.shape)  # Should be [B, 8, H, W]
    print("Original mask shape:", original_mask.shape)  # Should be [B, H, W]
    break  # Just to check the first batch

**Some Visualization**

In [None]:
visualize_one_hot_vertical(masks[1], class_names=class_names_8)

In [None]:
visualize_erosion_mask(np.array(Image.open(relative_path + 'cityscapes_f/mask/train/train1_m.png')))
visualize_dilation_mask(np.array(Image.open(relative_path + 'cityscapes_f/mask/train/train1_m.png')))
visualize_boundary_mask(np.array(Image.open(relative_path + 'cityscapes_f/mask/train/train1_m.png')), iterations=7)

## NETWORK

In [None]:
model = MultiLabelDeepLabV3(n_classes=8).to(device)

criterion = BoundaryAwareBCELoss(lambda_weight=3.0)

# Parameters
initial_lr = 0.01
momentum = 0.9
weight_decay = 0.0001
power = 0.9
num_epochs = 2
boundary_iterations = 7

optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=momentum, weight_decay=weight_decay)

# Calculate max iterations for poly schedule
max_iter = num_epochs * len(train_loader)
current_iter = 0

# Early stopping parameters
best_val_loss = float('inf')
patience = 3
counter = 0
early_stop = False

## TRAIN

In [None]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for imgs, masks, original_mask in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        preds = model(imgs)
        boundary_masks = get_boundary_mask_batch(original_mask, iterations=boundary_iterations).detach().to(device)
        loss = criterion(preds, masks, boundary_masks)
        loss.backward()
        optimizer.step()

        # Poly LR update
        current_iter += 1
        lr = initial_lr * (1 - current_iter / max_iter) ** power
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        
        running_loss += loss.item()
    
    avg_train_loss = running_loss / len(train_loader)
    
    # ---- VALIDATION STEP ----
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, masks, original_mask in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            preds = model(imgs)
            boundary_masks = get_boundary_mask_batch(original_mask, iterations=boundary_iterations).detach().to(device)
            loss = criterion(preds, masks, boundary_masks)
            val_loss += loss.item()
        torch.cuda.empty_cache() 
    avg_val_loss = val_loss / len(val_loader)
    
    print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {lr:.6f}")
    
    # ---- EARLY STOPPING LOGIC ----
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        counter = 0
        # Save best model
        torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_loss': best_val_loss,
        'current_iter': current_iter,
        }, 'weights/new_weights/model_boundary_7_epoch_2.pth')
    else:
        counter += 1
        if counter >= patience:
            print(f"Early stopping at epoch {epoch+1}. Best val loss: {best_val_loss:.4f}")
            early_stop = True
            break

## FINE-TUNING

In [None]:
# Ensure you have imported these properly
# from your_model_module import MultiLabelDeepLabV3, BoundaryAwareBCELoss, get_boundary_mask_batch

# Instantiate the model
model = MultiLabelDeepLabV3(n_classes=8).to(device)

# Load the checkpoint
checkpoint = torch.load('weights/new_weights/model_boundary_5_lambda_3.0_epochs_2.pth', map_location=device)

# Load weights into model
model.load_state_dict(checkpoint['model_state_dict'])

# Reinitialize optimizer
fine_tune_lr = 1e-4
momentum = 0.9
weight_decay = 0.0001
optimizer = optim.SGD(model.parameters(), lr=fine_tune_lr, momentum=momentum, weight_decay=weight_decay)

# Optionally load previous optimizer state
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Resume metadata
start_epoch = checkpoint['epoch'] + 1
best_val_loss = float('inf')  # Initialize best validation loss
boundary_iterations = 2  # Number of iterations for boundary mask computation

# Criterion
criterion = BoundaryAwareBCELossFineTuning(lambda_weight=3.0)


# Early Stopping
patience = 3
counter = 0
early_stop = False

# Fine-tuning parameters
num_finetune_epochs = 2
power = 0.9  # for poly LR schedule
# Calculate max iterations for poly schedule
max_iter = num_finetune_epochs * len(train_loader_lostandfound)
current_iter = 0

In [None]:
# Fine-tuning loop
for epoch in range(0, num_finetune_epochs):
    model.train()
    running_loss = 0.0

    for imgs, masks, original_mask in tqdm(train_loader_lostandfound, desc=f"Fine-tune Epoch {epoch+1}"):
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()

        preds = model(imgs)
        boundary_masks = get_boundary_mask_batch(original_mask, iterations=boundary_iterations).detach().to(device)

        loss = criterion(preds, masks, boundary_masks)
        loss.backward()
        optimizer.step()

        # Poly learning rate update
        current_iter += 1
        lr = fine_tune_lr * (1 - current_iter / max_iter) ** power
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader_lostandfound)

    # ---- Validation ----
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, masks, original_mask in val_loader_lostandfound:
            imgs, masks = imgs.to(device), masks.to(device)
            preds = model(imgs)
            boundary_masks = get_boundary_mask_batch(original_mask, iterations=boundary_iterations).detach().to(device)
            loss = criterion(preds, masks, boundary_masks)
            val_loss += loss.item()
        torch.cuda.empty_cache()

    avg_val_loss = val_loss / len(val_loader_lostandfound)
    print(f"[Fine-tune Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {lr:.6f}")

    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_loss': best_val_loss,
            'current_iter': current_iter,
        }, 'weights/new_weights/fine_tuned_model_boundary_5_lambda_3.0_boundaryft_2_epochsft_2.pth')

    else:
        counter += 1
        if counter >= patience:
            print(f"Early stopping at epoch {epoch+1}. Best val loss: {best_val_loss:.4f}")
            early_stop = True
            break

## TEST

In [None]:
#### NEW METHOD ####

calibration_scores = []

# Define the macro-classes considered as known objects
KNOWN_OBJECT_CLASSES = [2, 3, 6]  # human, vehicle, pole

model = MultiLabelDeepLabV3(n_classes=8).to(device)

checkpoint = torch.load('inserire peso', map_location = device) #inserire peso

model.load_state_dict(checkpoint['model_state_dict'])

model.eval()
for images_batch, labels_batch, _ in cal_loader: 
    images_batch = images_batch.to(device)
    labels_batch = labels_batch.to(device)  # shape: (B, C, H, W)

    with torch.no_grad():
        output = model(images_batch)  # shape: (B, 8, H, W), sigmoid
        nonconformity_scores = nonconformity_score(output)

    for b in range(images_batch.size(0)):
        lbl = labels_batch[b]  # (C, H, W)
        non_conf_score = nonconformity_scores[b]  # (H, W)

        # Build a binary mask of pixels that belong to known object classes
        mask = lbl[KNOWN_OBJECT_CLASSES].any(dim=0) # (H, W) true for pixels belonging to known objects

        # Apply mask and extract corresponding unknown objectness scores
        selected_scores = non_conf_score[mask]  # 1D tensor

        calibration_scores.extend(selected_scores.cpu().numpy())

# Save sorted calibration scores
calibration_scores = np.sort(np.array(calibration_scores))
np.save("calibration_scores.npy", calibration_scores)


In [None]:
# Visualize the heatmap superposed on the image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Recreate the model architecture
model = MultiLabelDeepLabV3(n_classes=8)

model.to(device)  # move to GPU or CPU as appropriate

resized_height = 512
resized_width = 1024

transform = T.Compose([
    T.Resize((resized_height, resized_width)),  # Resize to half the original size
    T.ToTensor(),  # converts in [0, 1], shape [3, H, W]
    T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225])
])


path = "../../obstacles12_rocks5.png"

with Image.open(path) as img:
    img = img.convert("RGB")  # Ensure it's in RGB mode
    img = img.resize((1024, 512), resample=Image.BILINEAR)  # Resize to match model input size
    transform = T.Compose([
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img_tensor = transform(img)  # Add batch dimension and move to device


weight_name = "fine_tuned_model_boundary_5_lambda_3.0_epochs_5_boundaryft_1_epochsft_3.pth"  # Change this to the desired model name

checkpoint = torch.load('weights/new_weights/' + weight_name, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406])[:, None, None]
    std = torch.tensor([0.229, 0.224, 0.225])[:, None, None]
    return (tensor * std + mean).clamp(0, 1)
#img = denormalize(img)

with torch.no_grad():
    uos = unknown_objectness_score(model(img_tensor.unsqueeze(0).to(device)))[0]
    # plt.imshow(uos.cpu().numpy(), cmap='hot')
    # plt.title(f"UOS for val_set[{idx}]")
    # plt.colorbar()
    # plt.axis('off')
    # plt.show()
img_test = denormalize(img_tensor)
uos_heatmap(img_test, uos, threshold=0.6, alpha_val=0.5)

threshold = 0.2024 # Set your conformal threshold here
visualize_uos_with_conformal(model, img_test, device, threshold)

In [None]:
imgs, masks, original_mask = next(iter(val_loader))  # Get a batch from the validation set

imgs = imgs.to(device)

preds = (model(imgs))  # if output is logits
preds = preds.detach().cpu()
masks = masks.cpu()

print("Pixel Accuracy:", pixel_accuracy(preds, masks))
print("Mean IoU:", mean_iou(preds, masks))
print("Dice per class:", dice_score(preds, masks))
prec, rec = precision_recall(preds, masks)
print("Precision per class:", prec)
print("Recall per class:", rec)

### BENCHMARK