In [None]:
from collections import namedtuple

In [None]:
import torch

import cv2

import os

import pandas as pd

import matplotlib.pyplot as plt

import json

from PIL import Image

from matplotlib.patches import Polygon

In [None]:
from torchvision.transforms import Compose
import torch.nn as nn

# Create File Path Dataframes



- Contains the paths to images instead of images themselves

In [None]:
#--------------------------------------------------------------------------------

# Definitions

#--------------------------------------------------------------------------------



# a label and all meta information

Label = namedtuple( 'Label' , [



    'name'        , # The identifier of this label, e.g. 'car', 'person', ... .

                    # We use them to uniquely name a class



    'id'          , # An integer ID that is associated with this label.

                    # The IDs are used to represent the label in ground truth images

                    # An ID of -1 means that this label does not have an ID and thus

                    # is ignored when creating ground truth images (e.g. license plate).



    'trainId'     , # An integer ID that overwrites the ID above, when creating ground truth

                    # images for training.

                    # For training, multiple labels might have the same ID. Then, these labels

                    # are mapped to the same class in the ground truth images. For the inverse

                    # mapping, we use the label that is defined first in the list below.

                    # For example, mapping all void-type classes to the same ID in training,

                    # might make sense for some approaches.



    'category'    , # The name of the category that this label belongs to



    'categoryId'  , # The ID of this category. Used to create ground truth images

                    # on category level.



    'hasInstances', # Whether this label distinguishes between single instances or not



    'ignoreInEval', # Whether pixels having this class as ground truth label are ignored

                    # during evaluations or not



    'color'       , # The color of this label

    ] )





#--------------------------------------------------------------------------------

# A list of all labels

#--------------------------------------------------------------------------------



# Please adapt the train IDs as appropriate for you approach.

# Note that you might want to ignore labels with ID 255 during training.

# Make sure to provide your results using the original IDs and not the training IDs.

# Note that many IDs are ignored in evaluation and thus you never need to predict these!



labels = [

    #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color

    Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),

    Label(  'ego vehicle'          ,  1 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),

    Label(  'rectification border' ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),

    Label(  'out of roi'           ,  3 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),

    Label(  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),

    Label(  'dynamic'              ,  5 ,      255 , 'void'            , 0       , False        , True         , (111, 74,  0) ),

    Label(  'ground'               ,  6 ,      255 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),

    Label(  'road'                 ,  7 ,        0 , 'ground'          , 1       , False        , False        , (128, 64,128) ),

    Label(  'sidewalk'             ,  8 ,        1 , 'ground'          , 1       , False        , False        , (244, 35,232) ),

    Label(  'parking'              ,  9 ,      255 , 'ground'          , 1       , False        , True         , (250,170,160) ),

    Label(  'rail track'           , 10 ,      255 , 'ground'          , 1       , False        , True         , (230,150,140) ),

    Label(  'building'             , 11 ,        2 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),

    Label(  'wall'                 , 12 ,        3 , 'construction'    , 2       , False        , False        , (102,102,156) ),

    Label(  'fence'                , 13 ,        4 , 'construction'    , 2       , False        , False        , (190,153,153) ),

    Label(  'guard rail'           , 14 ,      255 , 'construction'    , 2       , False        , True         , (180,165,180) ),

    Label(  'bridge'               , 15 ,      255 , 'construction'    , 2       , False        , True         , (150,100,100) ),

    Label(  'tunnel'               , 16 ,      255 , 'construction'    , 2       , False        , True         , (150,120, 90) ),

    Label(  'pole'                 , 17 ,        5 , 'object'          , 3       , False        , False        , (153,153,153) ),

    Label(  'polegroup'            , 18 ,      255 , 'object'          , 3       , False        , True         , (153,153,153) ),

    Label(  'traffic light'        , 19 ,        6 , 'object'          , 3       , False        , False        , (250,170, 30) ),

    Label(  'traffic sign'         , 20 ,        7 , 'object'          , 3       , False        , False        , (220,220,  0) ),

    Label(  'vegetation'           , 21 ,        8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),

    Label(  'terrain'              , 22 ,        9 , 'nature'          , 4       , False        , False        , (152,251,152) ),

    Label(  'sky'                  , 23 ,       10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),

    Label(  'person'               , 24 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),

    Label(  'rider'                , 25 ,       12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),

    Label(  'car'                  , 26 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),

    Label(  'truck'                , 27 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),

    Label(  'bus'                  , 28 ,       15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),

    Label(  'caravan'              , 29 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),

    Label(  'trailer'              , 30 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),

    Label(  'train'                , 31 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),

    Label(  'motorcycle'           , 32 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),

    Label(  'bicycle'              , 33 ,       18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),

    Label(  'license plate'        , 34 ,       19 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),

]

In [None]:
# Function to generate file paths for train dataset

def get_train_file_paths(images_dir, labels_dir):

    data = []

    # Iterate through cities in the train images folder

    for city in os.listdir(images_dir):

        city_image_dir = os.path.join(images_dir, city)

        city_label_dir = os.path.join(labels_dir, city)

        if os.path.isdir(city_image_dir) and os.path.isdir(city_label_dir):

            for file in os.listdir(city_image_dir):

                if file.endswith('_leftImg8bit.png'):

                    image_path = os.path.join(city_image_dir, file)

                    image_name = file.replace('_leftImg8bit.png', '')

                    image_label_path = os.path.join(city_label_dir, f'{image_name}_gtFine_labelTrainIds.png')

                    image_polygons_path = os.path.join(city_label_dir, f'{image_name}_gtFine_polygons.json')

                    data.append([image_path, image_label_path, image_polygons_path])

    return pd.DataFrame(data, columns=['image_path', 'image_label_path', 'image_polygons_path'])



# Function to generate file paths for test dataset

def get_test_file_paths(test_dir):

    data = []

    # Iterate through cities in the test images folder

    for city in os.listdir(test_dir):

        city_dir = os.path.join(test_dir, city)

        if os.path.isdir(city_dir):

            for file in os.listdir(city_dir):

                if file.endswith('_leftImg8bit.png'):

                    image_path = os.path.join(city_dir, file)

                    data.append([image_path])

    return pd.DataFrame(data, columns=['image_path'])



# Paths to images and labels directories

images_dir_train = '/kaggle/input/cityscapes-segmentation/images/train'

labels_dir_train = '/kaggle/input/cityscapes-segmentation/labels/train'

images_dir_test = '/kaggle/input/cityscapes-segmentation/images/test'



# Create train and test DataFrames

train_df = get_train_file_paths(images_dir_train, labels_dir_train)

test_df = get_test_file_paths(images_dir_test)

## Example for the paths

In [None]:
print(train_df['image_path'][100])

print(train_df['image_label_path'][100])

print(train_df['image_polygons_path'][100])

# Visualization

In [None]:
def visualize_train_row(row):

    # Load the image

    image = Image.open(row['image_path'])

    

    # Load the label image (segmentation map)

    label_image = Image.open(row['image_label_path'])



    # Visualize the image and label side by side

    fig, ax = plt.subplots(1, 2, figsize=(15, 7))



    # Display the original image

    ax[0].imshow(image)

    ax[0].set_title('Original Image')

    ax[0].axis('off')



    # Display the label (segmentation) image

    ax[1].imshow(label_image)

    ax[1].set_title('Segmentation Label Image')

    ax[1].axis('off')



    plt.show()





# Example usage with a row from train_df

visualize_train_row(train_df.iloc[1000])


In [None]:
def visualize_train_row_with_polygons(row):

    # Load the original image

    image = Image.open(row['image_path'])



    # Load and parse the polygons JSON

    with open(row['image_polygons_path'], 'r') as f:

        polygons_data = json.load(f)

    

    # Visualize the image

    fig, ax = plt.subplots(figsize=(10, 10))

    

    # Display the original image

    ax.imshow(image)

    # Iterate through the objects in the JSON file

    for obj in polygons_data['objects']:

        polygon_coords = obj['polygon']

        

        # Create a Polygon patch with the coordinates from JSON

        polygon = Polygon(polygon_coords, closed=True, edgecolor='red', fill=False, linewidth=1)

        

        # Add the polygon patch to the image plot

        ax.add_patch(polygon)

    

    # Set title and show the image with polygons overlaid

    ax.set_title('Original Image with Polygons')

    ax.axis('off')

    

    plt.show()



    #Print the JSON polygons data for reference

    # print("Polygons data from JSON file:")

    # print(json.dumps(polygons_data, indent=4))



# Example usage with a row from train_df

visualize_train_row_with_polygons(train_df.iloc[1000])


# Val/Train Split + Saving to CSV

In [None]:
#TODO : Split the dataset into training and validation data

print("Number of Samples before Split: ", len(train_df))



# Shuffle the DataFrame

train_df = train_df.sample(frac=1).reset_index(drop=True)



# Calculate the number of samples for

val_size = 380



# Split the DataFrame into train and validation sets

train_df_final = train_df[:-val_size]

val_df_final = train_df[-val_size:]



print("Train Samples (After Split): ", len(train_df_final))

print("Val Samples (After Split): ", len(val_df_final))



#TODO : Save the split to csv file

train_df_final.to_csv("train_data.csv", index=False)

val_df_final.to_csv("val_data.csv", index=False)

test_df.to_csv("test_data.csv", index=False)

In [None]:
#PRINT SAMPLE FROM CSV_FILE

# Data Analysis



### Write Most Important Notes Here:



- Image dimensions are (1024, 2048, 3)

- Mask is of shape (1024, 2048)

- Datatypes for both are uint8

In [None]:
#TODO: write your own data analysis techniques

### Plotting Pixel Class Frequency

In [None]:
from collections import Counter
import numpy as np

def map_labels_to_trainIds(labels):
    trainId_to_label = {label.trainId: label.name for label in labels if label.trainId != 255}
    return trainId_to_label

def pixel_class_frequency(train_df: pd.DataFrame, mask_labels):
    trainId_to_label = map_labels_to_trainIds(mask_labels)
    
    pixel_counts = Counter()
    class_frequency = Counter()
    
    for index, row in train_df.iterrows():
        mask_path = row["image_label_path"]
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        unique, counts = np.unique(mask, return_counts=True)
        
        for u, c in zip(unique, counts):
            if u in trainId_to_label:
                pixel_counts[trainId_to_label[u]] += c
                class_frequency[trainId_to_label[u]] += 1
                
    return pixel_counts, class_frequency

In [None]:
pixel_class_distribution, class_distribution = pixel_class_frequency(train_df, labels)
print(pixel_class_distribution)
print(class_distribution)

In [None]:
from collections import Counter
import torch
import torch.nn as nn

pixel_class_distribution = Counter({
    'road': 1626667099,
    'building': 1010874745,
    'vegetation': 701460788,
    'car': 307246339,
    'sidewalk': 270602244,
    'sky': 175475731,
    'pole': 53754015,
    'person': 52224101,
    'terrain': 51298793,
    'fence': 38197514,
    'wall': 30679368,
    'traffic sign': 24204440,
    'bicycle': 18262872,
    'truck': 11426619,
    'train': 11260221,
    'bus': 11056793,
    'traffic light': 9125136,
    'rider': 5898636,
    'motorcycle': 4282304
})

total_pixels = sum(pixel_class_distribution.values())
num_classes = len(pixel_class_distribution)

class_weights = {cls: total_pixels / (pixel_class_distribution * count) for cls, count in pixel_class_distribution.items()}
class_weights_tensor = torch.FloatTensor([class_weights[cls] for cls in pixel_class_distribution])

# Define your loss function with class weights
print(class_weights)

In [None]:
import torch
import torch.nn as nn

# Assuming class weights are calculated and converted to a tensor
class_weights_tensor = torch.FloatTensor([
    class_weights['road'],
    class_weights['building'],
    class_weights['vegetation'],
    class_weights['car'],
    class_weights['sidewalk'],
    class_weights['sky'],
    class_weights['pole'],
    class_weights['person'],
    class_weights['terrain'],
    class_weights['fence'],
    class_weights['wall'],
    class_weights['traffic sign'],
    class_weights['bicycle'],
    class_weights['truck'],
    class_weights['train'],
    class_weights['bus'],
    class_weights['traffic light'],
    class_weights['rider'],
    class_weights['motorcycle'],
]) 

# Define the loss function
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

def plot_class_distribution(distribution, title, y_label):
    df = pd.DataFrame(list(distribution.items()), columns=['Class', 'Count'])

    plt.figure(figsize=(12, 6))
    sns.barplot(x='Class', y='Count', data=df, palette='coolwarm')

    plt.xticks(rotation=90, ha='right')
    plt.xlabel('Class')
    plt.ylabel(y_label)
    plt.title(title)

    plt.tight_layout()
    plt.show()

plot_class_distribution(pixel_class_distribution, title="Pixel-Wise Class Distribution", y_label="Pixel Count")
plot_class_distribution(class_distribution, title="Images Class Distribution", y_label="Images Count")

### Getting Insights Over Images' Brightness

In [None]:
def calculate_luminance(train_df: pd.DataFrame):
    images_lum = []
    for index, row in train_df.iterrows():
        image_path = row["image_path"]
        image = cv2.imread(image_path)
        
        image = image.astype('float32') / 255.0
    
        R, G, B = image[:,:,2], image[:,:,1], image[:,:,0]
    
        luminance = 0.299 * R + 0.587 * G + 0.114 * B
        avg_luminance = np.mean(luminance)
        
        images_lum.append(avg_luminance)
    return images_lum

In [None]:
def plot_luminance_frequency(lum_vals):
    lum_vals = np.array(lum_vals)
    lum_vals = lum_vals[np.isfinite(lum_vals)]
    
    luminance_df = pd.DataFrame({'Luminance': lum_vals})
    
    plt.figure(figsize=(12, 6))
    sns.histplot(luminance_df['Luminance'], bins=15, kde=True)
    plt.title('Histogram of Average Luminance Values')
    plt.xlabel('Average Luminance')
    plt.ylabel('Frequency')
    plt.xlim(0, 1)
    plt.show()

In [None]:
images_lum = calculate_luminance(train_df)

In [None]:
plot_luminance_frequency(images_lum)

# Preprocessing

In [None]:
# Example for the desired interface 



class LoadImage():

   def __init__(self, keys):

       self.keys=keys

   def __call__(self,sample):

       for key in self.keys:

           sample[key]= sample[key]

       return sample

In [None]:
#TODO : TRAIN DATA PREPROCESSING PIPELINE

from torchvision import transforms

Train_data_transform = transforms.Compose([
    LoadImage(['image','mask']),

    # Write your own data transforms and augmentations
    
    # Random rotation to help with orientation variance
    #transforms.RandomRotation(degrees=15),
    
    # Randomly resize and crop the image to introduce scale variance
    # transforms.RandomResizedCrop(size=(256, 256), scale=(0.8, 1.0)),
    
    
    
    transforms.RandomChoice([
        transforms.ColorJitter(brightness=(1.2, 1.8), contrast=(0.5, 1.5), saturation=(1.2, 1.5), hue=0.2),
        transforms.ColorJitter(brightness=(1.2, 1.5), contrast=(1.2, 2), saturation=(0.5, 0.8), hue=0.2),
    ]), # Chooses one of the color space transforms
    transforms.RandomAdjustSharpness(sharpness_factor=0.3),
    transforms.RandomHorizontalFlip(p=0.8),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.3),
    transforms.ToTensor(),
])

#TODO : VALIDATION DATA PREPROCESSING PIPELINE
Valid_data_transform=Compose([

    LoadImage(['image','mask']),

    # Write your own data transforms and augmentations
    #transforms.Resize((256, 256)),  # Ensure the size matches the model's input
    #transforms.ToTensor(),
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Match normalization with training



])



#TODO : TEST DATA PREPROCESSING PIPELINE





Test_data_transform=Compose([

    LoadImage(['image']), # No mask since it is a test image

    # Write your own data transforms and augmentations
    # transforms.Resize((256, 256)),  # Ensure the size matches the model's input
    # transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Match normalization



])

In [None]:
import re

from torch.utils.data import DataLoader, Dataset

import pandas as pd

from typing import List

from torchvision.transforms import Compose



class Dataset(Dataset):

    def __init__(self, input_dataframe: pd.DataFrame, data_transform:Compose, has_labels=True):

        self.input_dataframe = input_dataframe

        self.data_transform = data_transform

        self.has_labels = has_labels



    def __getitem__(self, item: int):

        sample_idx = item

        sample = {}

        sample['image'] = torch.from_numpy(cv2.imread(self.input_dataframe['image_path'][sample_idx], cv2.IMREAD_UNCHANGED))/255

        if self.has_labels:

           sample['mask'] = torch.from_numpy(cv2.imread(self.input_dataframe['image_label_path'][sample_idx], cv2.IMREAD_UNCHANGED))

        sample = self.data_transform(sample)
        
        # Bring channel dimension first
        sample['image'] = torch.permute(sample['image'], (2, 0, 1))
        sample['mask'] = sample['mask'].to(torch.int64)

        return sample

        

        

    def __len__(self):

        return len(self.input_dataframe)

In [None]:
#TODO : Initilize your datasets



ds_train=Dataset(input_dataframe=train_df_final,

                data_transform=Train_data_transform)


In [None]:
dl_train=DataLoader(dataset=ds_train,batch_size= 2 ,num_workers=4 ,prefetch_factor=8,shuffle=True)

In [None]:
#TODO : Show samples from your data loaders

def print_batch_info(data_loader):

  """Prints information about batches from a DataLoader.



  Args:

    data_loader: The DataLoader to inspect.

  """



  for batch in data_loader:

    print("Batch Images shape:", batch['image'].shape)

    print("Batch Masks shape:", batch['mask'].shape)

    print("Batch Images dtype:", batch['image'].dtype)

    print("Batch Masks dtype:", batch['mask'].dtype)

    print("Batch Image device:", batch['image'].device)

    print("Batch Masks device:", batch['mask'].device)

    print("\n")

    break # print for first batch only



# Example usage

dl_train = DataLoader(dataset=ds_train, batch_size=2, num_workers=4, prefetch_factor=8, shuffle=True)



print_batch_info(dl_train)

# Model (U-NET)

In [None]:
#TODO : Write the model you are going to use (Pytorch)


class UNet(nn.Module):
    def double_convolution(self, in_channels, out_channels):
        conv_op = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        return conv_op
    
    def __init__(self, num_classes=20):
        super(UNet, self).__init__()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        # Contracting path.
        
        # Each convolution is applied twice.
        self.down_convolution_1 = self.double_convolution(3, 16)
        self.down_convolution_2 = self.double_convolution(16, 32)
        self.down_convolution_3 = self.double_convolution(32, 64)
        self.down_convolution_4 = self.double_convolution(64, 128)
        self.down_convolution_5 = self.double_convolution(128, 256)
        
        
        
        # Expanding path.
        self.up_transpose_1 = nn.ConvTranspose2d(
            in_channels=256, out_channels=128,
            kernel_size=2, 
            stride=2)
        self.up_convolution_1 = self.double_convolution(256, 128)
        
        
        self.up_transpose_2 = nn.ConvTranspose2d(
            in_channels=128, out_channels=64,
            kernel_size=2, 
            stride=2)
        self.up_convolution_2 = self.double_convolution(128, 64)
        
        
        self.up_transpose_3 = nn.ConvTranspose2d(
            in_channels=64, out_channels=32,
            kernel_size=2, 
            stride=2)
        self.up_convolution_3 = self.double_convolution(64, 32)
        
        
        self.up_transpose_4 = nn.ConvTranspose2d(
            in_channels=32, out_channels=16,
            kernel_size=2, 
            stride=2)
        self.up_convolution_4 = self.double_convolution(32, 16)
        

        self.out = nn.Conv2d(
            in_channels=16, out_channels=num_classes, 
            kernel_size=1
        )
        

        
    def forward(self, x):
        # Input x shape is [batch_size, 3, 1024, 2048]
        
        # Contracting
        down_1 = self.down_convolution_1(x) 
        down_2 = self.max_pool2d(down_1) 
        down_3 = self.down_convolution_2(down_2)
        down_4 = self.max_pool2d(down_3)
        down_5 = self.down_convolution_3(down_4) 
        down_6 = self.max_pool2d(down_5) 
        down_7 = self.down_convolution_4(down_6) 
        down_8 = self.max_pool2d(down_7)
        down_9 = self.down_convolution_5(down_8) # [batch_size, 256, 64, 128]
        
        
        # Expanding
        up_1 = self.up_transpose_1(down_9) # [batch_size, 128, 128, 256]
        x = self.up_convolution_1(torch.cat([down_7, up_1], 1))
        up_2 = self.up_transpose_2(x)
        x = self.up_convolution_2(torch.cat([down_5, up_2], 1))
        up_3 = self.up_transpose_3(x)
        x = self.up_convolution_3(torch.cat([down_3, up_3], 1))
        up_4 = self.up_transpose_4(x) 
        x = self.up_convolution_4(torch.cat([down_1, up_4], 1)) # [batch_size, 16, 1024, 2048]
        
        out = self.out(x) # [batch_size, num_classes, 1024, 2048]
        return out



    
    
# # Test for the model
# model = UNet()
# output = None
# inp = None
# label = None
# for batch in dl_train:
#     inp = batch['image']
#     label = batch['mask']
#     output = model(batch['image'])
#     break



# Loss

In [None]:
#TODO : Write the loss function you are going to use
import torch.optim as optim

model = UNet()
loss = nn.CrossEntropyLoss(ignore_index=255)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# # Test for the loss
# print(output.shape)
# print(inp.shape)
# print(label.shape)
# print(output.dtype)
# print(inp.dtype)
# print(label.dtype)


# print(loss(output,label))

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
#TODO : Write the evaluation metrics you are going to use

import torch
import numpy as np
from tqdm import tqdm

def evaluate_unet_model(model, dataloader, device):
    model = model.to(device)
    model.eval()  # Set model to evaluation mode

    # Initialize metrics
    total_correct = 0
    total_pixels = 0
    class_intersection = np.zeros((num_classes,))
    class_union = np.zeros((num_classes,))

    with torch.no_grad():  # Disable gradient calculation
        for batch in tqdm(dataloader):
            inputs = batch['image'].to(device)
            masks = batch['mask'].to(device)

            # Forward pass
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1)  # Get predictions

            # Calculate pixel-wise accuracy
            total_correct += (preds == masks).sum().item()
            total_pixels += masks.numel()

            # Calculate per-class IoU
            for c in range(num_classes):
                intersection = ((preds == c) & (masks == c)).sum().item()
                union = ((preds == c) | (masks == c)).sum().item()

                class_intersection[c] += intersection
                class_union[c] += union

    # Calculate overall accuracy
    pixel_accuracy = total_correct / total_pixels

    # Calculate mean IoU for each class
    mean_iou = class_intersection / (class_union + 1e-6)  # Avoid division by zero
    mean_iou = np.nanmean(mean_iou)  # Get the mean, ignoring NaNs

    print(f'Pixel Accuracy: {pixel_accuracy:.4f}')
    print(f'Mean IoU: {mean_iou:.4f}')

# Example usage
# Assuming you have a validation dataloader named `dl_val`
evaluate_unet_model(model, dl_val, device)


In [None]:
import torch
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_unet_model(model, dataloader, criterion, optimizer, num_epochs=10, device=device):
    model = model.to(device)
    model.train()  # Set model to training mode

    for epoch in range(num_epochs):
        running_loss = 0.0

        # Progress bar for the epoch
        pbar = tqdm(enumerate(dataloader), total=len(dataloader))

        for i, batch in pbar:
            inputs = batch['image'].to(device)
            masks = batch['mask'].to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)

            # Compute loss
            loss = criterion(outputs, masks)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Update running loss
            running_loss += loss.item()
            pbar.set_description(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / (i+1):.4f}")

        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {running_loss/len(dataloader):.4f}")

    print("Training complete")


In [None]:
train_unet_model(model, dl_train, loss, optimizer)  # training on 10 epochs

In [None]:
#TODO : Plot losses and metrics graphs

In [None]:
#TODO : Test your model and show some samples