# Semantic Segmentation on BDD10k and Dark Zurich using the U-net ResNet-34 Model

This notebook contains the code for fine-tuning the pre-trained U-net model on the BDD10k training images, and evaluating on Dark Zurich and BDD10k validation images.

NOTE: I had to clear the cell outputs of this notebook in order to fit the file size constraints for uploading to github. However, you can view the cell outputs in my Google CoLab Notebook: https://colab.research.google.com/drive/1fe95l4P-zbXTtlqoaFwwRri-nlRx3SzN?usp=sharing

NOTE: much of the code in this notebook was borrored or adapted from this repository: https://github.com/ronigold/sem-seg-bdd100k/tree/main




## Setup

In [None]:
import sys
from pathlib import Path
from PIL import Image
import numpy as np
from fastai.vision.all import *
import random

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

%matplotlib inline

In [None]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Define the base paths
BASE_PATH = Path('/content/drive/MyDrive/')
INPUT_PATH = BASE_PATH/'bdd10k_images'
TARGET_PATH = BASE_PATH/'bdd100k_seg_maps/labels/'

# Define the pixel2class dict
vocab_dict = {
    0: "road",
    1: "sidewalk",
    2: "building",
    3: "wall",
    4: "fence",
    5: "pole",
    6: "traffic light",
    7: "traffic sign",
    8: "vegetation",
    9: "terrain",
    10: "sky",
    11: "person",
    12: "rider",
    13: "car",
    14: "truck",
    15: "bus",
    16: "train",
    17: "motorcycle",
    18: "bicycle",
    19: "unknown"
}

# Training the UNET Model
def find_unique_classes_in_masks(mask_dir):
    unique_classes = set()
    for mask_path in mask_dir.iterdir():
        if mask_path.is_file() and mask_path.suffix in ['.png', '.jpg']:  # Ensure it's an image file
            mask = np.array(Image.open(mask_path))
            unique_classes.update(np.unique(mask))
    return unique_classes


In [None]:
# Define the directory containing your validation masks
mask_dir = TARGET_PATH/'val'

# Find unique classes
unique_classes = find_unique_classes_in_masks(mask_dir)

print(f"Unique classes found: {sorted(unique_classes)}")
print(f"Total number of unique classes: {len(unique_classes)}")
assert(len(unique_classes) == len(vocab_dict))

# Define the directory containing your training masks
mask_dir = TARGET_PATH/'train'

# Find unique classes
unique_classes = find_unique_classes_in_masks(mask_dir)

print(f"Unique classes found: {sorted(unique_classes)}")
print(f"Total number of unique classes: {len(unique_classes)}")
assert(len(unique_classes) == len(vocab_dict))


In [None]:
def get_adjusted_mask_file_path(x, mask_source_dir=None):
    """
    Given an input image path, returns the corresponding mask with values adjusted.
    All 255 values in the mask are changed to 19.

    Args:
    - x (Pathlib.Path): Path to the input image file.
    - mask_source_dir (Pathlib.Path, optional): The absolute path to the directory
                                                containing the corresponding masks.
                                                If None, it infers the mask path
                                                based on the existing TARGET_PATH logic for 'val'/'train'.

    Returns:
    - PIL.Image: The adjusted mask image.
    """
    base = x.stem
    mask_filename = base + "_train_id.png"

    if mask_source_dir:
        mask_path = mask_source_dir / mask_filename
    else:
        # Original logic for 'val' and 'train' subfolders
        is_validation = "val" in str(x)
        if is_validation:
            mask_path = TARGET_PATH / "val" / mask_filename
        else:
            mask_path = TARGET_PATH / "train" / mask_filename

    mask = np.array(Image.open(mask_path))
    mask[mask == 255] = 19             # Remap unknown → class 19
    return Image.fromarray(mask)


In [None]:
def custom_splitter(file_path):
    """Custom splitter for DataBlock to separate training and validation datasets based on folder structure."""
    is_valid = 'val' in str(file_path)
    return is_valid

# Load the DataBlock
segmentation_datablock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=None)), # codes are the classes for segmentation
                                   get_items=get_image_files,
                                   splitter=FuncSplitter(custom_splitter),
                                   get_y=get_adjusted_mask_file_path,
                                   item_tfms=Resize(460),
                                   batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)])

# Load the DataLoaders
dls = segmentation_datablock.dataloaders(INPUT_PATH, path=BASE_PATH, bs = 64, num_workers=11, pin_memory=True,
    prefetch_factor=4)

dls.show_batch()


In [None]:
dls.train_ds


In [None]:
dls.valid_ds

In [None]:
learn = unet_learner(dls, resnet34, n_out=len(unique_classes))


## Model Training (DO NOT RUN EVERYTIME)

In [None]:

lrs = learn.lr_find(suggest_funcs=(minimum, steep, valley, slide))


In [None]:
lrs.valley


In [None]:
learn.fine_tune(25, base_lr = lrs.valley, cbs=[ShowGraphCallback(), CSVLogger(fname='log.csv')])


In [None]:
# save the model
learn.save('unet_segmentation_model')

## Re-Load Trained Model from File

In [None]:
# To load the model from the file
learn.load('unet_segmentation_model')

In [None]:

learn.show_results(dl=dls.valid, max_n=12)

In [None]:
# predict on validation data
inputs, preds, targets, decoded_preds = learn.get_preds(dl=dls.valid, with_input=True, with_decoded = True)

# Evaluation

In [None]:
import numpy as np
import pandas as pd
import torch

def iou(preds, targs, num_classes=20):
    # Calculate Intersection over Union (IoU) for each class
    ious = []
    preds = preds.view(-1)
    targs = targs.view(-1)

    for cls in range(num_classes):  # Exclude the last class ('unknown')
        pred_inds = preds == cls
        target_inds = targs == cls
        intersection = (pred_inds[target_inds]).long().sum().item()  # True positives
        union = pred_inds.long().sum().item() + target_inds.long().sum().item() - intersection
        if union > 0:
            ious.append((cls, float(intersection) / float(max(union, 1))))
        else:
            ious.append((cls, float('nan')))
    return ious

def precision_recall(preds, targs, num_classes=20):
    # Calculate precision and recall for each class
    precisions = []
    recalls = []
    preds = preds.view(-1)
    targs = targs.view(-1)

    for cls in range(num_classes):
        pred_inds = preds == cls
        target_inds = targs == cls
        tp = (pred_inds[target_inds]).long().sum().item()  # True positives
        fp = (pred_inds[~target_inds]).long().sum().item()  # False positives
        fn = (~pred_inds[target_inds]).long().sum().item()  # False negatives

        precision = tp / (tp + fp) if (tp + fp) > 0 else float('nan')
        recall = tp / (tp + fn) if (tp + fn) > 0 else float('nan')

        precisions.append((cls, precision))
        recalls.append((cls, recall))

    return precisions, recalls

def calculate_pixel_accuracy(preds, targs):
    # Renamed to avoid collision with fastai.metrics.accuracy
    preds = preds.view(-1)
    targs = targs.view(-1)
    correct = (preds == targs).float().sum()
    return correct / preds.shape[0]

def evaluate_segmentation_model(preds, targs, vocab_dict):
    num_classes = len(vocab_dict)
    ious = iou(preds, targs, num_classes)
    precisions, recalls = precision_recall(preds, targs, num_classes)
    acc = calculate_pixel_accuracy(preds, targs).item() # Call the custom accuracy function

    # Ensure targs is of integer type before calling bincount
    targs_int = targs.view(-1).long()  # Cast targs to long to use with bincount
    class_frequencies = targs_int.bincount(minlength=num_classes)
    total_pixels = class_frequencies.sum().item()
    class_percentages = (class_frequencies / total_pixels * 100).tolist()

    # Compile metrics into a DataFrame for neat presentation
    metrics = []
    for cls in range(num_classes):
        class_name = vocab_dict[cls]
        metrics.append({
            'Class': class_name,
            'IoU': ious[cls][1],
            'Precision': precisions[cls][1],
            'Recall': recalls[cls][1],
            'Percentage (%)': class_percentages[cls]
        })

    metrics_df = pd.DataFrame(metrics)
    metrics_df.set_index('Class', inplace=True)
    print(f"Overall Accuracy: {acc:.4f}")
    return metrics_df

In [None]:
evaluation_df = evaluate_segmentation_model(decoded_preds, targets, vocab_dict)
evaluation_df

In [None]:
# mean IOU
evaluation_df['IoU'].mean()

# Plotting

In [None]:
import matplotlib.pyplot as plt
import torch
from matplotlib.colors import ListedColormap

BDD_COLORS = [
    (128,  64, 128),  # road
    (244,  35, 232),  # sidewalk
    (70,   70,  70),  # building
    (102, 102, 156),  # wall
    (190, 153, 153),  # fence
    (153, 153, 153),  # pole
    (250, 170,  30),  # traffic light
    (220, 220,   0),  # traffic sign
    (107, 142,  35),  # vegetation
    (152, 251, 152),  # terrain
    (70,  130, 180),  # sky
    (220,  20,  60),  # person
    (255,   0,   0),  # rider
    (0,     0, 142),  # car
    (0,     0,  70),  # truck
    (0,    60, 100),  # bus
    (0,    80, 100),  # train
    (0,     0, 230),  # motorcycle
    (119,  11,  32),  # bicycle
    (0,     0,   0),  # unknown
]

BDD_COLORS = np.array(BDD_COLORS) / 255.0

cmap = ListedColormap(BDD_COLORS)

def insert_line_breaks(text, max_chars_per_line=30):
    words = text.split(', ')
    lines = []
    current_line = ''
    for word in words:
        if len(current_line) + len(word) + 2 > max_chars_per_line:
            lines.append(current_line)
            current_line = word
        else:
            if current_line:
                current_line += ', ' + word
            else:
                current_line = word
    lines.append(current_line)  # add the last line
    return '\n'.join(lines)

def denormalized_image(img):
    img = img.float()
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    if torch.min(img) < 0:
        img = img * std + mean  # Denormalize
    img = img.permute(1, 2, 0)  # Rearrange channels for plotting
    img = img.clamp(0, 1)  # Clamp values to ensure they are within [0, 1] range
    return img

def get_classes_from_mask(mask, vocab_dict):
    unique_classes = torch.unique(mask).tolist()
    class_names = [vocab_dict[c] for c in unique_classes if c in vocab_dict]
    return ', '.join(class_names)

def visualize_segmentation_batch(inputs, true_masks=None, pred_masks=None, vocab_dict=None, num_imgs=3):
    # Determine the effective number of images to plot based on the shortest provided list/tensor
    effective_num_imgs = len(inputs)
    if true_masks is not None:
        effective_num_imgs = min(effective_num_imgs, len(true_masks))
    if pred_masks is not None:
        effective_num_imgs = min(effective_num_imgs, len(pred_masks))

    # The actual number of images to plot will be the minimum of the user-requested num_imgs
    # and the calculated effective_num_imgs.
    num_to_plot = min(num_imgs, effective_num_imgs)

    print(f"Plotting {num_to_plot} images.") # Added for debugging

    cols = 2 if true_masks is None else 3
    fig, axes = plt.subplots(num_to_plot, cols, figsize=(20, 5*num_to_plot), squeeze=False)

    for i in range(num_to_plot):
        input_image = denormalized_image(inputs[i])

        ax0 = axes[i][0]
        ax0.imshow(input_image)
        ax0.set_title('Input Image')
        ax0.axis('off')

        # Show true mask if provided
        if true_masks is not None:
            ax1 = axes[i][1]
            true_mask = true_masks[i].cpu() # Move to CPU for plotting
            ax1.imshow(true_mask, cmap=cmap, interpolation='nearest')
            true_objects = get_classes_from_mask(true_mask, vocab_dict)
            ax1.set_title(f'True Mask\n{insert_line_breaks(true_objects)}')
            ax1.axis('off')

        # Show predicted mask if provided
        if pred_masks is not None:
            ax_pred = axes[i][1] if true_masks is None else axes[i][2]
            pred_mask = pred_masks[i].cpu() # Move to CPU for plotting
            ax_pred.imshow(pred_mask, cmap=cmap, interpolation='nearest')
            predicted_objects = get_classes_from_mask(pred_mask, vocab_dict)
            ax_pred.set_title(f'Predicted Mask\n{insert_line_breaks(predicted_objects)}')
            ax_pred.axis('off')

    plt.tight_layout()
    plt.show()

visualize_segmentation_batch(inputs, targets, decoded_preds, vocab_dict=vocab_dict, num_imgs=50)


# Evaluate specifically on the dark images:

In [None]:
# define a custom splitter to force all images into validation
def custom_splitter(file_path):
    """Custom splitter for DataBlock to separate training and validation datasets based on folder structure."""
    is_valid = 'val' in str(file_path)
    return is_valid

In [None]:
#DARK_VAL_MASKS_DIR = BASE_PATH/'bdd100k_seg_maps_dark'
def get_adjusted_mask_file_path(x, mask_source_dir=None):
    """
    Given an input image path, returns the corresponding mask with values adjusted.
    All 255 values in the mask are changed to 19.

    Args:
    - x (Pathlib.Path): Path to the input image file.
    - mask_source_dir (Pathlib.Path, optional): The absolute path to the directory
                                                containing the corresponding masks.
                                                If None, it infers the mask path
                                                based on the existing TARGET_PATH logic for 'val'/'train'.

    Returns:
    - PIL.Image: The adjusted mask image.
    """
    base = x.stem
    mask_filename = base + "_train_id.png"

    if mask_source_dir:
        mask_path = mask_source_dir / mask_filename
    else:
        # Original logic for 'val' and 'train' subfolders
        is_validation = "val" in str(x)
        if is_validation:
            mask_path = TARGET_PATH / "val" / mask_filename
        else:
            mask_path = TARGET_PATH / "train" / mask_filename

    mask = np.array(Image.open(mask_path))
    mask[mask == 255] = 19             # Remap unknown → class 19
    return Image.fromarray(mask)


In [None]:
from functools import partial

# Define the directory containing the dark validation images
DARK_VAL_IMAGES_DIR = BASE_PATH/'bdd10k_images_dark'
# Define the directory containing the corresponding masks for dark validation images
# Assuming these masks are in a new subfolder 'val_dark' under TARGET_PATH
DARK_VAL_MASKS_DIR = BASE_PATH/'bdd100k_seg_maps_dark'

# Create a partial function for get_adjusted_mask_file_path for the dark validation set
get_y_dark_val = partial(get_adjusted_mask_file_path, mask_source_dir=DARK_VAL_MASKS_DIR)

# # Create a new DataBlock for the dark validation dataset
# dark_val_datablock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=None)),
#                                get_items=get_image_files,
#                                splitter=lambda x: [list(range(len(x)))],
#                                get_y=get_y_dark_val,
#                                item_tfms=Resize(224, method='crop'), # Changed to deterministic resize and center-crop to 224x224
#                                batch_tfms=[Normalize.from_stats(*imagenet_stats)]) # Removed aug_transforms for deterministic validation

dark_val_datablock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=None)), # codes are the classes for segmentation
                                   get_items=get_image_files,
                                   splitter=FuncSplitter(custom_splitter),
                                   get_y=get_y_dark_val,
                                   item_tfms=Resize(460),
                                   batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)])


# Create DataLoaders for the dark validation set
# Set num_workers=0 for deterministic evaluation to ensure consistent batch ordering
dark_val_dls = dark_val_datablock.dataloaders(DARK_VAL_IMAGES_DIR, path=BASE_PATH, bs=4, num_workers=11, pin_memory=True, prefetch_factor=4)




In [None]:
# check that validation is full:
print(dark_val_dls.train_ds)
print(dark_val_dls.valid_ds)

In [None]:
inputs_dark, preds_dark, targets_dark, decoded_dark = learn.get_preds(
    dl=dark_val_dls.valid,
    with_input=True,
    with_decoded=True
)

In [None]:
evaluation_df = evaluate_segmentation_model(decoded_dark, targets_dark, vocab_dict)
display(evaluation_df)
# mean IOU
# exclude the unknown class
print(evaluation_df['IoU'][:-1].mean())

In [None]:
visualize_segmentation_batch(inputs_dark, targets_dark, decoded_dark, vocab_dict=vocab_dict, num_imgs=80)

# Evaluate on enhanced dark images from bdd10k

In [None]:
from functools import partial

# Define the directory containing the dark validation images
enhanced_IMAGES_DIR = BASE_PATH/'bdd10k_images_gamma'
# Define the directory containing the corresponding masks for dark validation images
DARK_VAL_MASKS_DIR = BASE_PATH/'bdd100k_seg_maps_dark'

# Create a partial function for get_adjusted_mask_file_path for the dark validation set
get_y_dark_val = partial(get_adjusted_mask_file_path, mask_source_dir=DARK_VAL_MASKS_DIR)

# Create a new DataBlock for the dark validation dataset
# dark_val_datablock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=None)),
#                                get_items=get_image_files,
#                                splitter=lambda x: [list(range(len(x)))], # No splitting needed for a single eval set
#                                get_y=get_y_dark_val,
#                                item_tfms=Resize(224, method='crop'), # Changed to deterministic resize and center-crop to 224x224
#                                batch_tfms=[Normalize.from_stats(*imagenet_stats)])

dark_val_datablock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=None)), # codes are the classes for segmentation
                                   get_items=get_image_files,
                                   splitter=FuncSplitter(custom_splitter),
                                   get_y=get_y_dark_val,
                                   item_tfms=Resize(460),
                                   batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)])

# Create DataLoaders for the dark validation set
dark_val_dls_enhanced = dark_val_datablock.dataloaders(enhanced_IMAGES_DIR, path=BASE_PATH, bs=4, num_workers=11, pin_memory=True, prefetch_factor=4)

inputs_dark_enhanced, preds_dark_enhanced, targets_dark_enhanced, decoded_dark_enhanced = learn.get_preds(
    dl=dark_val_dls_enhanced.valid,
    with_input=True,
    with_decoded=True
)

# check dataset sizes
print(dark_val_dls_enhanced.train_ds)
print(dark_val_dls_enhanced.valid_ds)


In [None]:
evaluation_df = evaluate_segmentation_model(decoded_dark_enhanced, targets_dark_enhanced, vocab_dict)
display(evaluation_df)
# Mean IOU
# exclude the unknown class
print(evaluation_df['IoU'][:-1].mean())


In [None]:
visualize_segmentation_batch(inputs_dark_enhanced, targets_dark_enhanced, decoded_dark_enhanced, vocab_dict=vocab_dict, num_imgs=80)

# CLAHE enhanced


In [None]:
from functools import partial

# Define the directory containing the dark validation images
enhanced_IMAGES_DIR = BASE_PATH/'bdd10k_images_dark_clahe'
# Define the directory containing the corresponding masks for dark validation images
DARK_VAL_MASKS_DIR = BASE_PATH/'bdd100k_seg_maps_dark'

# Create a partial function for get_adjusted_mask_file_path for the dark validation set
get_y_dark_val = partial(get_adjusted_mask_file_path, mask_source_dir=DARK_VAL_MASKS_DIR)

# # Create a new DataBlock for the dark validation dataset
# dark_val_datablock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=None)),
#                                get_items=get_image_files,
#                                splitter=lambda x: [list(range(len(x)))], # No splitting needed for a single eval set
#                                get_y=get_y_dark_val,
#                                item_tfms=Resize(224, method='crop'), # Changed to deterministic resize and center-crop to 224x224
#                                batch_tfms=[Normalize.from_stats(*imagenet_stats)])

dark_val_datablock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=None)), # codes are the classes for segmentation
                                   get_items=get_image_files,
                                   splitter=FuncSplitter(custom_splitter),
                                   get_y=get_y_dark_val,
                                   item_tfms=Resize(460),
                                   batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)])


# Create DataLoaders for the dark validation set
dark_val_dls_enhanced = dark_val_datablock.dataloaders(enhanced_IMAGES_DIR, path=BASE_PATH, bs=4, num_workers=11, pin_memory=True, prefetch_factor=4)

inputs_dark_enhanced, preds_dark_enhanced, targets_dark_enhanced, decoded_dark_enhanced = learn.get_preds(
    dl=dark_val_dls_enhanced.valid,
    with_input=True,
    with_decoded=True
)

print(dark_val_dls_enhanced.train_ds)
print(dark_val_dls_enhanced.valid_ds)


In [None]:
evaluation_df = evaluate_segmentation_model(decoded_dark_enhanced, targets_dark_enhanced, vocab_dict)
display(evaluation_df)
# Mean IOU
print(evaluation_df['IoU'][:-1].mean())

In [None]:
visualize_segmentation_batch(inputs_dark_enhanced, targets_dark_enhanced, decoded_dark_enhanced, vocab_dict=vocab_dict, num_imgs=80)

# Retinex-Enhanced

In [None]:
from functools import partial

# Define the directory containing the dark validation images
enhanced_IMAGES_DIR = BASE_PATH/'bdd10k_images_dark_retinex'
# Define the directory containing the corresponding masks for dark validation images
DARK_VAL_MASKS_DIR = BASE_PATH/'bdd100k_seg_maps_dark'

# Create a partial function for get_adjusted_mask_file_path for the dark validation set
get_y_dark_val = partial(get_adjusted_mask_file_path, mask_source_dir=DARK_VAL_MASKS_DIR)

# # Create a new DataBlock for the dark validation dataset
# dark_val_datablock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=None)),
#                                get_items=get_image_files,
#                                splitter=lambda x: [list(range(len(x)))], # No splitting needed for a single eval set
#                                get_y=get_y_dark_val,
#                                item_tfms=Resize(224, method='crop'), # Changed to deterministic resize and center-crop to 224x224
#                                batch_tfms=[Normalize.from_stats(*imagenet_stats)])

dark_val_datablock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=None)), # codes are the classes for segmentation
                                   get_items=get_image_files,
                                   splitter=FuncSplitter(custom_splitter),
                                   get_y=get_y_dark_val,
                                   item_tfms=Resize(460),
                                   batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)])

# Create DataLoaders for the dark validation set
dark_val_dls_enhanced = dark_val_datablock.dataloaders(enhanced_IMAGES_DIR, path=BASE_PATH, bs=4, num_workers=11, pin_memory=True, prefetch_factor=4)

inputs_dark_enhanced, preds_dark_enhanced, targets_dark_enhanced, decoded_dark_enhanced = learn.get_preds(
    dl=dark_val_dls_enhanced.valid,
    with_input=True,
    with_decoded=True
)

print(dark_val_dls_enhanced.train_ds)
print(dark_val_dls_enhanced.valid_ds)

In [None]:
evaluation_df = evaluate_segmentation_model(decoded_dark_enhanced, targets_dark_enhanced, vocab_dict)
display(evaluation_df)
# Mean IOU
print(evaluation_df['IoU'][:-1].mean())

In [None]:
visualize_segmentation_batch(inputs_dark_enhanced, targets_dark_enhanced, decoded_dark_enhanced, vocab_dict=vocab_dict, num_imgs=80)

# Evaluation on Dark Zurich

## Attempt 1: Using LabelIds

In [None]:
def get_adjusted_mask_file_path(x, mask_source_dir=None):
    """
    Given an input image path, returns the corresponding mask with values adjusted.
    All class IDs outside the 0-18 range (including 255) are remapped to 19 ('unknown').

    Args:
    - x (Pathlib.Path): Path to the input image file.
    - mask_source_dir (Pathlib.Path, optional): The absolute path to the directory
                                                containing the corresponding masks.

    Returns:
    - PIL.Image: The adjusted mask image.
    """
    base = x.stem.removesuffix("_rgb_anon")
    mask_filename = base + "_gt_labelIds.png"

    mask_path = mask_source_dir / mask_filename

    # Explicitly check for mask existence for debugging
    if not mask_path.exists():
        print(f"WARNING: Mask file not found for image: {x}. Expected at: {mask_path}")
        # Create a dummy black mask with the same dimensions as the input image
        try:
            img = Image.open(x)
            # Ensure dummy mask is uint8 for compatibility with PIL.Image.fromarray
            dummy_mask = np.zeros((img.size[1], img.size[0]), dtype=np.uint8)
            return Image.fromarray(dummy_mask)
        except Exception as e:
            print(f"Error creating dummy mask for {x}: {e}")
            raise FileNotFoundError(f"Mask {mask_path} not found and failed to create dummy mask for {x}.") from e

    # Load mask as NumPy array, default dtype (usually uint8) is fine for Image.open
    mask = np.array(Image.open(mask_path))
    # print("HERE")
    # print(mask)
    # Remap all values > 18 (including 255) to 19 (unknown)
    # This operation can be performed on uint8 if all values are <= 255, which is the case here.
    mask[mask > 18] = 19

    # Return as PIL Image. It's now uint8, which PIL handles correctly.
    return Image.fromarray(mask)


In [None]:
# Define the directory containing the dark validation images
IMAGES_DIR = BASE_PATH/'dark_zurich_val'
# Define the directory containing the corresponding masks for dark validation images
DZ_MASKS_DIR = BASE_PATH/'dark_zurich_val_segmaps_labelIds'

# Create a partial function for get_adjusted_mask_file_path for the dark validation set
get_y_dz = partial(get_adjusted_mask_file_path, mask_source_dir=DZ_MASKS_DIR)

# Create a new DataBlock for the dark validation dataset
dz_datablock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=list(vocab_dict.keys()))),
                               get_items=get_image_files,
                               #splitter=lambda x: [list(range(len(x)))], # No splitting needed for a single eval set
                               get_y=get_y_dz,
                               item_tfms=Resize(224, method='crop'), # Changed to deterministic resize and center-crop to 224x224
                               batch_tfms=[Normalize.from_stats(*imagenet_stats)]) # Explicitly set cuda=False for debugging

# Create DataLoaders for the dark validation set
dz_dls = dz_datablock.dataloaders(IMAGES_DIR, path=BASE_PATH, bs=4, num_workers=0, pin_memory=True, prefetch_factor=4) # Explicitly set device to 'cpu'

inputs_dz, preds_dz, targets_dz, decoded_dz = learn.get_preds(
    dl=dz_dls,
    with_input=True,
    with_decoded=True
)

In [None]:
from functools import partial

# Define the directory containing the dark validation images
IMAGES_DIR = BASE_PATH/'dark_zurich_val'
# Define the directory containing the corresponding masks for dark validation images
DZ_MASKS_DIR = BASE_PATH/'dark_zurich_val_segmaps_labelIds'

# Create a partial function for get_adjusted_mask_file_path for the dark validation set
get_y_dz = partial(get_adjusted_mask_file_path, mask_source_dir=DZ_MASKS_DIR)

# Create a new DataBlock for the dark validation dataset
dz_datablock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=list(vocab_dict.keys()))),
                               get_items=get_image_files,
                               splitter=lambda x: [list(range(len(x)))], # No splitting needed for a single eval set
                               get_y=get_y_dz,
                               item_tfms=Resize(224, method='crop'), # Changed to deterministic resize and center-crop to 224x224
                               batch_tfms=[Normalize.from_stats(*imagenet_stats, cuda=False)]) # Explicitly set cuda=False for debugging

# Create DataLoaders for the dark validation set
dz_dls = dz_datablock.dataloaders(IMAGES_DIR, path=BASE_PATH, bs=4, num_workers=0, pin_memory=True, prefetch_factor=4, device='cpu') # Explicitly set device to 'cpu'

# Manual prediction loop to entirely bypass learn.get_preds reordering logic
all_inputs = []
all_preds = []
all_targets = []
all_decoded = []

learn.model.eval() # Set model to evaluation mode
with torch.no_grad(): # Disable gradient calculation for inference
    for i, (xb, yb) in enumerate(dz_dls.train): # Changed dz_dls to dz_dls.train
        print(f"Processing batch {i}.")
        # xb, yb are already on CPU because dz_dls.device='cpu'
        # Move them to the model's device (GPU) for inference
        xb_gpu = xb.to(device)
        yb_gpu = yb.to(device)

        # Get raw model outputs
        raw_preds = learn.model(xb_gpu)

        # Get actual predictions (argmax for segmentation)
        preds = raw_preds.argmax(dim=1)

        # Decode predictions (if your model's loss_func has a specific decodes method, otherwise preds themselves)
        # For segmentation, decoded_preds are usually just the preds if not using a custom decoder
        # If learn.loss_func has a decodes method, you might do: decoded_preds = learn.loss_func.decodes(raw_preds)
        decoded_preds = preds # Defaulting to preds as decoded_preds if no specific decoder is needed

        all_inputs.append(xb.cpu()) # Store on CPU
        all_preds.append(preds.cpu()) # Store on CPU
        all_targets.append(yb.cpu()) # Store on CPU
        all_decoded.append(decoded_preds.cpu()) # Store on CPU

inputs_dz = torch.cat(all_inputs)
preds_dz = torch.cat(all_preds)
targets_dz = torch.cat(all_targets)
decoded_dz = torch.cat(all_decoded)

print("Manual prediction loop completed.")


In [None]:
evaluation_df = evaluate_segmentation_model(decoded_dz, targets_dz, vocab_dict)
display(evaluation_df)
# Mean IOU
print(evaluation_df['IoU'].mean())

In [None]:
visualize_segmentation_batch(inputs_dz, targets_dz, decoded_dz, vocab_dict=vocab_dict, num_imgs=50)

## Attempt 2: TrainIds

In [None]:
def get_adjusted_mask_file_path(x, mask_source_dir=None):
    """
    Given an input image path, returns the corresponding mask with values adjusted.
    All class IDs outside the 0-18 range (including 255) are remapped to 19 ('unknown').

    Args:
    - x (Pathlib.Path): Path to the input image file.
    - mask_source_dir (Pathlib.Path, optional): The absolute path to the directory
                                                containing the corresponding masks.

    Returns:
    - PIL.Image: The adjusted mask image.
    """
    base = x.stem.removesuffix("_rgb_anon")
    mask_filename = base + "_gt_labelTrainIds.png"

    mask_path = mask_source_dir / mask_filename

    # Explicitly check for mask existence for debugging
    if not mask_path.exists():
        print(f"WARNING: Mask file not found for image: {x}. Expected at: {mask_path}")
        # Create a dummy black mask with the same dimensions as the input image
        try:
            img = Image.open(x)
            # Ensure dummy mask is uint8 for compatibility with PIL.Image.fromarray
            dummy_mask = np.zeros((img.size[1], img.size[0]), dtype=np.uint8)
            return Image.fromarray(dummy_mask)
        except Exception as e:
            print(f"Error creating dummy mask for {x}: {e}")
            raise FileNotFoundError(f"Mask {mask_path} not found and failed to create dummy mask for {x}.") from e

    # Load mask as NumPy array, default dtype (usually uint8) is fine for Image.open
    mask = np.array(Image.open(mask_path))
    # print("HERE")
    # print(mask)
    # Remap all values > 18 (including 255) to 19 (unknown)
    # This operation can be performed on uint8 if all values are <= 255, which is the case here.
    mask[mask > 18] = 19

    # Return as PIL Image. It's now uint8, which PIL handles correctly.
    return Image.fromarray(mask)


In [None]:
# Define the directory containing the dark validation images
IMAGES_DIR = BASE_PATH/'dark_zurich_images'
# Define the directory containing the corresponding masks for dark validation images
DZ_MASKS_DIR = BASE_PATH/'dark_zurich_val_segmaps_trainIds'

# Create a partial function for get_adjusted_mask_file_path for the dark validation set
get_y_dz = partial(get_adjusted_mask_file_path, mask_source_dir=DZ_MASKS_DIR)

# Create a new DataBlock for the dark validation dataset
dz_datablock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=list(vocab_dict.keys()))),
                               get_items=get_image_files,
                               splitter=FuncSplitter(custom_splitter),
                               get_y=get_y_dz,
                              item_tfms=Resize(460),
                              batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)])

# Create DataLoaders for the dark validation set
dz_dls = dz_datablock.dataloaders(IMAGES_DIR, path=BASE_PATH, bs=4, num_workers=11, pin_memory=True, prefetch_factor=4, shuffle_train=False) # Explicitly set device to 'cpu'

print(dz_dls.train_ds)
print(dz_dls.valid_ds)

inputs_dz, preds_dz, targets_dz, decoded_dz = learn.get_preds(
    dl=dz_dls.valid,
    with_input=True,
    with_decoded=True
)

In [None]:
evaluation_df = evaluate_segmentation_model(decoded_dz, targets_dz, vocab_dict)
display(evaluation_df)
# Mean IOU
print(evaluation_df['IoU'][:-1].mean())

In [None]:
visualize_segmentation_batch(inputs_dz, targets_dz, decoded_dz, vocab_dict=vocab_dict, num_imgs=50)

# Gamma enhancement on Dark Zurich

In [None]:

# Define the directory containing the dark validation images
IMAGES_DIR = BASE_PATH/'dark_zurich_images_gamma'
# Define the directory containing the corresponding masks for dark validation images
DZ_MASKS_DIR = BASE_PATH/'dark_zurich_val_segmaps_trainIds'

# Create a partial function for get_adjusted_mask_file_path for the dark validation set
get_y_dz = partial(get_adjusted_mask_file_path, mask_source_dir=DZ_MASKS_DIR)

# Create a new DataBlock for the dark validation dataset
dz_datablock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=list(vocab_dict.keys()))),
                               get_items=get_image_files,
                               splitter=FuncSplitter(custom_splitter),
                               get_y=get_y_dz,
                              item_tfms=Resize(460),
                              batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)])

# Create DataLoaders for the dark validation set
dz_dls = dz_datablock.dataloaders(IMAGES_DIR, path=BASE_PATH, bs=4, num_workers=4, pin_memory=True, prefetch_factor=4) # Explicitly set device to 'cpu'

print(dz_dls.train_ds)
print(dz_dls.valid_ds)

inputs_dz, preds_dz, targets_dz, decoded_dz = learn.get_preds(
    dl=dz_dls.valid,
    with_input=True,
    with_decoded=True
)

In [None]:
evaluation_df = evaluate_segmentation_model(decoded_dz, targets_dz, vocab_dict)
display(evaluation_df)
# Mean IOU
print(evaluation_df['IoU'][:-1].mean())

In [None]:
visualize_segmentation_batch(inputs_dz, targets_dz, decoded_dz, vocab_dict=vocab_dict, num_imgs=50)

# Dark Zurich Clahe-enhanced

In [None]:
# Define the directory containing the dark validation images
IMAGES_DIR = BASE_PATH/'dark_zurich_images_clahe'
# Define the directory containing the corresponding masks for dark validation images
DZ_MASKS_DIR = BASE_PATH/'dark_zurich_val_segmaps_trainIds'

# Create a partial function for get_adjusted_mask_file_path for the dark validation set
get_y_dz = partial(get_adjusted_mask_file_path, mask_source_dir=DZ_MASKS_DIR)

# Create a new DataBlock for the dark validation dataset
dz_datablock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=list(vocab_dict.keys()))),
                               get_items=get_image_files,
                               splitter=FuncSplitter(custom_splitter),
                               get_y=get_y_dz,
                              item_tfms=Resize(460),
                              batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)])
# Create DataLoaders for the dark validation set
dz_dls = dz_datablock.dataloaders(IMAGES_DIR, path=BASE_PATH, bs=4, num_workers=11, pin_memory=True, prefetch_factor=4) # Explicitly set device to 'cpu'

print(dz_dls.train_ds)
print(dz_dls.valid_ds)

inputs_dz, preds_dz, targets_dz, decoded_dz = learn.get_preds(
    dl=dz_dls.valid,
    with_input=True,
    with_decoded=True
)

In [None]:
evaluation_df = evaluate_segmentation_model(decoded_dz, targets_dz, vocab_dict)
display(evaluation_df)
# Mean IOU
print(evaluation_df['IoU'][:-1].mean())

In [None]:
visualize_segmentation_batch(inputs_dz, targets_dz, decoded_dz, vocab_dict=vocab_dict, num_imgs=50)

# Dark Zurich Retinex Enhanced

In [None]:
# Define the directory containing the dark validation images
IMAGES_DIR = BASE_PATH/'dark_zurich_images_retinex'
# Define the directory containing the corresponding masks for dark validation images
DZ_MASKS_DIR = BASE_PATH/'dark_zurich_val_segmaps_trainIds'

# Create a partial function for get_adjusted_mask_file_path for the dark validation set
get_y_dz = partial(get_adjusted_mask_file_path, mask_source_dir=DZ_MASKS_DIR)

# # Create a new DataBlock for the dark validation dataset
# dz_datablock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=list(vocab_dict.keys()))),
#                                get_items=get_image_files,
#                                #splitter=lambda x: [list(range(len(x)))], # No splitting needed for a single eval set
#                                get_y=get_y_dz,
#                                item_tfms=Resize(224, method='crop'), # Changed to deterministic resize and center-crop to 224x224
#                                batch_tfms=[Normalize.from_stats(*imagenet_stats)]) # Explicitly set cuda=False for debugging

# Create DataLoaders for the dark validation set
dz_dls = dz_datablock.dataloaders(IMAGES_DIR, path=BASE_PATH, bs=4, num_workers=4, pin_memory=True, prefetch_factor=4) # Explicitly set device to 'cpu'

print(dz_dls.train_ds)
print(dz_dls.valid_ds)

inputs_dz, preds_dz, targets_dz, decoded_dz = learn.get_preds(
    dl=dz_dls.valid,
    with_input=True,
    with_decoded=True
)

In [None]:
evaluation_df = evaluate_segmentation_model(decoded_dz, targets_dz, vocab_dict)
display(evaluation_df)
# Mean IOU
print(evaluation_df['IoU'][:-1].mean())

In [None]:
visualize_segmentation_batch(inputs_dz, targets_dz, decoded_dz, vocab_dict=vocab_dict, num_imgs=50)

# Debugging

In [None]:
# Manually inspect a processed mask from Dark Zurich

# Get a list of image files from the Dark Zurich directory
dz_image_files = get_image_files(IMAGES_DIR)

if len(dz_image_files) > 0:
    # Pick the first image for inspection
    sample_image_path = dz_image_files[0]

    # Apply the get_adjusted_mask_file_path function to get the processed mask
    processed_mask_pil = get_y_dz(sample_image_path)
    processed_mask_np = np.array(processed_mask_pil)

    # Find unique classes in the processed mask
    unique_classes_processed = np.unique(processed_mask_np)

    print(f"Sample Image: {sample_image_path.name}")
    print(f"Unique classes in PROCESSED mask: {sorted(unique_classes_processed)}")

    # Check if all classes are within the expected range (0-19)
    if all(0 <= cls <= 19 for cls in unique_classes_processed):
        print("All unique classes in processed mask are within the 0-19 range. Remapping appears successful.")
    else:
        print("WARNING: Some unique classes in processed mask are still outside the 0-19 range. Remapping issue persists.")
else:
    print(f"No image files found in {IMAGES_DIR}. Cannot perform mask inspection.")


In [None]:
# Define the directory containing your validation masks
mask_dir = BASE_PATH/'dark_zurich_val_segmaps_labelIds'

# Find unique classes
unique_classes = find_unique_classes_in_masks(mask_dir)

print(f"Unique classes found: {sorted(unique_classes)}")
print(f"Total number of unique classes: {len(unique_classes)}")
assert(len(unique_classes) == len(vocab_dict))

In [None]:
# deleted section
inputs_dz, preds_dz, targets_dz, decoded_dz = learn.get_preds(
    dl=dz_dls,
    with_input=True,
    with_decoded=True
)