---
### Project - Road Image Segmentation for AD/ADAS related applications
---

In [1]:
NAME1 = "Mathanesh Vellingiri Ramasamy"
NAME2 = "Dimas Rizky Kurniasalim"
PROJECT_GROUP = "80"

### Check Python version

In [2]:
from platform import python_version_tuple

assert (
    python_version_tuple()[:2] == ("3", "11")
), "You are not running Python 3.11. Make sure to run Python through the course Conda environment."

### Install required packages

In [3]:
# pip install numpy matplotlib pandas 
# pip install torch==2.3.1 torchvision --index-url https://download.pytorch.org/whl/cpu
# pip install torch==2.3.1+cu118 torchvision --index-url https://download.pytorch.org/whl/cu118
# pip install torchinfo 
# pip install opencv-python opencv-contrib-python 
# pip install scikit-learn 
# pip install segmentation-models-pytorch

### Verify the package requirements

In [4]:
# This will produce a warning regarding a mismatch between
# Scipy and Numpy but it does not affect the functionality we require.
from importlib.util import find_spec

packages_to_find = ["numpy", "matplotlib", "pandas", "seaborn", "torch", "torchinfo",
                    "torchvision", "cv2", "sklearn", "segmentation_models_pytorch"]
all_found = True
for pkg in packages_to_find:
    if find_spec(pkg) is None:
        all_found = False
        raise ImportError(
            f"There was an error importing: [{pkg}]\n Please make sure you followed all instructions correctly."
        )

if all_found:
    print("All modules have been imported successfully.")


All modules have been imported successfully.


---
### [0] Imports
---

In [5]:
# Standard Libraries and dealing with files
import os
import random
import glob
import shutil     
import numpy as np
import pandas as pd
import scipy
from scipy import misc
from pathlib import Path
from itertools import chain
from typing import Callable, Union

# Data Visualization and Image Processing
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap
from PIL import Image, ImageChops

# PyTorch Modules
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
from torchvision import datasets, transforms
from torchvision.transforms import Compose, ToTensor, Resize, Normalize
from torchinfo import summary

# Sci-kit-Learn Modules
from sklearn.model_selection import train_test_split                    # For splitting the data
from sklearn.datasets import fetch_openml                               # For accessing datasets from the OpenML database
from sklearn.preprocessing import StandardScaler                        # For Pre-processing
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay    # For Confusion Matrix

# Others
import segmentation_models_pytorch as smp
from tqdm.auto import tqdm   
import warnings
warnings.filterwarnings("ignore")


---
### [1] Data: Pre-processing and Visualization
---

## Extracting the file path and sorting the data

In [6]:
# def extract_file_path(img_path, mask_path, img_format = None, label_format = None, label_filter = None):

#     # Load/Sort - Images and Masks/Labels
#     img_files = sorted(list(Path(img_path).glob(img_format)))
#     label_files = sorted(list(Path(mask_path).glob(label_format)))
    
#     # When images and labels count mis-matches: Filter for specific labels to match images and masks ('road' in the Kitti dataset)
#     if label_filter:
#         label_files = [file for file in label_files if label_filter in os.path.basename(file)]
#     return img_files, label_files

### One-hot-vector-encoding (Binary classification: Road/Not a road)

In [7]:
# def one_hot_vector_encode(label_array, img, dataset = None):

#     # Choosing the Dataset
#     if dataset == 'sample':
#         num_classes = 3
#         road_label = np.array([0, 128, 255])    # class: road
#         non_road_label = np.array([255, 0, 0])  # class: non-road
#         other_road_label = np.array([0, 0, 0])  # class: other/opposite road

#         # One-Hot-Vector-Encoding (ohve)
#         ohve_label = np.zeros_like(label_array)
#         # ohve_label = np.zeros((label_array.shape[0], label_array.shape[1], num_classes), dtype=np.uint8)
#         ohve_label[:, :, 0] = np.all(label_array == road_label, axis = 2).astype(np.uint8)
#         ohve_label[:, :, 1] = np.all(label_array == other_road_label, axis = 2).astype(np.uint8)
#         ohve_label[:, :, 2] = np.all(label_array == non_road_label, axis = 2).astype(np.uint8)

#         # Binary Encoding (Road/Non-road)
#         # binary_label = (1 - np.all(label_array == non_road_label, axis = 2)).astype(np.uint8)
#         binary_label = ((ohve_label[:, :, 0] == 1)).astype(np.uint8)  # Road is 1, all else is 0
        
#     elif dataset == 'kitti':
#         num_classes = 3
#         road_label = np.array([255, 0, 255])    # class: road
#         non_road_label = np.array([255, 0, 0])  # class: non-road
#         other_road_label = np.array([0, 0, 0])  # class: other/opposite road

#         # One-Hot-Vector-Encoding (ohve)
#         ohve_label = np.zeros_like(label_array)
#         # ohve_label = np.zeros((label_array.shape[0], label_array.shape[1], num_classes), dtype=np.uint8)
#         ohve_label[:, :, 0] = np.all(label_array == road_label, axis = 2).astype(np.uint8)
#         ohve_label[:, :, 1] = np.all(label_array == other_road_label, axis = 2).astype(np.uint8)
#         ohve_label[:, :, 2] = np.all(label_array == non_road_label, axis = 2).astype(np.uint8)

#         # Binary Encoding (Road/Non-road)
#         # binary_label = (1 - np.all(label_array == non_road_label, axis = 2)).astype(np.uint8)
#         binary_label = ((ohve_label[:, :, 0] == 1)).astype(np.uint8)  # Road == 1
        
#     elif dataset == 'comma10k':
#         #402020 [64, 32, 32] - road (all parts, anywhere nobody would look at you funny for driving)
#         #00ff66 [0, 255, 102] - movable (vehicles and people/animals)
        
#         #ff0000 [255, 0, 0] - lane markings (don't include non lane markings like turn arrows and crosswalks)
#         #808060 [128, 128, 96] - undrivable
#         #cc00ff [204, 0, 255] - my car (and anything inside it, including wires, mounts, etc. No reflections)
#         #00ccff [0, 204, 255] - movable in my car (people inside the car, optional: imgsd only)

#         num_classes = 6
#         road_label = np.array([64, 32, 32])          # class: road
#         non_road_label = np.array([0, 255, 102])     # class: non-road
#         other_road_label = np.array([128, 128, 96])  # class: undrivable surface
#         movable_label = np.array([0, 255, 102])      # class: movable
#         lane_markings_label = np.array([255, 0, 0])  # class: lane markings
#         my_car_label = np.array([204, 0, 255])       # class: my car

#         # One-Hot-Vector-Encoding (ohve)
#         # ohve_label = np.zeros_like(label_array)
#         ohve_label = np.zeros((label_array.shape[0], label_array.shape[1], num_classes), dtype=np.uint8)
#         ohve_label[:, :, 0] = np.all(label_array == road_label, axis=2).astype(np.uint8)           # road
#         ohve_label[:, :, 1] = np.all(label_array == other_road_label, axis=2).astype(np.uint8)     # undrivable surface
#         ohve_label[:, :, 2] = np.all(label_array == non_road_label, axis=2).astype(np.uint8)       # non-road
#         ohve_label[:, :, 3] = np.all(label_array == movable_label, axis=2).astype(np.uint8)        # movable
#         ohve_label[:, :, 4] = np.all(label_array == lane_markings_label, axis=2).astype(np.uint8)  # lane markings
#         ohve_label[:, :, 5] = np.all(label_array == my_car_label, axis=2).astype(np.uint8)         # my car
        
#         # Binary Encoding (Road/Non-road)
#         # binary_label = (1 - np.all(label_array == non_road_label, axis = 2)).astype(np.uint8)
#         binary_label = ((ohve_label[:, :, 0] == 1) | (ohve_label[:, :, 4] == 1)).astype(np.uint8)  # Road and Lane Marking == 1
        
#     elif dataset == 'bdd100k':
#         num_classes = 3
#         road_label = np.array([128, 64, 128])   # class: road
#         non_road_label = np.array([0, 0, 0])    # class: non-road
#         other_road_label = np.array([0, 0, 0])  # class: others

#         # One-Hot-Vector-Encoding (ohve)
#         ohve_label = np.zeros_like(label_array)
#         # ohve_label = np.zeros((label_array.shape[0], label_array.shape[1], num_classes), dtype=np.uint8)
#         ohve_label[:, :, 0] = np.all(label_array == road_label, axis = 2).astype(np.uint8)
#         ohve_label[:, :, 1] = np.all(label_array == other_road_label, axis = 2).astype(np.uint8)
#         ohve_label[:, :, 2] = np.all(label_array == non_road_label, axis = 2).astype(np.uint8)

#         # Binary Encoding (Road/Non-road)
#         # binary_label = (1 - np.all(label_array == non_road_label, axis = 2)).astype(np.uint8)
#         binary_label = ((ohve_label[:, :, 0] == 1)).astype(np.uint8)  # Road == 1
        
#     elif dataset == 'cityscapes':
#         num_classes = 3
#         road_label = np.array([128, 64, 128])   # class: road
#         non_road_label = np.array([0, 0, 0])    # class: non-road
#         other_road_label = np.array([0, 0, 0])  # class: others
        
#         # One-Hot-Vector-Encoding (ohve)
#         ohve_label = np.zeros_like(label_array)
#         # ohve_label = np.zeros((label_array.shape[0], label_array.shape[1], num_classes), dtype=np.uint8)
#         ohve_label[:, :, 0] = np.all(label_array == road_label, axis = 2).astype(np.uint8)
#         ohve_label[:, :, 1] = np.all(label_array == other_road_label, axis = 2).astype(np.uint8)
#         ohve_label[:, :, 2] = np.all(label_array == non_road_label, axis = 2).astype(np.uint8)

#         # Binary Encoding (Road/Non-road)
#         # binary_label = (1 - np.all(label_array == non_road_label, axis = 2)).astype(np.uint8)
#         binary_label = ((ohve_label[:, :, 0] == 1)).astype(np.uint8)  # Road == 1
            
#     elif dataset == 'nuscenes':
#         num_classes = 3
#         road_label = np.array([128, 64, 128])   # class: road
#         non_road_label = np.array([0, 0, 0])    # class: non-road
#         other_road_label = np.array([0, 0, 0])  # class: others
        
#         # One-Hot-Vector-Encoding (ohve)
#         ohve_label = np.zeros_like(label_array)
#         # ohve_label = np.zeros((label_array.shape[0], label_array.shape[1], num_classes), dtype=np.uint8)
#         ohve_label[:, :, 0] = np.all(label_array == road_label, axis = 2).astype(np.uint8)
#         ohve_label[:, :, 1] = np.all(label_array == other_road_label, axis = 2).astype(np.uint8)
#         ohve_label[:, :, 2] = np.all(label_array == non_road_label, axis = 2).astype(np.uint8)

#         # Binary Encoding (Road/Non-road)
#         # binary_label = (1 - np.all(label_array == non_road_label, axis = 2)).astype(np.uint8)
#         binary_label = ((ohve_label[:, :, 0] == 1)).astype(np.uint8)  # Road == 1
    
#     # Plotting
#     plt.figure(figsize = (15, 5))
    
#     # Plot the One-Hot-Vector-Encoded Image (Multi-class)
#     plt.subplot(1, 3, 1)
#     plt.imshow(ohve_label * 255)
#     plt.title("One-Hot-Vector Encoding (Multi-class)")
#     plt.axis('off')
    
#     # Plot the Binary Mask/Label (Road/Non-road)
#     plt.subplot(1, 3, 2)
#     plt.imshow(binary_label * 255, cmap = 'gray')
#     plt.title("Binary Classification: Road/Non-road")
#     plt.axis('off')
    
#     # Overlay Binary Mask
#     binary_label_img = Image.fromarray(binary_label * 255)
#     overlay_binary = ImageChops.add(img, binary_label_img.convert("RGB"), scale=1.5)
#     plt.subplot(1, 3, 3)
#     plt.imshow(overlay_binary)
#     plt.title("Binary Mask Overlay")
#     plt.axis('off')
    
#     plt.tight_layout()
#     plt.show()

### Visualization

In [8]:
# def visualize_data(img_files, mask_files, overlay = False, one_hot_encode = False, dataset = None):
#     idx = random.randint(0, len(img_files) - 1)
    
#     img = Image.open(img_files[idx])
#     mask = Image.open(mask_files[idx])
    
#     print(f"Image Dimension: {img.size}")
#     print(f"Mask/Label Dimension: {mask.size}")
    
#     img_np = np.array(img)
#     mask_np = np.array(mask)

#     plt.figure(figsize=(15, 5))
    
#     # Plot original image
#     plt.subplot(1, 3, 1)
#     plt.imshow(img_np)
#     plt.title("Raw Original Image")
#     plt.axis('off')
    
#     # Plot mask image
#     plt.subplot(1, 3, 2)
#     plt.imshow(mask_np)
#     plt.title("Mask/Label")
#     plt.axis('off')
    
#     # To overlay mask onto the original image
#     if overlay:
#         mask = mask.resize(img.size)  # Ensure: mask dimension == image dimension
#         overlay_img = ImageChops.add(img, mask, scale = 1.5)
#         plt.subplot(1, 3, 3)
#         plt.imshow(overlay_img)
#         plt.title("Overlayed Mask")
#         plt.axis('off')
    
#     plt.tight_layout()
#     plt.show()
    
#     # To perform one-hot encoding
#     if one_hot_encode:
#         one_hot_vector_encode(mask_np, img, dataset = dataset)

### Processing the dataset

In [9]:
# def process_dataset(img_dir, mask_dir, img_format = "*.png", label_format = "*.png", 
#                     label_filter = None, overlay = False, one_hot_encode = False, dataset = None):
#     img_files, label_files = extract_file_path(img_dir, mask_dir, img_format, label_format, label_filter)  # Get image and label paths
#     print(f"Total Images = {len(img_files)}")           # Print Total Image Count
#     print(f"Total Masks/Labels = {len(label_files)}")   # Print Total Mask/Label Count
#     visualize_data(img_files, label_files, overlay = overlay, one_hot_encode = one_hot_encode, dataset = dataset) # Visualize data

### FULL CODE EDIT

In [None]:
# Extracting the file path and sorting the data (supporting .png and .jpg formats):
def extract_file_path(img_path, mask_path, img_format = None, label_format = None, label_filter = None):

    # Load/Sort - Images and Masks/Labels
    img_files = sorted(list(Path(img_path).glob(img_format)))
    label_files = sorted(list(Path(mask_path).glob(label_format)))
    
    # When images and labels count mis-matches: Filter for specific labels to match images and masks ('road' in the Kitti dataset)
    if label_filter:
        label_files = [file for file in label_files if label_filter in os.path.basename(file)]
    return img_files, label_files

# Update the one-hot encoding to handle different datasets (adjusting for comma10k dataset)
def one_hot_vector_encode(label_array, img, dataset = None):

    # Choosing the Dataset
    if dataset == 'sample':
        num_classes = 3
        road_label = np.array([0, 128, 255])    # class: road
        non_road_label = np.array([255, 0, 0])  # class: non-road
        other_road_label = np.array([0, 0, 0])  # class: other/opposite road

        # One-Hot-Vector-Encoding (ohve)
        ohve_label = np.zeros_like(label_array)
        # ohve_label = np.zeros((label_array.shape[0], label_array.shape[1], num_classes), dtype=np.uint8)
        ohve_label[:, :, 0] = np.all(label_array == road_label, axis = 2).astype(np.uint8)
        ohve_label[:, :, 1] = np.all(label_array == other_road_label, axis = 2).astype(np.uint8)
        ohve_label[:, :, 2] = np.all(label_array == non_road_label, axis = 2).astype(np.uint8)

        # Binary Encoding (Road/Non-road)
        # binary_label = (1 - np.all(label_array == non_road_label, axis = 2)).astype(np.uint8)
        binary_label = ((ohve_label[:, :, 0] == 1)).astype(np.uint8)  # Road is 1, all else is 0
        
    elif dataset == 'kitti':
        num_classes = 3
        road_label = np.array([255, 0, 255])    # class: road
        non_road_label = np.array([255, 0, 0])  # class: non-road
        other_road_label = np.array([0, 0, 0])  # class: other/opposite road

        # One-Hot-Vector-Encoding (ohve)
        ohve_label = np.zeros_like(label_array)
        # ohve_label = np.zeros((label_array.shape[0], label_array.shape[1], num_classes), dtype=np.uint8)
        ohve_label[:, :, 0] = np.all(label_array == road_label, axis = 2).astype(np.uint8)
        ohve_label[:, :, 1] = np.all(label_array == other_road_label, axis = 2).astype(np.uint8)
        ohve_label[:, :, 2] = np.all(label_array == non_road_label, axis = 2).astype(np.uint8)

        # Binary Encoding (Road/Non-road)
        # binary_label = (1 - np.all(label_array == non_road_label, axis = 2)).astype(np.uint8)
        binary_label = ((ohve_label[:, :, 0] == 1)).astype(np.uint8)  # Road == 1
        
    elif dataset == 'comma10k':
        #402020 [64, 32, 32] - road (all parts, anywhere nobody would look at you funny for driving)
        #00ff66 [0, 255, 102] - movable (vehicles and people/animals)
        
        #ff0000 [255, 0, 0] - lane markings (don't include non lane markings like turn arrows and crosswalks)
        #808060 [128, 128, 96] - undrivable
        #cc00ff [204, 0, 255] - my car (and anything inside it, including wires, mounts, etc. No reflections)
        #00ccff [0, 204, 255] - movable in my car (people inside the car, optional: imgsd only)

        num_classes = 6
        road_label = np.array([64, 32, 32])          # class: road
        non_road_label = np.array([0, 255, 102])     # class: non-road
        other_road_label = np.array([128, 128, 96])  # class: undrivable surface
        movable_label = np.array([0, 255, 102])      # class: movable
        lane_markings_label = np.array([255, 0, 0])  # class: lane markings
        my_car_label = np.array([204, 0, 255])       # class: my car

        # One-Hot-Vector-Encoding (ohve)
        # ohve_label = np.zeros_like(label_array)
        ohve_label = np.zeros((label_array.shape[0], label_array.shape[1], num_classes), dtype=np.uint8)
        ohve_label[:, :, 0] = np.all(label_array == road_label, axis=2).astype(np.uint8)           # road
        ohve_label[:, :, 1] = np.all(label_array == other_road_label, axis=2).astype(np.uint8)     # undrivable surface
        ohve_label[:, :, 2] = np.all(label_array == non_road_label, axis=2).astype(np.uint8)       # non-road
        ohve_label[:, :, 3] = np.all(label_array == movable_label, axis=2).astype(np.uint8)        # movable
        ohve_label[:, :, 4] = np.all(label_array == lane_markings_label, axis=2).astype(np.uint8)  # lane markings
        ohve_label[:, :, 5] = np.all(label_array == my_car_label, axis=2).astype(np.uint8)         # my car
        
        # Binary Encoding (Road/Non-road)
        # binary_label = (1 - np.all(label_array == non_road_label, axis = 2)).astype(np.uint8)
        binary_label = ((ohve_label[:, :, 0] == 1) | (ohve_label[:, :, 4] == 1)).astype(np.uint8)  # Road and Lane Marking == 1
        
    elif dataset == 'bdd100k':
        num_classes = 3
        road_label = np.array([128, 64, 128])   # class: road
        non_road_label = np.array([0, 0, 0])    # class: non-road
        other_road_label = np.array([0, 0, 0])  # class: others

        # One-Hot-Vector-Encoding (ohve)
        ohve_label = np.zeros_like(label_array)
        # ohve_label = np.zeros((label_array.shape[0], label_array.shape[1], num_classes), dtype=np.uint8)
        ohve_label[:, :, 0] = np.all(label_array == road_label, axis = 2).astype(np.uint8)
        ohve_label[:, :, 1] = np.all(label_array == other_road_label, axis = 2).astype(np.uint8)
        ohve_label[:, :, 2] = np.all(label_array == non_road_label, axis = 2).astype(np.uint8)

        # Binary Encoding (Road/Non-road)
        # binary_label = (1 - np.all(label_array == non_road_label, axis = 2)).astype(np.uint8)
        binary_label = ((ohve_label[:, :, 0] == 1)).astype(np.uint8)  # Road == 1
        
    elif dataset == 'cityscapes':
        num_classes = 3
        road_label = np.array([128, 64, 128])   # class: road
        non_road_label = np.array([0, 0, 0])    # class: non-road
        other_road_label = np.array([0, 0, 0])  # class: others
        
        # One-Hot-Vector-Encoding (ohve)
        ohve_label = np.zeros_like(label_array)
        # ohve_label = np.zeros((label_array.shape[0], label_array.shape[1], num_classes), dtype=np.uint8)
        ohve_label[:, :, 0] = np.all(label_array == road_label, axis = 2).astype(np.uint8)
        ohve_label[:, :, 1] = np.all(label_array == other_road_label, axis = 2).astype(np.uint8)
        ohve_label[:, :, 2] = np.all(label_array == non_road_label, axis = 2).astype(np.uint8)

        # Binary Encoding (Road/Non-road)
        # binary_label = (1 - np.all(label_array == non_road_label, axis = 2)).astype(np.uint8)
        binary_label = ((ohve_label[:, :, 0] == 1)).astype(np.uint8)  # Road == 1
            
    elif dataset == 'nuscenes':
        num_classes = 3
        road_label = np.array([128, 64, 128])   # class: road
        non_road_label = np.array([0, 0, 0])    # class: non-road
        other_road_label = np.array([0, 0, 0])  # class: others
        
        # One-Hot-Vector-Encoding (ohve)
        ohve_label = np.zeros_like(label_array)
        # ohve_label = np.zeros((label_array.shape[0], label_array.shape[1], num_classes), dtype=np.uint8)
        ohve_label[:, :, 0] = np.all(label_array == road_label, axis = 2).astype(np.uint8)
        ohve_label[:, :, 1] = np.all(label_array == other_road_label, axis = 2).astype(np.uint8)
        ohve_label[:, :, 2] = np.all(label_array == non_road_label, axis = 2).astype(np.uint8)

        # Binary Encoding (Road/Non-road)
        # binary_label = (1 - np.all(label_array == non_road_label, axis = 2)).astype(np.uint8)
        binary_label = ((ohve_label[:, :, 0] == 1)).astype(np.uint8)  # Road == 1
    
    # Plotting
    plt.figure(figsize = (15, 5))
    
    # Plot the One-Hot-Vector-Encoded Image (Multi-class)
    plt.subplot(1, 3, 1)
    plt.imshow(ohve_label * 255)
    plt.title("One-Hot-Vector Encoding (Multi-class)")
    plt.axis('off')
    
    # Plot the Binary Mask/Label (Road/Non-road)
    plt.subplot(1, 3, 2)
    plt.imshow(binary_label * 255, cmap = 'gray')
    plt.title("Binary Classification: Road/Non-road")
    plt.axis('off')
    
    # Overlay Binary Mask
    binary_label_img = Image.fromarray(binary_label * 255)
    overlay_binary = ImageChops.add(img, binary_label_img.convert("RGB"), scale=1.5)
    plt.subplot(1, 3, 3)
    plt.imshow(overlay_binary)
    plt.title("Binary Mask Overlay")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Pre-process the dataset
def process_dataset(img_dir, mask_dir, img_format = "*.png", label_format = "*.png", 
                    label_filter = None, overlay = False, one_hot_encode = False, dataset = None):
    img_files, label_files = extract_file_path(img_dir, mask_dir, img_format, label_format, label_filter)  # Get image and label paths
    print(f"Total Images = {len(img_files)}")           # Print Total Image Count
    print(f"Total Masks/Labels = {len(label_files)}")   # Print Total Mask/Label Count
    visualize_data(img_files, label_files, overlay = overlay, one_hot_encode = one_hot_encode, dataset = dataset) # Visualize data

# Visualizing the dataset
def visualize_data(img_files, mask_files, overlay = False, one_hot_encode = False, dataset = None):
    idx = random.randint(0, len(img_files) - 1)
    
    img = Image.open(img_files[idx])
    mask = Image.open(mask_files[idx])
    
    print(f"Image Dimension: {img.size}")
    print(f"Mask/Label Dimension: {mask.size}")
    
    img_np = np.array(img)
    mask_np = np.array(mask)

    plt.figure(figsize=(15, 5))
    
    # Plot original image
    plt.subplot(1, 3, 1)
    plt.imshow(img_np)
    plt.title("Raw Original Image")
    plt.axis('off')
    
    # Plot mask image
    plt.subplot(1, 3, 2)
    plt.imshow(mask_np)
    plt.title("Mask/Label")
    plt.axis('off')
    
    # To overlay mask onto the original image
    if overlay:
        mask = mask.resize(img.size)  # Ensure: mask dimension == image dimension
        overlay_img = ImageChops.add(img, mask, scale = 1.5)
        plt.subplot(1, 3, 3)
        plt.imshow(overlay_img)
        plt.title("Overlayed Mask")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # To perform one-hot encoding
    if one_hot_encode:
        one_hot_vector_encode(mask_np, img, dataset = dataset)


In [10]:
def extract_file_path(img_path, mask_path, img_format="*.png", label_format="*.png", label_filter=None):
    img_files = sorted(list(Path(img_path).glob(img_format)))
    label_files = sorted(list(Path(mask_path).glob(label_format)))
    
    # Optional filter for labels (useful for Kitti dataset or other specific cases)
    if label_filter:
        label_files = [file for file in label_files if label_filter in os.path.basename(file)]
    
    # Sanity check: Ensure image and label count match
    if len(img_files) != len(label_files):
        raise ValueError("Mismatch in the number of images and labels")
    
    return img_files, label_files

def one_hot_vector_encode(label_array, img, dataset=None):
    dataset_classes = {
        'sample': {
            'num_classes': 3,
            'class_labels': {
                'road': [0, 128, 255],
                'non_road': [255, 0, 0],
                'other_road': [0, 0, 0],
            }
        },
        'kitti': {
            'num_classes': 3,
            'class_labels': {
                'road': [255, 0, 255],
                'non_road': [255, 0, 0],
                'other_road': [0, 0, 0],
            }
        },
        'comma10k': {
            'num_classes': 6,
            'class_labels': {
                'road': [64, 32, 32],
                'movable': [0, 255, 102],
                'lane_markings': [255, 0, 0],
                'undrivable': [128, 128, 96],
                'my_car': [204, 0, 255],
                'movable_in_car': [0, 204, 255],
            }
        },
        'bdd100k': {
            'num_classes': 3,
            'class_labels': {
                'road': [128, 64, 128],
                'non_road': [0, 0, 0],
                'other': [0, 0, 0],
            }
        },
        'cityscapes': {
            'num_classes': 3,
            'class_labels': {
                'road': [128, 64, 128],
                'non_road': [0, 0, 0],
                'other': [0, 0, 0],
            }
        },
        'nuscenes': {
            'num_classes': 3,
            'class_labels': {
                'road': [128, 64, 128],
                'non_road': [0, 0, 0],
                'other': [0, 0, 0],
            }
        }
    }

    if dataset not in dataset_classes:
        raise ValueError(f"Dataset {dataset} is not supported")

    num_classes = dataset_classes[dataset]['num_classes']
    class_labels = dataset_classes[dataset]['class_labels']

    # Create a blank one-hot encoded label array
    ohve_label = np.zeros((label_array.shape[0], label_array.shape[1], num_classes), dtype=np.uint8)

    for i, (class_name, color) in enumerate(class_labels.items()):
        ohve_label[:, :, i] = np.all(label_array == np.array(color), axis=2).astype(np.uint8)

    # Create a binary label for road vs. non-road
    binary_label = (ohve_label[:, :, 0] == 1).astype(np.uint8)  # Road class is typically the first class

    # Plotting logic remains the same...
    plt.figure(figsize=(15, 5))

    # Plot the One-Hot-Vector-Encoded Image (Multi-class)
    plt.subplot(1, 3, 1)
    plt.imshow(ohve_label * 255)
    plt.title("One-Hot-Vector Encoding (Multi-class)")
    plt.axis('off')

    # Plot the Binary Mask/Label (Road/Non-road)
    plt.subplot(1, 3, 2)
    plt.imshow(binary_label * 255, cmap='gray')
    plt.title("Binary Classification: Road/Non-road")
    plt.axis('off')

    # Overlay Binary Mask
    binary_label_img = Image.fromarray(binary_label * 255)
    overlay_binary = ImageChops.add(img, binary_label_img.convert("RGB"), scale=1.5)
    plt.subplot(1, 3, 3)
    plt.imshow(overlay_binary)
    plt.title("Binary Mask Overlay")
    plt.axis('off')

    plt.tight_layout()
    plt.show()

    return ohve_label, binary_label

def process_dataset(img_dir, mask_dir, img_format="*.png", label_format="*.png", 
                    label_filter=None, overlay=False, one_hot_encode=False, dataset=None):
    
    img_files, label_files = extract_file_path(img_dir, mask_dir, img_format, label_format, label_filter)
    
    for img_file, label_file in zip(img_files, label_files):
        img = np.array(Image.open(img_file))
        label_array = np.array(Image.open(label_file))
        
        if one_hot_encode:
            ohve_label, binary_label = one_hot_vector_encode(label_array, img, dataset=dataset)
        
        # Additional processing if needed...


## [1.1] Small Sample Dataset
To verify the above functions of data pre-processing and visualization, we are just loading and visualizing a small dataset below.

In [11]:
small_img_dir = "/SCM_CHALMERS/GitHub/Deep-Machine-Learning-Project-SSY340/Datasets/Road_Segment/images"
small_mask_dir = "/SCM_CHALMERS/GitHub/Deep-Machine-Learning-Project-SSY340/Datasets/Road_Segment/masks"

process_dataset(small_img_dir, small_mask_dir, img_format = "*.png", label_format = "*.png", 
                label_filter = None, overlay = False, one_hot_encode = False, dataset = 'sample')


### [1.2] KITTI Road Dataset

In [None]:
kitti_img_dir = "/SCM_CHALMERS/GitHub/Deep-Machine-Learning-Project-SSY340/Datasets/KITTI_ROAD/data_road/training/image_2"
kitti_mask_dir = "/SCM_CHALMERS/GitHub/Deep-Machine-Learning-Project-SSY340/Datasets/KITTI_ROAD/data_road/training/gt_image_2"

process_dataset(kitti_img_dir, kitti_mask_dir, img_format = "*.png", label_format = "*.png", 
                label_filter = "road", overlay = True, one_hot_encode = True, dataset = 'kitti')


### [1.3] comma10k Dataset 

In [None]:
comma10k_img_dir = "/SCM_CHALMERS/GitHub/Deep-Machine-Learning-Project-SSY340/Datasets/comma10k/imgs"
comma10k_mask_dir = "/SCM_CHALMERS/GitHub/Deep-Machine-Learning-Project-SSY340/Datasets/comma10k/masks"

process_dataset(comma10k_img_dir, comma10k_mask_dir, img_format = "*.png", label_format = "*.png", 
                label_filter = None, overlay = True, one_hot_encode = True, dataset = 'comma10k')

### [1.4] bdd100k Dataset

In [None]:
bdd100k_img_dir = "/SCM_CHALMERS/GitHub/Deep-Machine-Learning-Project-SSY340/Datasets/bdd100k/images/train"
bdd100k_mask_dir = "/SCM_CHALMERS/GitHub/Deep-Machine-Learning-Project-SSY340/Datasets/bdd100k/color_labels/train"

process_dataset(bdd100k_img_dir, bdd100k_mask_dir, img_format = "*.jpg", label_format = "*.png", 
                label_filter = None, overlay = True, one_hot_encode = True, dataset = 'bdd100k')


### [1.5] Cityscapes Dataset

In [None]:
cityscapes_img_dir = "/SCM_CHALMERS/GitHub/Deep-Machine-Learning-Project-SSY340/Datasets/Cityscapes/leftImg8bit_trainvaltest/leftImg8bit/train/aachen"
cityscapes_mask_dir = "/SCM_CHALMERS/GitHub/Deep-Machine-Learning-Project-SSY340/Datasets/Cityscapes/gtFine_trainvaltest/gtFine/train/aachen"

process_dataset(cityscapes_img_dir, cityscapes_mask_dir, img_format = "*.png", label_format = "*.png", 
                label_filter = None, overlay = True, one_hot_encode = True, dataset = 'cityscapes')