## 0 - Imports

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as A
from tqdm import tqdm
from itertools import cycle
from typing import Tuple, List
#from config import CITYSCAPES, GTA, DEEPLABV2_PATH, CITYSCAPES_PATH, GTA5_PATH
#from datasets import CityScapes, GTA5
#from models import BiSeNet, get_deeplab_v2, FCDiscriminator
#from utils import *
import warnings
warnings.filterwarnings("ignore")
torch.cuda.manual_seed(42)


## 1 - Dataset

### CityScapes

In [None]:
from torch.utils.data import Dataset
import torch
from PIL import Image
import numpy as np
import os
from typing import Optional, Tuple
from albumentations import Compose


class CityScapes(Dataset):
    
    """
    A dataset class for loading and processing the CityScapes dataset.
    """
    def __init__(self, 
                 root_dir:str, 
                 split:str = 'train', 
                 transform: Optional[Compose] = None):
        """
        Initializes the CityScapes dataset.

        Args:
            root_dir (str): Root directory of the dataset.
            split (str, optional): Dataset split to use ('train', 'val', 'test'). Defaults to 'train'.
            transform (Optional[Compose], optional): Transformations to be applied on images and labels.. Defaults to None.
        """
        super(CityScapes, self).__init__()

        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        
        # Load the data
        self.data = []
        path = os.path.join(self.root_dir, 'images', split)
        for city in os.listdir(path):
            images = os.path.join(path, city)
            for image in os.listdir(images):
                image = os.path.join(images, image)
                label = image.replace('images', 'gtFine').replace('_leftImg8bit','_gtFine_labelTrainIds')
                self.data.append((image, label))

    def __len__(self)->int:   
        """
        Returns the total number of samples in the dataset.

        Returns:
            int: Number of samples in the dataset.
        """ 
        return len(self.data)

    def __getitem__(self, idx:int)-> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generates one sample of data.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Tuple containing the image and the corresponding label.
        """
        image_path, label_path = self.data[idx]

        # Load image and label
        image = Image.open(image_path).convert('RGB')
        label = Image.open(label_path).convert('L')
        image, label = np.array(image), np.array(label)
        
        if self.transform:
            transformed = self.transform(image=image, mask=label)
            image, label = transformed['image'], transformed['mask']

        image = torch.from_numpy(image).permute(2, 0, 1).float()/255
        label = torch.from_numpy(label).long()
        
        return image, label

### GTA5

In [None]:
from torch.utils.data import Dataset
import torch
from PIL import Image
import numpy as np
import os
from typing import Optional, Tuple
from albumentations import Compose
#from utils import get_color_to_id


class GTA5(Dataset):
    
    """
    A dataset class for loading and processing the GTA5 dataset.
    """
    
    def __init__(self, 
                 root_dir:str,
                 compute_mask:bool=False,
                 transform: Optional[Compose] = None):
        """
        Initializes the GTA5 dataset.

        Args:
            root_dir (str): Root directory of the dataset.
            compute_mask (bool, optional): Whether to compute the mask from RGB labels. Defaults to False.
            transform (Optional[Compose], optional): Transformations to be applied on images and labels. Defaults to None.
        """
        super(GTA5, self).__init__()
        
        self.root_dir = root_dir
        self.compute_mask = compute_mask
        self.transform = transform
        if self.compute_mask:
            self.color_to_id = get_color_to_id()
        
        # Load the data
        self.data = []
        image_dir = os.path.join(self.root_dir, 'images')
        
        if self.compute_mask:
            label_dir = os.path.join(self.root_dir, 'labels')
        else:
            label_dir = os.path.join(self.root_dir, 'masks')
            
        for filename in os.listdir(image_dir):
            image = os.path.join(image_dir, filename)
            label = os.path.join(label_dir, filename)
            self.data.append((image, label))
            
    def _rgb_to_label(self, image:Image.Image)->np.ndarray:
        """
        Converts an RGB image to a label image using the color to ID mapping.

        Args:
            image (Image.Image): The input RGB image.

        Returns:
            np.ndarray: The label image.
        """
        image_np = np.array(image)
        label = np.zeros((image_np.shape[0], image_np.shape[1]), dtype=np.uint8)
        
        for color, class_id in self.color_to_id.items():
            mask = np.all(image_np == color, axis=-1)
            label[mask] = class_id
    
        return label
        
    def __len__(self)->int: 
        """
        Returns the total number of samples in the dataset.

        Returns:
            int: Number of samples in the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx:int)-> Tuple[torch.Tensor,torch.Tensor]:
        """
        Generates one sample of data.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Tuple containing the image and the corresponding label or mask.
        """
        image_path, label_path = self.data[idx]

        # Load images and labels or masks
        image = Image.open(image_path).convert('RGB')
        
        if self.compute_mask:
            label = self._rgb_to_label(Image.open(label_path).convert('RGB'))
        else:
            label = Image.open(label_path).convert('L')
            
        image, label = np.array(image), np.array(label)
        
        if self.transform:
            transformed = self.transform(image=image, mask=label)
            image, label = transformed['image'], transformed['mask']

        image = torch.from_numpy(image).permute(2, 0, 1).float()/255
        label = torch.from_numpy(label).long()
        return image, label

## 2 - Models

### DeepLabV2

In [None]:
import torch
import torch.nn as nn

affine_par = True


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(Bottleneck, self).__init__()
        # change
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
        self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
        for i in self.bn1.parameters():
            i.requires_grad = False
        padding = dilation
        # change
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                               padding=padding, bias=False, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
        for i in self.bn2.parameters():
            i.requires_grad = False
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)
        for i in self.bn3.parameters():
            i.requires_grad = False
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)

        return out


class ClassifierModule(nn.Module):
    def __init__(self, inplanes, dilation_series, padding_series, num_classes):
        super(ClassifierModule, self).__init__()
        self.conv2d_list = nn.ModuleList()
        for dilation, padding in zip(dilation_series, padding_series):
            self.conv2d_list.append(
                nn.Conv2d(inplanes, num_classes, kernel_size=3, stride=1, padding=padding,
                          dilation=dilation, bias=True))

        for m in self.conv2d_list:
            m.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.conv2d_list[0](x)
        for i in range(len(self.conv2d_list) - 1):
            out += self.conv2d_list[i + 1](x)
        return out


class ResNetMulti(nn.Module):
    def __init__(self, block, layers, num_classes):
        self.inplanes = 64
        super(ResNetMulti, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64, affine=affine_par)
        for i in self.bn1.parameters():
            i.requires_grad = False
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)  # change
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
        self.layer6 = ClassifierModule(2048, [6, 12, 18, 24], [6, 12, 18, 24], num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.01)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        if (stride != 1
                or self.inplanes != planes * block.expansion
                or dilation == 2
                or dilation == 4):
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, affine=affine_par))
        for i in downsample._modules['1'].parameters():
            i.requires_grad = False
        layers = []
        layers.append(
            block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        _, _, H, W = x.size()

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer6(x)

        x = torch.nn.functional.interpolate(x, size=(H, W), mode='bilinear')

        #if self.training == True:
        #    return x, None, None

        return x

    def get_1x_lr_params_no_scale(self):
        """
        This generator returns all the parameters of the net except for
        the last classification layer. Note that for each batchnorm layer,
        requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
        any batchnorm parameter
        """
        b = []

        b.append(self.conv1)
        b.append(self.bn1)
        b.append(self.layer1)
        b.append(self.layer2)
        b.append(self.layer3)
        b.append(self.layer4)

        for i in range(len(b)):
            for j in b[i].modules():
                jj = 0
                for k in j.parameters():
                    jj += 1
                    if k.requires_grad:
                        yield k

    def get_10x_lr_params(self):
        """
        This generator returns all the parameters for the last layer of the net,
        which does the classification of pixel into classes
        """
        b = []
        if self.multi_level:
            b.append(self.layer5.parameters())
        b.append(self.layer6.parameters())

        for j in range(len(b)):
            for i in b[j]:
                yield i

    def optim_parameters(self, lr):
        return [{'params': self.get_1x_lr_params_no_scale(), 'lr': lr},
                {'params': self.get_10x_lr_params(), 'lr': 10 * lr}]


def get_deeplab_v2(num_classes=19, pretrain=True, pretrain_model_path='DeepLab_resnet_pretrained_imagenet.pth'):
    model = ResNetMulti(Bottleneck, [3, 4, 23, 3], num_classes)

    # Pretraining loading
    if pretrain:
        print('Deeplab pretraining loading...')
        saved_state_dict = torch.load(pretrain_model_path)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            i_parts = i.split('.')
            new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        model.load_state_dict(new_params, strict=False)

    return model


### BiSeNet

In [None]:
import torch
from torch import nn
#from .build_contextpath import build_contextpath
#import warnings
#warnings.filterwarnings(action='ignore')


class ConvBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, input):
        x = self.conv1(input)
        return self.relu(self.bn(x))


class Spatial_path(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.convblock1 = ConvBlock(in_channels=3, out_channels=64)
        self.convblock2 = ConvBlock(in_channels=64, out_channels=128)
        self.convblock3 = ConvBlock(in_channels=128, out_channels=256)

    def forward(self, input):
        x = self.convblock1(input)
        x = self.convblock2(x)
        x = self.convblock3(x)
        return x


class AttentionRefinementModule(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.sigmoid = nn.Sigmoid()
        self.in_channels = in_channels
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

    def forward(self, input):
        # global average pooling
        x = self.avgpool(input)
        assert self.in_channels == x.size(1), 'in_channels and out_channels should all be {}'.format(x.size(1))
        x = self.conv(x)
        x = self.sigmoid(self.bn(x))
        # x = self.sigmoid(x)
        # channels of input and x should be same
        x = torch.mul(input, x)
        return x


class FeatureFusionModule(torch.nn.Module):
    def __init__(self, num_classes, in_channels):
        super().__init__()
        # self.in_channels = input_1.channels + input_2.channels
        # resnet101 3328 = 256(from spatial path) + 1024(from context path) + 2048(from context path)
        # resnet18  1024 = 256(from spatial path) + 256(from context path) + 512(from context path)
        self.in_channels = in_channels

        self.convblock = ConvBlock(in_channels=self.in_channels, out_channels=num_classes, stride=1)
        self.conv1 = nn.Conv2d(num_classes, num_classes, kernel_size=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(num_classes, num_classes, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

    def forward(self, input_1, input_2):
        x = torch.cat((input_1, input_2), dim=1)
        assert self.in_channels == x.size(1), 'in_channels of ConvBlock should be {}'.format(x.size(1))
        feature = self.convblock(x)
        x = self.avgpool(feature)

        x = self.relu(self.conv1(x))
        x = self.sigmoid(self.conv2(x))
        x = torch.mul(feature, x)
        x = torch.add(x, feature)
        return x


class BiSeNet(torch.nn.Module):
    def __init__(self, num_classes, context_path):
        super().__init__()
        # build spatial path
        self.saptial_path = Spatial_path()

        # build context path
        self.context_path = build_contextpath(name=context_path)

        # build attention refinement module  for resnet 101
        if context_path == 'resnet101':
            self.attention_refinement_module1 = AttentionRefinementModule(1024, 1024)
            self.attention_refinement_module2 = AttentionRefinementModule(2048, 2048)
            # supervision block
            self.supervision1 = nn.Conv2d(in_channels=1024, out_channels=num_classes, kernel_size=1)
            self.supervision2 = nn.Conv2d(in_channels=2048, out_channels=num_classes, kernel_size=1)
            # build feature fusion module
            self.feature_fusion_module = FeatureFusionModule(num_classes, 3328)

        elif context_path == 'resnet18':
            # build attention refinement module  for resnet 18
            self.attention_refinement_module1 = AttentionRefinementModule(256, 256)
            self.attention_refinement_module2 = AttentionRefinementModule(512, 512)
            # supervision block
            self.supervision1 = nn.Conv2d(in_channels=256, out_channels=num_classes, kernel_size=1)
            self.supervision2 = nn.Conv2d(in_channels=512, out_channels=num_classes, kernel_size=1)
            # build feature fusion module
            self.feature_fusion_module = FeatureFusionModule(num_classes, 1024)
        else:
            print('Error: unspport context_path network \n')

        # build final convolution
        self.conv = nn.Conv2d(in_channels=num_classes, out_channels=num_classes, kernel_size=1)

        self.init_weight()

        self.mul_lr = []
        self.mul_lr.append(self.saptial_path)
        self.mul_lr.append(self.attention_refinement_module1)
        self.mul_lr.append(self.attention_refinement_module2)
        self.mul_lr.append(self.supervision1)
        self.mul_lr.append(self.supervision2)
        self.mul_lr.append(self.feature_fusion_module)
        self.mul_lr.append(self.conv)

    def init_weight(self):
        for name, m in self.named_modules():
            if 'context_path' not in name:
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                elif isinstance(m, nn.BatchNorm2d):
                    m.eps = 1e-5
                    m.momentum = 0.1
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

    def forward(self, input):
        # output of spatial path
        sx = self.saptial_path(input)

        # output of context path
        cx1, cx2, tail = self.context_path(input)
        cx1 = self.attention_refinement_module1(cx1)
        cx2 = self.attention_refinement_module2(cx2)
        cx2 = torch.mul(cx2, tail)
        # upsampling
        cx1 = torch.nn.functional.interpolate(cx1, size=sx.size()[-2:], mode='bilinear')
        cx2 = torch.nn.functional.interpolate(cx2, size=sx.size()[-2:], mode='bilinear')
        cx = torch.cat((cx1, cx2), dim=1)

        #if self.training == True:
        #    cx1_sup = self.supervision1(cx1)
        #    cx2_sup = self.supervision2(cx2)
        #    cx1_sup = torch.nn.functional.interpolate(cx1_sup, size=input.size()[-2:], mode='bilinear')
        #    cx2_sup = torch.nn.functional.interpolate(cx2_sup, size=input.size()[-2:], mode='bilinear')

        # output of feature fusion module
        result = self.feature_fusion_module(sx, cx)

        # upsampling
        result = torch.nn.functional.interpolate(result, scale_factor=8, mode='bilinear')
        result = self.conv(result)

        #if self.training == True:
        #    return result, cx1_sup, cx2_sup

        return result

In [None]:
import torch
from torchvision import models


class resnet18(torch.nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.resnet18(pretrained=pretrained)
        self.conv1 = self.features.conv1
        self.bn1 = self.features.bn1
        self.relu = self.features.relu
        self.maxpool1 = self.features.maxpool
        self.layer1 = self.features.layer1
        self.layer2 = self.features.layer2
        self.layer3 = self.features.layer3
        self.layer4 = self.features.layer4

    def forward(self, input):
        x = self.conv1(input)
        x = self.relu(self.bn1(x))
        x = self.maxpool1(x)
        feature1 = self.layer1(x)  # 1 / 4
        feature2 = self.layer2(feature1)  # 1 / 8
        feature3 = self.layer3(feature2)  # 1 / 16
        feature4 = self.layer4(feature3)  # 1 / 32
        # global average pooling to build tail
        tail = torch.mean(feature4, 3, keepdim=True)
        tail = torch.mean(tail, 2, keepdim=True)
        return feature3, feature4, tail


class resnet101(torch.nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.resnet101(pretrained=pretrained)
        self.conv1 = self.features.conv1
        self.bn1 = self.features.bn1
        self.relu = self.features.relu
        self.maxpool1 = self.features.maxpool
        self.layer1 = self.features.layer1
        self.layer2 = self.features.layer2
        self.layer3 = self.features.layer3
        self.layer4 = self.features.layer4

    def forward(self, input):
        x = self.conv1(input)
        x = self.relu(self.bn1(x))
        x = self.maxpool1(x)
        feature1 = self.layer1(x)  # 1 / 4
        feature2 = self.layer2(feature1)  # 1 / 8
        feature3 = self.layer3(feature2)  # 1 / 16
        feature4 = self.layer4(feature3)  # 1 / 32
        # global average pooling to build tail
        tail = torch.mean(feature4, 3, keepdim=True)
        tail = torch.mean(tail, 2, keepdim=True)
        return feature3, feature4, tail


def build_contextpath(name):
    model = {
        'resnet18': resnet18(pretrained=True),
        'resnet101': resnet101(pretrained=True)
    }
    return model[name]

### Discriminator

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class FCDiscriminator(nn.Module):

    def __init__(self, num_classes, ndf = 64):
        super(FCDiscriminator, self).__init__()

        self.conv1 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
        self.classifier = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=2, padding=1)

        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        #self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear')
        #self.sigmoid = nn.Sigmoid()


    def forward(self, x):
        x = self.conv1(x)
        x = self.leaky_relu(x)
        x = self.conv2(x)
        x = self.leaky_relu(x)
        x = self.conv3(x)
        x = self.leaky_relu(x)
        x = self.conv4(x)
        x = self.leaky_relu(x)
        x = self.classifier(x)
        #x = self.up_sample(x)
        #x = self.sigmoid(x) 

        return x

## 3 - Utils

### Checkpoint

In [None]:
import os
import torch
from typing import List, Dict, Tuple, Optional
#from utils import get_id_to_label
#from config import CHECKPOINT_ROOT

def save_results(model_results: List[List[float]], 
                 filename: str,
                 project_step: str,
                 model_params_flops: Dict[str, float],
                 model_latency_fps: Dict[str, float]) -> None:
    """
    Saves the model results to a text file.

    Args:
        model_results (List[List[float]]): A list containing model results.
            - model_results[0]: List of training losses.
            - model_results[1]: List of validation losses.
            - model_results[2]: List of training mIoU scores.
            - model_results[3]: List of validation mIoU scores.
            - model_results[4]: List of training IoU scores for each class.
            - model_results[5]: List of validation IoU scores for each class.
        filename (str): The name of the file to save the results in.
        project_step (str): The current project step, used for directory naming.
        model_params_flops (Dict[str, float]): Dictionary containing model parameters and FLOPS.
            - 'Parameters': Number of parameters.
            - 'FLOPS': Floating Point Operations per Second.
        model_latency_fps (Dict[str, float]): Dictionary containing model latency and FPS information.
            - 'mean_latency': Mean latency.
            - 'std_latency': Standard deviation of latency.
            - 'mean_fps': Mean FPS.
            - 'std_fps': Standard deviation of FPS.
    """
    
    # Construct the checkpoint path
    checkpoint_path = f'{CHECKPOINT_ROOT}/{project_step}'
    
    # Create the directory if it does not exist
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    
    # Open the file for writing
    with open(f"{checkpoint_path}/{filename}.txt", 'w') as file:
        # Write model parameters and FLOPS
        file.write(f"Parameters: {model_params_flops['Parameters']}\n")
        file.write(f"FLOPS: {model_params_flops['FLOPS']}\n")
        
        # Write latency information
        file.write("Latency:\n")
        file.write(f"\tmean: {model_latency_fps['mean_latency']}\n")
        file.write(f"\tstd: {model_latency_fps['std_latency']}\n")
        
        # Write FPS information
        file.write("FPS:\n")
        file.write(f"\tmean: {model_latency_fps['mean_fps']}\n")
        file.write(f"\tstd: {model_latency_fps['std_fps']}\n")
        
        # Write loss information
        file.write("Loss:\n")
        file.write(f"\ttrain: {model_results[0][-1]}\n")
        file.write(f"\tval: {model_results[1][-1]}\n")
        
        # Write mIoU information
        file.write("mIoU:\n")
        file.write(f"\ttrain: {model_results[2][-1]}\n")
        file.write(f"\tval: {model_results[3][-1]}\n")
        
        # Write training IoU for each class
        file.write("Training IoU for class:\n")
        for i, iou in enumerate(model_results[4]):
            file.write(f"{get_id_to_label()[i]}: {iou}\n")
        
        # Write validation IoU for each class
        file.write("Validation IoU for class:\n")
        for i, iou in enumerate(model_results[5]):
            file.write(f"{get_id_to_label()[i]}: {iou}\n")
            
def save_checkpoint(checkpoint_root: str,
                    project_step: str, 
                    adversarial: bool,
                    model: torch.nn.Module, 
                    model_D: torch.nn.Module, 
                    optimizer: torch.optim.Optimizer, 
                    optimizer_D: torch.optim.Optimizer, 
                    epoch: int,
                    train_loss_list: List[float], 
                    train_miou_list: List[float],
                    train_iou: List[float],
                    val_loss_list: List[float],
                    val_miou_list: List[float],
                    val_iou: List[float],
                    verbose: bool)->None:
    """
    Saves the current state of the training process to a checkpoint file.

    Args:
        checkpoint_root (str): The root directory where the checkpoint will be saved.
        project_step (str): The current project step or phase, used for naming the checkpoint file.
        adversarial (bool): Whether to use adversarial training.
        model (torch.nn.Module): The main model whose state is to be saved.
        model_D (torch.nn.Module): The auxiliary or discriminator model whose state is to be saved.
        optimizer (torch.optim.Optimizer): The optimizer for the main model.
        optimizer_D (torch.optim.Optimizer): The optimizer for the auxiliary/discriminator model.
        epoch (int): The current epoch number.
        train_loss_list (List[float]): List of training losses over epochs.
        train_miou_list (List[float]): List of training mean Intersection over Union (mIoU) scores over epochs.
        train_iou (List[float]): List of training IoU scores for each class.
        val_loss_list (List[float]): List of validation losses over epochs.
        val_miou_list (List[float]): List of validation mIoU scores over epochs.
        val_iou (List[float]): List of validation IoU scores for each class.
        verbose (bool): If True, prints a message confirming the checkpoint has been saved.

    Returns:
        None
    """
    # Construct the path for the checkpoint file
    checkpoint_path = f'{checkpoint_root}/{project_step}/checkpoint.pth'
    
    # Save the state of the training process, including model parameters, optimizers, and performance metrics
    if adversarial:
        torch.save({
            'model': model.state_dict(),
            'model_D': model_D.state_dict(),
            'optimizer': optimizer.state_dict(),
            'optimizer_D': optimizer_D.state_dict(),
            'epoch': epoch + 1,
            'train_loss_list': train_loss_list,
            'train_miou_list': train_miou_list,
            'train_iou': train_iou,
            'val_loss_list': val_loss_list,
            'val_miou_list': val_miou_list,
            'val_iou': val_iou
        }, checkpoint_path)
    else:
        torch.save({
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch + 1,
            'train_loss_list': train_loss_list,
            'train_miou_list': train_miou_list,
            'train_iou': train_iou,
            'val_loss_list': val_loss_list,
            'val_miou_list': val_miou_list,
            'val_iou': val_iou
        }, checkpoint_path)
    
    # If verbose is True, print a confirmation message
    if verbose == True:
        print(f"Checkpoint saved in {checkpoint_path}")
    
def load_checkpoint(checkpoint_root: str,
                    project_step: str, 
                    adversarial: bool,
                    model: torch.nn.Module, 
                    model_D: torch.nn.Module,
                    optimizer: torch.optim.Optimizer,
                    optimizer_D: torch.optim.Optimizer) -> Tuple[bool, Optional[int], Optional[List[float]], Optional[List[float]], Optional[List[float]], Optional[List[float]], Optional[List[float]], Optional[List[float]]]:
    """
    Loads the checkpoint from the specified directory and restores the model, optimizer, and training state.

    Args:
        checkpoint_root (str): The root directory where the checkpoint is stored.
        project_step (str): The current project step or phase, used for constructing the checkpoint file path.
        adversarial (bool): Whether to use adversarial training.
        model (torch.nn.Module): The main model to load the state dictionary into.
        model_D (torch.nn.Module): The auxiliary or discriminator model to load the state dictionary into.
        optimizer (torch.optim.Optimizer): The optimizer for the main model to load the state dictionary into.
        optimizer_D (torch.optim.Optimizer): The optimizer for the auxiliary/discriminator model to load the state dictionary into.

    Returns:
        Tuple[bool, Optional[int], Optional[List[float]], Optional[List[float]], Optional[List[float]], Optional[List[float]], Optional[List[float]], Optional[List[float]]]:
            - bool: Indicates whether to start training from scratch (True) or resume from a checkpoint (False).
            - Optional[int]: The epoch to resume from, if a checkpoint is found.
            - Optional[List[float]]: List of training losses over epochs, if a checkpoint is found.
            - Optional[List[float]]: List of training mean Intersection over Union (mIoU) scores over epochs, if a checkpoint is found.
            - Optional[List[float]]: List of training IoU scores for each class, if a checkpoint is found.
            - Optional[List[float]]: List of validation losses over epochs, if a checkpoint is found.
            - Optional[List[float]]: List of validation mIoU scores over epochs, if a checkpoint is found.
            - Optional[List[float]]: List of validation IoU scores for each class, if a checkpoint is found.
    """

    # Construct the path to the checkpoint file
    checkpoint_path = f'{checkpoint_root}/{project_step}/checkpoint.pth'
    
    # Check if the checkpoint file exists
    if os.path.exists(checkpoint_path):
        # Load the checkpoint
        checkpoint = torch.load(checkpoint_path)
        
        # Load the state dictionaries into the model, auxiliary model, and optimizers
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if adversarial:
            model_D.load_state_dict(checkpoint['model_D'])
            optimizer_D.load_state_dict(checkpoint['optimizer_D'])
        
        # Extract training state information
        start_epoch = checkpoint['epoch']
        train_loss_list = checkpoint['train_loss_list']
        train_miou_list = checkpoint['train_miou_list']
        train_iou = checkpoint['train_iou']
        val_loss_list = checkpoint['val_loss_list']
        val_miou_list = checkpoint['val_miou_list']
        val_iou = checkpoint['val_iou']
        
        # Print a message indicating the checkpoint was found and loaded
        print(f"Checkpoint found. Resuming from epoch {start_epoch}.")
        
        # Return the state indicating that training can resume from the checkpoint
        return (False, start_epoch, train_loss_list, train_miou_list, train_iou, val_loss_list, val_miou_list, val_iou)
    
    else:
        # Create the directory if it does not exist
        directory = f'{checkpoint_root}/{project_step}'
        if not os.path.exists(directory):
            os.makedirs(directory)
        
        # Print a message indicating no checkpoint was found and training will start from scratch
        print(f"No checkpoint found in {directory}. Starting from scratch.")
        
        # Return the state indicating that training should start from scratch
        return (True, None, None, None, None, None, None, None)
  

### Computations

In [None]:
! pip install -U fvcore

In [None]:
from typing import Dict, Tuple
import numpy as np
import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table
import time

def compute_flops(model: torch.nn.Module, 
                  height: int = 512, 
                  width: int = 1024) -> Dict[str, str]:
    """
    Computes the number of floating point operations (FLOPs) and parameters of a model.

    Args:
        model (torch.nn.Module): The neural network model to analyze.
        height (int, optional): The height of the input image. Defaults to 512.
        width (int, optional): The width of the input image. Defaults to 1024.

    Returns:
        Dict[str, str]: A dictionary containing the number of parameters and FLOPs of the model.
    """
    
    # Create a dummy input tensor with the specified dimensions
    image = torch.zeros((1, 3, height, width))

    # Perform FLOP analysis on the model with the dummy input
    flops = FlopCountAnalysis(model.cpu(), image)
    
    # Generate a formatted table with the FLOP count results
    table = flop_count_table(flops)
    
    # Extract the number of parameters and FLOPs from the table
    n_param_table = table.split('\n')[2].split('|')[2].strip()
    flops_table = table.split('\n')[2].split('|')[3].strip()

    # Return the extracted values as a dictionary
    return {'Parameters': n_param_table,
            'FLOPS': flops_table
            }


def compute_latency_and_fps(model: torch.nn.Module, 
                            height: int = 512, 
                            width: int = 1024, 
                            iterations: int = 1000, 
                            device: str = 'cuda') ->  Dict[str, float]:
    """
    Computes the mean latency, standard deviation of latency, mean FPS, and standard deviation of FPS for a given model.

    Args:
        model (torch.nn.Module): The neural network model to evaluate.
        height (int, optional): The height of the input image. Defaults to 512.
        width (int, optional): The width of the input image. Defaults to 1024.
        iterations (int, optional): Number of iterations to measure latency and FPS. Defaults to 1000.
        device (str, optional): Device to run inference ('cpu' or 'cuda'). Defaults to 'cuda'.

    Returns:
        Dict[str, float]: Dictionary containing model latency and FPS information.
    """
    
    latencies = []
    fps_records = []
    
    model.eval()
    model = model.to(device)
    
    with torch.no_grad():
        for _ in range(iterations):
            # Create a dummy input tensor with the specified dimensions and move it to the device
            image = torch.zeros((1, 3, height, width)).to(device)
            
            # Measure the start time of the inference
            start_time = time.time()
            
            # Perform inference with the model
            model(image)
            
            # Measure the end time of the inference
            end_time = time.time() 
            
            # Calculate the latency in seconds and append it to the list
            latency = end_time - start_time
            latencies.append(latency)
            
            # Calculate the FPS and append it to the list
            fps_records.append(1 / latency)
    
    # Calculate mean and standard deviation of latency
    mean_latency = np.mean(latencies)
    std_latency = np.std(latencies)
    
    # Calculate mean and standard deviation of FPS
    mean_fps = np.mean(fps_records)
    std_fps = np.std(fps_records)

    return {'mean_latency': mean_latency,
            'std_latency': std_latency,
            'mean_fps': mean_fps,
            'std_fps':std_fps}

### Data processing

In [None]:
import numpy as np
from PIL import Image
import PIL
#from config import GTA
import albumentations as A

def get_color_to_id() -> dict:
    """
    Creates a dictionary mapping color representations to their corresponding IDs.

    Returns:
        dict: A dictionary where keys are color representations (RGB tuples) and values are IDs.
    """
    
    id_to_color = get_id_to_color()
    color_to_id = {color: id for id, color in id_to_color.items()}
    return color_to_id

def get_id_to_color() -> dict:
    """
    Returns a dictionary mapping class IDs to their corresponding RGB color representations.

    Returns:
        dict: A dictionary where keys are class IDs (integers) and values are RGB tuples.
    """

    return {
        0: (128, 64, 128),    # road
        1: (244, 35, 232),    # sidewalk
        2: (70, 70, 70),      # building
        3: (102, 102, 156),   # wall
        4: (190, 153, 153),   # fence
        5: (153, 153, 153),   # pole
        6: (250, 170, 30),    # light
        7: (220, 220, 0),     # sign
        8: (107, 142, 35),    # vegetation
        9: (152, 251, 152),   # terrain
        10: (70, 130, 180),   # sky
        11: (220, 20, 60),    # person
        12: (255, 0, 0),      # rider
        13: (0, 0, 142),      # car
        14: (0, 0, 70),       # truck
        15: (0, 60, 100),     # bus
        16: (0, 80, 100),     # train
        17: (0, 0, 230),      # motorcycle
        18: (119, 11, 32),    # bicycle
    }

def get_id_to_label() -> dict:
    """
    Returns a dictionary mapping class IDs to their corresponding semantic labels.

    Returns:
        dict: A dictionary where keys are class IDs (integers) and values are semantic labels (strings).
    """

    return {
        0: 'road',
        1: 'sidewalk',
        2: 'building',
        3: 'wall',
        4: 'fence',
        5: 'pole',
        6: 'light',
        7: 'sign',
        8: 'vegetation',
        9: 'terrain',
        10: 'sky',
        11: 'person',
        12: 'rider',
        13: 'car',
        14: 'truck',
        15: 'bus',
        16: 'train',
        17: 'motorcycle',
        18: 'bicycle',
        255: 'unlabeled'
    }

def label_to_rgb(label:np.ndarray)->PIL.Image:
    """
    Converts a 2D numpy array of class IDs (labels) into an RGB image.

    Args:
        label (np.ndarray): 2D numpy array containing class IDs.
    Returns:
        PIL.Image.Image: RGB image where each pixel corresponds to a color based on class ID.
    """
    
    id_to_color = get_id_to_color()
    color_image = np.zeros((label.shape[0], label.shape[1], 3), dtype=np.uint8)
    
    for class_id, color in id_to_color.items():
        color_image[label == class_id] = color
        
    # Set color to black for label 255
    color_image[label == 255] = (0, 0, 0)
    
    return Image.fromarray(color_image, 'RGB')


def get_augmented_data(augmentedType: str) -> A.Compose:
    """
    Returns an augmentation pipeline based on the specified `augmentedType`.

    Args:
        augmentedType (str): Type of augmentation pipeline to return.
            Possible values: 'transform1', 'transform2', 'transform3', 'transform4'.

    Returns:
        A.Compose: Augmentation pipeline defined using Albumentations library.
    """
    # Define different augmentation pipelines
    augmentations = {
        'transform1': A.Compose([
            A.Resize(GTA['height'], GTA['width']),
            A.HorizontalFlip(p=0.5),
            A.ColorJitter(p=0.5)
        ]),
        'transform2': A.Compose([
            A.Resize(GTA['height'], GTA['width']),
            A.HorizontalFlip(p=0.5),
            A.GaussianBlur(p=0.5)
        ]),
        'transform3': A.Compose([
            A.Resize(GTA['height'], GTA['width']),
            A.HorizontalFlip(p=0.5),
            A.GaussianBlur(p=0.5)
        ]),
        'transform4': A.Compose([
            A.Resize(GTA['height'], GTA['width']),
            A.HorizontalFlip(p=0.5),
            A.GaussianBlur(p=0.5),
            A.ColorJitter(p=0.5),
            A.RandomResizedCrop(height=GTA['height'], 
                                width=GTA['width'], 
                                scale=(0.5, 1.0), 
                                ratio=(0.75, 1.333), 
                                p=0.5)
        ]),
    }
    
    # Return the specified augmentation pipeline if it exists
    if augmentedType in ['transform1', 'transform2', 'transform3', 'transform4']:
        return augmentations[augmentedType]
    else:
        print('Transformation accepted: [transform1, transform2, transform3, transform4]')
        return A.Compose([
            A.Resize(GTA['height'], GTA['width']),
        ])

### Poly Lr

In [None]:
import torch

def poly_lr_scheduler(optimizer: torch.optim.Optimizer, 
                      init_lr: float, 
                      iter: int, 
                      max_iter:int = 300, 
                      power: float = 0.9) -> float:
    """
    Polynomial learning rate scheduler.

    Args:
        optimizer (torch.optim.Optimizer): Optimizer object.
        init_lr (float): Initial learning rate.
        iter (int): Current iteration number.
        max_iter (int, optional): Maximum number of iterations (default is 300).
        power (float, optional): Power factor (default is 0.9).

    Returns:
        float: Updated learning rate.
    """

    # Calculate the learning rate using the polynomial decay formula
    lr = init_lr * (1 - iter / max_iter) ** power

    # Update the learning rate in the optimizer
    optimizer.param_groups[0]['lr'] = lr

### Statistics

In [None]:
import numpy as np

def fast_hist(a, b, n):
    '''
    a and b are label and prediction respectively
    n is the number of classes
    '''
    k = (a >= 0) & (a < n)
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)

def per_class_iou(hist):
    epsilon = 1e-5
    return (np.diag(hist)) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)


### Visualization

In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np
import os
import random
from collections import OrderedDict
from typing import Tuple
#from utils import get_id_to_label, label_to_rgb
#from datasets import CityScapes
#from models import get_deeplab_v2, BiSeNet
#from config import CHECKPOINT_ROOT, CITYSCAPES_PATH, DEEPLABV2_PATH

def print_stats(epoch:int, 
                train_loss:float,
                val_loss:float, 
                train_miou:float, 
                val_miou:float, 
                verbose:bool)->None:
    """
    Print training and validation statistics if verbose is True.

    Args:
        epoch (int): Current epoch number.
        train_loss (float): Training loss value.
        val_loss (float): Validation loss value.
        train_miou (float): Training mean IoU value.
        val_miou (float): Validation mean IoU value.
        verbose (bool): Flag to control verbosity. If False, no output is printed.

    Returns:
        None
    """
    if verbose:
        print(f'Epoch: {epoch}')
        print(f'\tTrain Loss: {train_loss}, Validation Loss: {val_loss}')
        print(f'\tTrain mIoU: {train_miou}, Validation mIoU: {val_miou}')
    
def plot_loss(model_results:list, 
              model_name:str, 
              project_step:str, 
              train_dataset:str, 
              validation_dataset:str)->None:
    """
    Plot and save the training and validation loss curves.

    Args:
        model_results (list): List of model results containing training and validation losses.
        model_name (str): Name of the model.
        project_step (str): Project step or phase.
        train_dataset (str): Name of the training dataset.
        validation_dataset (str): Name of the validation dataset.

    Returns:
        None
    """
    
    epochs = range(len(model_results[0]))
    train_losses = model_results[0]
    validation_losses = model_results[1]

    # Create the plot
    fig, ax = plt.subplots(figsize=(12, 6), dpi=300)  
    ax.set_title(f'Train vs. Validation Loss for {model_name}', fontsize=14, fontweight='bold')
    ax.set_xlabel('Epoch', fontsize=14)
    ax.set_ylabel('Loss', fontsize=14)
    ax.plot(epochs, train_losses, 'o-', color='tab:blue', label=f"train loss ({train_dataset})", linewidth=2, markersize=5)
    ax.plot(epochs, validation_losses, '^-', color='tab:red', label=f"validation loss ({validation_dataset})", linewidth=2, markersize=5)
    ax.legend(loc='upper right', fontsize=12, frameon=True, shadow=True)
    ax.grid(True, which='both', linewidth=0.5)
    ax.tick_params(axis='both', which='major', labelsize=12)

    plt.tight_layout()
    plt.show()
    
    # Save the plot
    checkpoint_path = f'{CHECKPOINT_ROOT}/{project_step}'
    if os.path.exists(checkpoint_path):
        fig.savefig(f"{checkpoint_path}/{model_name}_{project_step}_loss.png", format='png')
    else:
        os.makedirs(checkpoint_path)
        fig.savefig(f"{checkpoint_path}/{model_name}_{project_step}_loss.png", format='png')
     

    
def plot_miou(model_results:list, 
              model_name:str, 
              project_step:str, 
              train_dataset:str, 
              validation_dataset:str) -> None:
    """
    Plot and save the training and validation mIoU curves.

    Args:
        model_results (list): List of model results containing training and validation mIoU values.
        model_name (str): Name of the model.
        project_step (str): Project step or phase.
        train_dataset (str): Name of the training dataset.
        validation_dataset (str): Name of the validation dataset.

    Returns:
        None
    """
    epochs = range(len(model_results[2]))
    train_mIoU = model_results[2]
    validation_mIoU = model_results[3]

    # Create the plot
    fig, ax = plt.subplots(figsize=(12, 6), dpi=300)  
    ax.set_title(f'Train vs. Validation mIoU for {model_name} over Epochs', fontsize=14, fontweight='bold')
    ax.set_xlabel('Epoch', fontsize=14)
    ax.set_ylabel('Mean Intersection over Union (mIoU)', fontsize=14)
    ax.plot(epochs, train_mIoU, 'o-', color='tab:blue', label=f"train mIoU ({train_dataset})", linewidth=2, markersize=5)
    ax.plot(epochs, validation_mIoU, '^-', color='tab:red', label=f"validation mIoU ({validation_dataset})", linewidth=2, markersize=5)
    ax.legend(loc='upper left', fontsize=12, frameon=True, shadow=True)
    ax.grid(True, which='both', linewidth=0.5)
    ax.tick_params(axis='both', which='major', labelsize=12)

    plt.tight_layout()
    plt.show()

    # Save the plot
    checkpoint_path = f'{CHECKPOINT_ROOT}/{project_step}'
    if os.path.exists(checkpoint_path):
        fig.savefig(f"{checkpoint_path}/{model_name}_{project_step}_miou.png", format='png')
    else:
        os.makedirs(checkpoint_path)
        fig.savefig(f"{checkpoint_path}/{model_name}_{project_step}_miou.png", format='png')

def plot_iou(model_results:list, 
              model_name:str, 
              project_step:str, 
              train_dataset:str, 
              validation_dataset:str) -> None:
    """
    Plot and save the IoU (Intersection over Union) for each class across training and validation phases.

    Args:
        model_results (list): List of model results containing training and validation IoU values for each class.
                              It should contain two lists:
                              - model_results[4]: List of training IoU values for each class.
                              - model_results[5]: List of validation IoU values for each class.
        model_name (str): Name of the model.
        project_step (str): Project step or phase.
        train_dataset (str): Name of the training dataset.
        validation_dataset (str): Name of the validation dataset.

    Returns:
        None
    """
    num_classes = 19
    class_names = [get_id_to_label()[i] for i in range(num_classes)]
    train_iou = [model_results[4][i] for i in range(num_classes)]
    val_iou = [model_results[5][i] for i in range(num_classes)]

    fig, ax = plt.subplots(figsize=(12, 6), dpi=300)  
    bar_width = 0.35
    index = np.arange(num_classes)

    ax.bar(index, train_iou, bar_width, label=f'train IoU ({train_dataset})', color='tab:blue', alpha=0.7)
    ax.bar(index + bar_width, val_iou, bar_width, label=f'validation IoU ({validation_dataset})', color='tab:red', alpha=0.7)

    ax.set_xlabel('Classes', fontsize=14)
    ax.set_ylabel('IoU', fontsize=14)
    ax.set_title(f'Training and Validation IoU for Each Class ({model_name})', fontsize=16, fontweight='bold')
    ax.set_xticks(index + bar_width / 2)
    ax.set_xticklabels(class_names, rotation=45, ha="right", fontsize=12)
    ax.legend(loc='upper right', fontsize=12, frameon=True, shadow=True)
    ax.grid(True, which='both', linewidth=0.5, axis='y')

    plt.tight_layout()
    plt.show()

    checkpoint_path = f'{CHECKPOINT_ROOT}/{project_step}'
    if os.path.exists(checkpoint_path):
        fig.savefig(f"{checkpoint_path}/{model_name}_{project_step}_iou.png", format='png')
    else:
        os.makedirs(checkpoint_path)
        fig.savefig(f"{checkpoint_path}/{model_name}_{project_step}_iou.png", format='png')

def plot_segmented_images(model_roots: list,
                          model_types: list[Tuple],
                          n_images: int = 5,
                          device: str = "cpu") -> None:
    """Visualizes the segmentation results of multiple models on multiple random Cityscapes validation images.

    Args:
        model_roots (list): List of paths to the model checkpoints.
        model_types (list): List of model types (e.g., 'DeepLabV2' or 'BiSeNet').
        n_images (int): Number of random images to visualize.
        device (str): Device to run the model on ('cpu' or 'cuda').
    """

    # Load the Cityscapes validation dataset
    cityscapes_dataset = CityScapes(root_dir=CITYSCAPES_PATH, split='val')

    # Select n_images random images from the validation set
    selected_indices = random.sample(range(len(cityscapes_dataset)), n_images)
    selected_images = [cityscapes_dataset[i][0] for i in selected_indices]
    ground_truths = [cityscapes_dataset[i][1] for i in selected_indices]

    # Initialize the models and load their checkpoints
    models = []
    for model_root, model_type in zip(model_roots, model_types):
        checkpoint = torch.load(model_root, map_location=torch.device(device))
        
        if model_type[0] == 'DeepLabV2':
            model = get_deeplab_v2(num_classes=19, pretrain=True, pretrain_model_path=DEEPLABV2_PATH).to(device)
        else:
            model = BiSeNet(num_classes=19, context_path="resnet18").to(device)
        
        try:
            model.load_state_dict(checkpoint["model"])
            print(f"{model_type[0]} model loaded successfully.")
        except RuntimeError as e:
            print(f"Error: Failed to load {model_type[0]} model state dictionary with error: {e}")
            print(f"Attempting to adjust the state dictionary for {model_type[0]}...")
            new_state_dict = OrderedDict()
            
            for k, v in checkpoint['model'].items():
                if k.startswith("module"):
                    name = k[7:]  # remove "module." prefix
                else:
                    name = k
                
                new_state_dict[name] = v

            model.load_state_dict(new_state_dict)
            print(f"Adjusted state dictionary for {model_type[0]} loaded successfully.")
        model.eval()
        models.append(model)
    
    # Generate the segmented images for each selected image and model
    outputs = []
    with torch.no_grad():
        for image in selected_images:
            model_outputs = []
            for model in models:
                output = model(image.unsqueeze(0))
                output = torch.argmax(torch.softmax(output, dim=1), dim=1)
                output = np.squeeze(output)
                segmented_image = label_to_rgb(output)
                model_outputs.append(segmented_image)
            outputs.append(model_outputs)

    # Convert the original and ground truth images to numpy arrays for plotting
    selected_images_np = [(image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for image in selected_images]
    ground_truths_np = [label_to_rgb(np.squeeze(gt.unsqueeze(0))) for gt in ground_truths]
    
    # Plot the images
    _, axes = plt.subplots(n_images, len(models) + 2, figsize=(23, 2 * n_images))
    
    for row in range(n_images):
        axes[row, 0].imshow(selected_images_np[row])
        axes[0, 0].set_title("Target Image", fontsize=16, fontweight='bold')
        axes[row, 0].axis("off")

        axes[row, 1].imshow(ground_truths_np[row])
        axes[0, 1].set_title("Ground Truth", fontsize=16, fontweight='bold')
        axes[row, 1].axis("off")

        for col, (output, model_type) in enumerate(zip(outputs[row], model_types), start=2):
            axes[row, col].imshow(output)
            axes[0, col].set_title(model_type[1], fontsize=16, fontweight='bold')
            axes[row, col].axis("off")
    
    plt.tight_layout()
    plt.show()

## 4 - Pipeline

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as A
from tqdm import tqdm
from itertools import cycle
from typing import Tuple, List, Union
#from config import CITYSCAPES, GTA, DEEPLABV2_PATH, CITYSCAPES_PATH, GTA5_PATH
#from datasets import CityScapes, GTA5
#from models import BiSeNet, get_deeplab_v2, FCDiscriminator
#from utils import *
import warnings
warnings.filterwarnings("ignore")
torch.cuda.manual_seed(42)


def get_core(model_name: str, 
             n_classes: int,
             device: str,
             parallelize: bool,
             optimizer_name: str, 
             lr: float,
             momentum: float,
             weight_decay: float,
             loss_fn_name: str,
             ignore_index: int,
             adversarial: bool) -> Tuple[torch.nn.Module, torch.optim.Optimizer, torch.nn.Module, torch.nn.Module, torch.optim.Optimizer, torch.nn.Module]:
    """
    Set up components for semantic segmentation model training.

    Args:
    - model_name (str): Name of the segmentation model ('DeepLabV2' or 'BiSeNet').
    - n_classes (int): Number of classes in the dataset.
    - device (str): Device to run the model on ('cpu' or 'cuda').
    - parallelize (bool): Whether to use DataParallel for multi-GPU training.
    - optimizer_name (str): Name of the optimizer ('Adam' or 'SGD').
    - lr (float): Learning rate for the optimizer.
    - momentum (float): Momentum factor for SGD optimizer.
    - weight_decay (float): Weight decay (L2 penalty) for the optimizer.
    - loss_fn_name (str): Name of the loss function ('CrossEntropyLoss').
    - ignore_index (int): Index to ignore in loss computation.
    - adversarial (bool): Whether to include adversarial training components.

    Raises:
    - ValueError: If an invalid model_name, optimizer_name, or loss_fn_name is provided.

    Returns:
    - Tuple containing:
        - model (nn.Module): Segmentation model.
        - optimizer (torch.optim.Optimizer): Optimizer for the segmentation model.
        - loss_fn (nn.Module): Loss function for the segmentation model.
        - model_D (nn.Module or None): Discriminator model for adversarial training (if adversarial=True).
        - optimizer_D (torch.optim.Optimizer or None): Optimizer for the discriminator model (if adversarial=True).
        - loss_D (nn.Module or None): Loss function for the discriminator model (if adversarial=True).
    """
    
    model = None
    optimizer = None
    loss_fn = None
    model_D = None
    optimizer_D = None
    loss_D = None
    
    # Initialize segmentation model based on model_name
    if model_name == 'DeepLabV2':
        model = get_deeplab_v2(num_classes=n_classes, pretrain=True, pretrain_model_path=DEEPLABV2_PATH).to(device)
        if parallelize and device == 'cuda' and torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model).to(device)
    elif model_name == 'BiSeNet':
        model = BiSeNet(num_classes=n_classes, context_path="resnet18").to(device)
        if parallelize and device == 'cuda' and torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model).to(device)
    else:
        raise ValueError('Model accepted: [DeepLabV2, BiSeNet]')
            
    # Initialize optimizer based on optimizer_name
    if optimizer_name == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    elif optimizer_name == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    else:
        raise ValueError('Optimizer accepted: [Adam, SGD]')
        
    # Initialize loss function based on loss_fn_name
    if loss_fn_name == 'CrossEntropyLoss':
        loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
    else:
        raise ValueError('Loss function accepted: [CrossEntropyLoss]')
    
    # Initialize adversarial components if adversarial is True
    if adversarial:
        model_D = FCDiscriminator(num_classes=n_classes).to(device)
        if parallelize and device == 'cuda' and torch.cuda.device_count() > 1:
            model_D = torch.nn.DataParallel(model_D).to(device)
        optimizer_D = torch.optim.Adam(model_D.parameters(), lr=1e-3, betas=(0.9, 0.99))
        loss_D = torch.nn.BCEWithLogitsLoss()
        
    return model, optimizer, loss_fn, model_D, optimizer_D, loss_D

def get_loaders(train_dataset_name: str, 
                val_dataset_name: str, 
                augmented: bool,
                augmentedType: str,
                batch_size: int,
                n_workers: int,
                adversarial: bool) -> Tuple[Union[DataLoader, Tuple[DataLoader, DataLoader]], DataLoader, int, int]:
    """
    Set up data loaders for training and validation datasets in semantic segmentation.

    Args:
    - train_dataset_name (str): Name of the training dataset ('CityScapes' or 'GTA5').
    - val_dataset_name (str): Name of the validation dataset ('CityScapes').
    - augmented (bool): Whether to use augmented data.
    - augmentedType (str): Type of augmentation to apply (specific to your implementation).
    - batch_size (int): Batch size for data loaders.
    - n_workers (int): Number of workers for data loading.
    - adversarial (bool): Whether to set up adversarial training data loaders.

    Raises:
    - ValueError: If an invalid train_dataset_name or val_dataset_name is provided.

    Returns:
    - Tuple containing:
        - train_loader (Union[DataLoader, Tuple[DataLoader, DataLoader]]): DataLoader(s) for the training dataset.
        - val_loader (DataLoader): DataLoader for the validation dataset.
        - data_height (int): Height of the dataset images.
        - data_width (int): Width of the dataset images.
    """

    transform_cityscapes = A.Compose([
        A.Resize(CITYSCAPES['height'], CITYSCAPES['width']),
    ])
    transform_gta5 = A.Compose([
        A.Resize(GTA['height'], GTA['width'])
    ])

    train_loader = None
    val_loader = None
    data_height = None
    data_width = None
    
    if augmented:
        transform_gta5 = get_augmented_data(augmentedType)
    
    if adversarial:
        source_dataset = GTA5(root_dir=GTA5_PATH, transform=transform_gta5)
        target_dataset = CityScapes(root_dir=CITYSCAPES_PATH, split='train', transform=transform_cityscapes)

        source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers)
        target_loader = DataLoader(target_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers)

        train_loader = (source_loader, target_loader)
    else:
        if train_dataset_name == 'CityScapes':
            train_dataset = CityScapes(root_dir=CITYSCAPES_PATH, split='train', transform=transform_cityscapes)
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers)
        elif train_dataset_name == 'GTA5':
            train_dataset = GTA5(root_dir=GTA5_PATH, transform=transform_gta5)
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers)
        else:
            raise ValueError('Train datasets accepted: [CityScapes, GTA5]')
        
    if val_dataset_name == 'CityScapes':
        val_dataset = CityScapes(root_dir=CITYSCAPES_PATH, split='val', transform=transform_cityscapes)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers)
        data_height = CITYSCAPES['height']
        data_width = CITYSCAPES['width']
    else:
        raise ValueError('Val datasets accepted: [CityScapes]')
    
    return train_loader, val_loader, data_height, data_width

def adversarial_train_step(model: torch.nn.Module, 
                           model_D: torch.nn.Module, 
                           loss_fn: torch.nn.Module, 
                           loss_D: torch.nn.Module, 
                           optimizer: torch.optim.Optimizer, 
                           optimizer_D: torch.optim.Optimizer, 
                           loaders: Tuple[DataLoader,DataLoader], 
                           device: str, 
                           n_classes: int = 19)-> Tuple[float, float, float]:
    """
    Perform a single adversarial training step for semantic segmentation.

    Args:
    - model (torch.nn.Module): Segmentation model.
    - model_D (torch.nn.Module): Discriminator model.
    - loss_fn (torch.nn.Module): Segmentation loss function.
    - loss_D (torch.nn.Module): Adversarial loss function for discriminator.
    - optimizer (torch.optim.Optimizer): Optimizer for segmentation model.
    - optimizer_D (torch.optim.Optimizer): Optimizer for discriminator model.
    - loaders (Tuple[DataLoader,DataLoader]): Source and target dataloaders for training data.
    - device (str): Device on which to run the models ('cuda' or 'cpu').
    - n_classes (int, optional): Number of classes for segmentation. Default is 19.

    Returns:
    - Tuple containing:
        - epoch_loss (float): Average segmentation loss for the epoch.
        - epoch_miou (float): Mean Intersection over Union (mIoU) for the epoch.
        - epoch_iou (np.ndarray): Array of per-class IoU values for the epoch.
    """

    model_G = model.to(device)
    optimizer_G = optimizer
    ce_loss = loss_fn
    bce_loss = loss_D
    
    interp_source = nn.Upsample(size=(GTA['height'], GTA['width']), mode='bilinear')
    interp_target = nn.Upsample(size=(CITYSCAPES['height'], CITYSCAPES['width']), mode='bilinear')
    
    lambda_adv = 0.001
    total_loss = 0
    total_miou = 0
    total_iou = np.zeros(n_classes)
    
    iterations = 0
    
    model_G.train()
    model_D.train()
    
    source_loader, target_loader = loaders
    train_loader = zip(source_loader, cycle(target_loader))
    
    
    for (source_data, source_labels), (target_data, _) in train_loader:
        
        iterations+=1

        source_data, source_labels = source_data.to(device), source_labels.to(device)
        target_data = target_data.to(device)
        
        optimizer_G.zero_grad()
        optimizer_D.zero_grad()

        #TRAIN GENERATOR
        
        #Train with source
        for param in model_D.parameters():
            param.requires_grad = False
        
        output_source = model_G(source_data)
        output_source = interp_source(output_source) # apply upsample

        segmentation_loss = ce_loss(output_source, source_labels)
        segmentation_loss.backward()

        #Train with target
        output_target = model_G(target_data)
        output_target = interp_target(output_target) # apply upsample
        
        prediction_target = torch.nn.functional.softmax(output_target)
        discriminator_output_target = model_D(prediction_target)
        discriminator_label_source = torch.FloatTensor(discriminator_output_target.data.size()).fill_(0).cuda()
        
        adversarial_loss = bce_loss(discriminator_output_target, discriminator_label_source)
        discriminator_loss = lambda_adv * adversarial_loss
        discriminator_loss.backward()
        
        
        #TRAIN DISCRIMINATOR
        
        #Train with source
        for param in model_D.parameters():
            param.requires_grad = True
            
        output_source = output_source.detach()
        
        prediction_source = torch.nn.functional.softmax(output_source)
        discriminator_output_source = model_D(prediction_source)
        discriminator_label_source = torch.FloatTensor(discriminator_output_source.data.size()).fill_(0).cuda()
        discriminator_loss_source = bce_loss(discriminator_output_source, discriminator_label_source)
        discriminator_loss_source.backward()

        #Train with target
        output_target = output_target.detach()
        
        prediction_target = torch.nn.functional.softmax(output_target)
        discriminator_output_target = model_D(prediction_target)
        discriminator_label_target = torch.FloatTensor(discriminator_output_target.data.size()).fill_(1).cuda()
        
        discriminator_loss_target = bce_loss(discriminator_output_target, discriminator_label_target)
        discriminator_loss_target.backward()
        
        optimizer_G.step()
        optimizer_D.step()
        
        total_loss += segmentation_loss.item()
        
        prediction_source = torch.argmax(torch.softmax(output_source, dim=1), dim=1)
        hist = fast_hist(source_labels.cpu().numpy(), prediction_source.cpu().numpy(), n_classes)
        running_iou = np.array(per_class_iou(hist)).flatten()
        total_miou += running_iou.sum()
        total_iou += running_iou

        
    epoch_loss = total_loss / iterations
    epoch_miou = total_miou / (iterations * n_classes)
    epoch_iou = total_iou / iterations
    
    return epoch_loss, epoch_miou, epoch_iou

def train_step(model: torch.nn.Module, 
               loss_fn: torch.nn.Module, 
               optimizer: torch.optim.Optimizer, 
               dataloader: DataLoader, 
               device: str, 
               n_classes: int = 19)-> Tuple[float, float, float]:
    """
    Perform a single training step for semantic segmentation.

    Args:
    - model (torch.nn.Module): Segmentation model.
    - loss_fn (torch.nn.Module): Loss function for segmentation.
    - optimizer (torch.optim.Optimizer): Optimizer for training.
    - dataloader (DataLoader): DataLoader for training data.
    - device (str): Device on which to run the models ('cuda' or 'cpu').
    - n_classes (int, optional): Number of classes for segmentation. Default is 19.

    Returns:
    - Tuple containing:
        - epoch_loss (float): Average segmentation loss for the epoch.
        - epoch_miou (float): Mean Intersection over Union (mIoU) for the epoch.
        - epoch_iou (np.ndarray): Array of per-class IoU values for the epoch.
    """

    total_loss = 0
    total_miou = 0
    total_iou = np.zeros(n_classes)
    
    model.train()
    
    for image, label in dataloader:
        image, label = image.to(device), label.type(torch.LongTensor).to(device)
    
        output = model(image)
        loss = loss_fn(output, label)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        prediction = torch.argmax(torch.softmax(output, dim=1), dim=1)

        hist = fast_hist(label.cpu().numpy(), prediction.cpu().numpy(), n_classes)
        running_iou = np.array(per_class_iou(hist)).flatten()
        total_miou += running_iou.sum()
        total_iou += running_iou
    
    epoch_loss = total_loss / len(dataloader)
    epoch_miou = total_miou / (len(dataloader)* n_classes)
    epoch_iou = total_iou / len(dataloader)
    
    return epoch_loss, epoch_miou, epoch_iou

def val_step(model: torch.nn.Module,  
             loss_fn: torch.nn.Module, 
             dataloader: DataLoader, 
             device: str, 
             n_classes: int = 19) -> Tuple[float, float, float]:
    """
    Perform a single validation step for semantic segmentation.

    Args:
    - model (torch.nn.Module): Segmentation model.
    - loss_fn (torch.nn.Module): Loss function for segmentation.
    - dataloader (DataLoader): DataLoader for validation data.
    - device (str): Device on which to run the models ('cuda' or 'cpu').
    - n_classes (int, optional): Number of classes for segmentation. Default is 19.

    Returns:
    - Tuple containing:
        - epoch_loss (float): Average segmentation loss for the epoch.
        - epoch_miou (float): Mean Intersection over Union (mIoU) for the epoch.
        - epoch_iou (np.ndarray): Array of per-class IoU values for the epoch.
    """
    
    total_loss = 0
    total_miou = 0
    total_iou = np.zeros(n_classes)
    
    model.eval()

    with torch.inference_mode(): # which is analogous to torch.no_grad
        for image, label in dataloader:
            image, label = image.to(device), label.type(torch.LongTensor).to(device)
            
            output = model(image)
            loss = loss_fn(output, label)
            total_loss += loss.item()
            
            prediction = torch.argmax(torch.softmax(output, dim=1), dim=1)
            
            hist = fast_hist(label.cpu().numpy(), prediction.cpu().numpy(), n_classes)
            running_iou = np.array(per_class_iou(hist)).flatten()
            total_miou += running_iou.sum()
            total_iou += running_iou
    
    epoch_loss = total_loss / len(dataloader)
    epoch_miou = total_miou / (len(dataloader)* n_classes)
    epoch_iou = total_iou / len(dataloader)
    
    return epoch_loss, epoch_miou, epoch_iou
    
def train(model: torch.nn.Module, 
          model_D: torch.nn.Module, 
          optimizer: torch.optim.Optimizer, 
          optimizer_D: torch.optim.Optimizer, 
          loss_fn: torch.nn.Module, 
          loss_D: torch.nn.Module, 
          train_loader: Union[DataLoader, Tuple[DataLoader,DataLoader]],  
          val_loader: DataLoader, 
          epochs: int, 
          device: str, 
          checkpoint_root: str,
          project_step: str,
          verbose: bool,
          n_classes: int = 19,
          power: float = 0.9,
          adversarial: bool = False) -> Tuple[List[float], List[float], List[float], List[float], List[float], List[float]]:
    """
    Train a semantic segmentation model with optional adversarial training.

    Args:
        model (torch.nn.Module): Semantic segmentation model.
        model_D (torch.nn.Module): Discriminator model for adversarial training.
        optimizer (torch.optim.Optimizer): Optimizer for the segmentation model.
        optimizer_D (torch.optim.Optimizer): Optimizer for the discriminator.
        loss_fn (torch.nn.Module): Loss function for segmentation.
        loss_D (torch.nn.Module): Loss function for adversarial training.
        train_loader (Union[DataLoader, Tuple[DataLoader,DataLoader]]): DataLoader(s) for the training dataset.
        val_loader (DataLoader): DataLoader for validation data.
        epochs (int): Number of epochs to train.
        device (str): Device on which to run computations ('cuda' or 'cpu').
        checkpoint_root (str): Root directory to save checkpoints.
        project_step (str): Name/id of the project or step.
        verbose (bool): Whether to print verbose training statistics.
        n_classes (int, optional): Number of classes for segmentation. Defaults to 19.
        power (float, optional): Power parameter for learning rate scheduler. Defaults to 0.9.
        adversarial (bool, optional): Whether to use adversarial training. Defaults to False.

    Returns:
        Tuple containing lists of:
        - train_loss_list (List[float]): List of training losses per epoch.
        - val_loss_list (List[float]): List of validation losses per epoch.
        - train_miou_list (List[float]): List of training mIoU per epoch.
        - val_miou_list (List[float]): List of validation mIoU per epoch.
        - train_iou (List[float]): List of per-class IoU for training per epoch.
        - val_iou (List[float]): List of per-class IoU for validation per epoch.
    """
    
    # Load or initialize checkpoint
    no_checkpoint, start_epoch, train_loss_list, train_miou_list, train_iou, val_loss_list, val_miou_list, val_iou = load_checkpoint(checkpoint_root=checkpoint_root, project_step=project_step, adversarial=adversarial, model=model, model_D=model_D, optimizer=optimizer, optimizer_D=optimizer_D)
        
    if no_checkpoint:
        train_loss_list, train_miou_list = [], []
        val_loss_list, val_miou_list = [], []
        start_epoch = 0
    
    for epoch in tqdm(range(start_epoch, epochs)):
        
        # Perform training step
        if adversarial:
            train_loss, train_miou, train_iou = adversarial_train_step(model=model,
                                                                       model_D=model_D,
                                                                       loss_fn=loss_fn, 
                                                                       loss_D=loss_D, 
                                                                       optimizer=optimizer, 
                                                                       optimizer_D=optimizer_D, 
                                                                       loaders=train_loader, 
                                                                       device=device, 
                                                                       n_classes=n_classes)
        else:
            train_loss, train_miou, train_iou = train_step(model=model, 
                                                           loss_fn=loss_fn, 
                                                           optimizer=optimizer, 
                                                           train_loader=train_loader, 
                                                           device=device, 
                                                           n_classes=n_classes)
        
        # Perform validation step
        val_loss, val_miou, val_iou = val_step(model=model, 
                                               loss_fn=loss_fn, 
                                               val_loader=val_loader,
                                               device=device, 
                                               n_classes=n_classes)
        
        # Append metrics to lists
        train_loss_list.append(train_loss) 
        train_miou_list.append(train_miou) 
        val_loss_list.append(val_loss)
        val_miou_list.append(val_miou)

        # Print statistics if verbose
        print_stats(epoch=epoch, 
                    train_loss=train_loss,
                    val_loss=val_loss, 
                    train_miou=train_miou, 
                    val_miou=val_miou, 
                    verbose=verbose)

        # Adjust learning rate
        poly_lr_scheduler(optimizer=optimizer,
                          init_lr=optimizer.param_groups[0]['lr'],
                          iter=epoch, 
                          max_iter=epochs,
                          power=power)
        if adversarial:
            poly_lr_scheduler(optimizer=optimizer_D,
                              init_lr=optimizer_D.param_groups[0]['lr'],
                              iter=epoch, 
                              max_iter=epochs,
                              power=power)
        
        # Save checkpoint after each epoch
        save_checkpoint(checkpoint_root=checkpoint_root, 
                        project_step=project_step,
                        adversarial=adversarial,
                        model=model, 
                        model_D=model_D,
                        optimizer=optimizer, 
                        optimizer_D=optimizer_D, 
                        epoch=epoch,
                        train_loss_list=train_loss_list, 
                        train_miou_list=train_miou_list,
                        train_iou=train_iou,
                        val_loss_list=val_loss_list,
                        val_miou_list=val_miou_list,
                        val_iou=val_iou,
                        verbose=verbose)
        
    return train_loss_list, val_loss_list, train_miou_list, val_miou_list, train_iou, val_iou
    
def pipeline (model_name: str, 
              train_dataset_name: str, 
              val_dataset_name: str,
              n_classes:int,
              epochs: int,
              augmented: bool,
              augmentedType:str,
              optimizer_name: str,
              lr:float,
              momentum:float,
              weight_decay:float,
              loss_fn_name: str,
              ignore_index:int,
              batch_size: int,
              n_workers: int,
              device:str,
              parallelize:bool,
              project_step:str,
              verbose: bool,
              checkpoint_root:str,
              power:float,
              evalIterations:int,
              adversarial:bool
              )->None:
    """
    Main pipeline function to orchestrate the training and evaluation of a deep learning model.

    Args:
        model_name (str): Name of the deep learning model architecture.
        train_dataset_name (str): Name of the training dataset.
        val_dataset_name (str): Name of the validation dataset.
        n_classes (int): Number of classes in the dataset.
        epochs (int): Number of epochs for training.
        augmented (bool): Whether to use data augmentation during training.
        augmentedType (str): Type of data augmentation to apply.
        optimizer_name (str): Name of the optimizer to use.
        lr (float): Learning rate for the optimizer.
        momentum (float): Momentum factor for optimizers like SGD.
        weight_decay (float): Weight decay (L2 penalty) for the optimizer.
        loss_fn_name (str): Name of the loss function.
        ignore_index (int): Index to ignore in the loss function (e.g., for padding).
        batch_size (int): Batch size for training and validation data loaders.
        n_workers (int): Number of workers for data loading.
        device (str): Device to run the model on ('cuda' or 'cpu').
        parallelize (bool): Whether to use GPU parallelization.
        project_step (str): Name or identifier of the current project step or experiment.
        verbose (bool): Whether to print detailed logs during training.
        checkpoint_root (str): Root directory to save checkpoints and results.
        power (float): Power parameter for polynomial learning rate scheduler.
        evalIterations (int): Number of iterations for evaluating model latency and FPS.
        adversarial (bool): Whether to use adversarial training.

    Returns:
        None
    """
    
    
    # get model
    model, optimizer, loss_fn, model_D, optimizer_D, loss_D = get_core(model_name, 
                                                                       n_classes,
                                                                       device,
                                                                       parallelize,
                                                                       optimizer_name, 
                                                                       lr,
                                                                       momentum,
                                                                       weight_decay,
                                                                       loss_fn_name,
                                                                       ignore_index,
                                                                       adversarial)
    # get loader
    train_loader, val_loader, data_height, data_width = get_loaders(train_dataset_name, 
                                                                    val_dataset_name, 
                                                                    augmented,
                                                                    augmentedType,
                                                                    batch_size,
                                                                    n_workers,
                                                                    adversarial)
    # train
    model_results = train(model=model,
                          model_D = model_D,
                          optimizer=optimizer, 
                          optimizer_D = optimizer_D,
                          loss_fn = loss_fn, 
                          loss_D = loss_D,
                          train_loader=train_loader, 
                          val_loader=val_loader, 
                          epochs=epochs, 
                          device=device, 
                          checkpoint_root=checkpoint_root,
                          project_step=project_step,
                          verbose=verbose,
                          n_classes=n_classes,
                          power=power,
                          adversarial=adversarial)
    
    # evaluation
    model_params_flops = compute_flops(model=model, 
                                       height=data_height, 
                                       width=data_width)
    
    model_latency_fps = compute_latency_and_fps(model=model,
                                                height=data_height, 
                                                width=data_width, 
                                                iterations=evalIterations, 
                                                device=device)
    
    # visualization
    plot_loss(model_results, 
              model_name, 
              project_step, 
              train_dataset_name, 
              val_dataset_name)
    
    plot_miou(model_results, 
              model_name, 
              project_step, 
              train_dataset_name, 
              val_dataset_name)
    
    plot_iou(model_results, 
             model_name, 
             project_step, 
             train_dataset_name, 
             val_dataset_name)
    
    # save results
    save_results(model_results, 
                 filename=f"{model_name}_metrics_{project_step}", 
                 project_step=project_step,
                 model_params_flops=model_params_flops,
                 model_latency_fps=model_latency_fps)

## 5 - Steps

In [None]:
CITYSCAPES_PATH = '/kaggle/input/cityscapes/Cityscapes/Cityspaces'
GTA5_PATH = '/kaggle/input/gta5-with-mask/GTA5_with_mask'# GTA5 with mask
DEEPLABV2_PATH = '/kaggle/input/deeplab_v2_model/pytorch/model_weight/1/deeplab_resnet_pretrained_imagenet.pth'#'models/deeplab_resnet_pretrained_imagenet.pth'
CHECKPOINT_ROOT = '/kaggle/working'

CITYSCAPES = {
    'width': 1024, 
    'height': 512
}
GTA = {
    'width': 1280, 
    'height': 720
}

### Step 2.1

#### SGD

In [None]:
# for kaggle:
config = {
    'model_name': 'DeepLabV2', # [DeepLabV2, BiSeNet]
    'train_dataset_name': 'CityScapes', # [CityScapes, GTA5]
    'val_dataset_name': 'CityScapes', # [CityScapes]
    'n_classes': 19,
    'epochs': 50,
    'augmented': False,
    'augmentedType': 'transform1', # [transform1,transform2,transform3,transform4]
    'optimizer_name': 'SGD', # [SGD, Adam] # TRY BOTH
    'lr': 1e-3, 
    'momentum': 0.9,
    'weight_decay': 5e-4,
    'loss_fn_name': 'CrossEntropyLoss', # [CrossEntropyLoss]
    'ignore_index': 255,
    'batch_size': 4, # [2,4,8]
    'n_workers': 4, # [0,2,4]
    'device': 'cuda',
    'parallelize': True,
    'project_step': 'Step2_1_SGD_1e_3', # [Step2_1,Step2_2,Step3_1,Step3_2,Step4]
    'verbose': False,
    'checkpoint_root': CHECKPOINT_ROOT,
    'power': 0.9,
    'evalIterations': 100,
    'adversarial':False
}

#### Adam

In [None]:
# for kaggle:
config = {
    'model_name': 'DeepLabV2', # [DeepLabV2, BiSeNet]
    'train_dataset_name': 'CityScapes', # [CityScapes, GTA5]
    'val_dataset_name': 'CityScapes', # [CityScapes]
    'n_classes': 19,
    'epochs': 50,
    'augmented': False,
    'augmentedType': 'transform1', # [transform1,transform2,transform3,transform4]
    'optimizer_name': 'Adam', # [SGD, Adam] # TRY BOTH
    'lr': 2.5e-4, 
    'momentum': 0.9,
    'weight_decay': 5e-4,
    'loss_fn_name': 'CrossEntropyLoss', # [CrossEntropyLoss]
    'ignore_index': 255,
    'batch_size': 4, # [2,4,8]
    'n_workers': 4, # [0,2,4]
    'device': 'cuda',
    'parallelize': True,
    'project_step': 'Step2_1_Adam', # [Step2_1,Step2_2,Step3_1,Step3_2,Step4]
    'verbose': False,
    'checkpoint_root': CHECKPOINT_ROOT,
    'power': 0.9,
    'evalIterations': 100,
    'adversarial':False
}

### Step 2.2

#### SGD

In [None]:
# for kaggle:
config = {
    'model_name': 'BiSeNet', # [DeepLabV2, BiSeNet]
    'train_dataset_name': 'CityScapes', # [CityScapes, GTA5]
    'val_dataset_name': 'CityScapes', # [CityScapes]
    'n_classes': 19,
    'epochs': 50,
    'augmented': False,
    'augmentedType': 'transform1', # [transform1,transform2,transform3,transform4]
    'optimizer_name': 'SGD', # [SGD, Adam]
    'lr': 2.5e-4,
    'momentum': 0.9,
    'weight_decay': 5e-4,
    'loss_fn_name': 'CrossEntropyLoss', # [CrossEntropyLoss]
    'ignore_index': 255,
    'batch_size': 8, # [2,4,8]
    'n_workers': 4, # [0,2,4]
    'device': 'cuda',
    'parallelize': True,
    'project_step': 'Step2_2_SGD', # [Step2_1,Step2_2,Step3_1,Step3_2,Step4]
    'verbose': False,
    'checkpoint_root': CHECKPOINT_ROOT,
    'power': 0.9,
    'evalIterations': 100,
    'adversarial':False     
}

#### Adam

In [None]:
# for kaggle:
config = {
    'model_name': 'BiSeNet', # [DeepLabV2, BiSeNet]
    'train_dataset_name': 'CityScapes', # [CityScapes, GTA5]
    'val_dataset_name': 'CityScapes', # [CityScapes]
    'n_classes': 19,
    'epochs': 50,
    'augmented': False,
    'augmentedType': 'transform1', # [transform1,transform2,transform3,transform4]
    'optimizer_name': 'Adam', # [SGD, Adam]
    'lr': 2.5e-4,
    'momentum': 0.9,
    'weight_decay': 5e-4,
    'loss_fn_name': 'CrossEntropyLoss', # [CrossEntropyLoss]
    'ignore_index': 255,
    'batch_size': 8, # [2,4,8]
    'n_workers': 4, # [0,2,4]
    'device': 'cuda',
    'parallelize': True,
    'project_step': 'Step2_2_Adam', # [Step2_1,Step2_2,Step3_1,Step3_2,Step4]
    'verbose': False,
    'checkpoint_root': CHECKPOINT_ROOT,
    'power': 0.9,
    'evalIterations': 100,
    'adversarial':False
}

### Step 3.1

#### SGD

In [None]:
# for kaggle:
config = {
    'model_name': 'BiSeNet', # [DeepLabV2, BiSeNet]
    'train_dataset_name': 'GTA5', # [CityScapes, GTA5]
    'val_dataset_name': 'CityScapes', # [CityScapes]
    'n_classes': 19,
    'epochs': 50,
    'augmented': False,
    'augmentedType': 'transform1', # [transform1,transform2,transform3,transform4]
    'optimizer_name': 'SGD', # [SGD, Adam]
    'lr': 2.5e-4,
    'momentum': 0.9,
    'weight_decay': 5e-4,
    'loss_fn_name': 'CrossEntropyLoss', # [CrossEntropyLoss]
    'ignore_index': 255,
    'batch_size': 8, # [2,4,8]
    'n_workers': 4, # [0,2,4]
    'device': 'cuda',
    'parallelize': True,
    'project_step': 'Step3_1_SGD', # [Step2_1,Step2_2,Step3_1,Step3_2,Step4]
    'verbose': True,
    'checkpoint_root': CHECKPOINT_ROOT,
    'power': 0.9,
    'evalIterations': 100,
    'adversarial':False
}

#### Adam

In [None]:
# for kaggle:
config = {
    'model_name': 'BiSeNet', # [DeepLabV2, BiSeNet]
    'train_dataset_name': 'GTA5', # [CityScapes, GTA5]
    'val_dataset_name': 'CityScapes', # [CityScapes]
    'n_classes': 19,
    'epochs': 50,
    'augmented': False,
    'augmentedType': 'transform1', # [transform1,transform2,transform3,transform4]
    'optimizer_name': 'Adam', # [SGD, Adam]
    'lr': 2.5e-4,
    'momentum': 0.9,
    'weight_decay': 5e-4,
    'loss_fn_name': 'CrossEntropyLoss', # [CrossEntropyLoss]
    'ignore_index': 255,
    'batch_size': 8, # [2,4,8]
    'n_workers': 4, # [0,2,4]
    'device': 'cuda',
    'parallelize': True,
    'project_step': 'Step3_1_Adam', # [Step2_1,Step2_2,Step3_1,Step3_2,Step4]
    'verbose': True,
    'checkpoint_root': CHECKPOINT_ROOT,
    'power': 0.9,
    'evalIterations': 100,
    'adversarial':False
}

### Step 3.2

#### L'optimizer va scelto sulla base del risultato ottenuto dallo step 3.1

In [None]:
# for kaggle:
config = {
    'model_name': 'BiSeNet', # [DeepLabV2, BiSeNet]
    'train_dataset_name': 'GTA5', # [CityScapes, GTA5]
    'val_dataset_name': 'CityScapes', # [CityScapes]
    'n_classes': 19,
    'epochs': 50,
    'augmented': True,
    'augmentedType': 'transform1', # [transform1,transform2,transform3,transform4]
    'optimizer_name': 'SGD', # [SGD, Adam]
    'lr': 2.5e-4,
    'momentum': 0.9,
    'weight_decay': 5e-4,
    'loss_fn_name': 'CrossEntropyLoss', # [CrossEntropyLoss]
    'ignore_index': 255,
    'batch_size': 8, # [2,4,8]
    'n_workers': 4, # [0,2,4]
    'device': 'cuda',
    'parallelize': True,
    'project_step': 'Step3_2_SGD_transform1', # [Step2_1,Step2_2,Step3_1,Step3_2,Step4]
    'verbose': False,
    'checkpoint_root': CHECKPOINT_ROOT,
    'power': 0.9,
    'evalIterations': 100,
    'adversarial':False
}

### Step 4

#### TODO

In [None]:
# for kaggle:
config = {
    'model_name': 'BiSeNet', # [DeepLabV2, BiSeNet]
    'train_dataset_name': 'CityScapes', # [CityScapes, GTA5]
    'val_dataset_name': 'CityScapes', # [CityScapes]
    'n_classes': 19,
    'epochs': 5,
    'augmented': True,
    'augmentedType': 'transform3', # [transform1,transform2,transform3,transform4]
    'optimizer_name': 'Adam', # [SGD, Adam]
    'lr': 2.5e-4,
    'momentum': 0.9,
    'weight_decay': 5e-4,
    'loss_fn_name': 'CrossEntropyLoss', # [CrossEntropyLoss]
    'ignore_index': 255,
    'batch_size': 8, # [2,4,8]
    'n_workers': 0, # [0,2,4]
    'device': 'cuda',
    'parallelize': True,
    'project_step': 'Step4_Adam', # [Step2_1,Step2_2,Step3_1,Step3_2,Step4]
    'verbose': False,
    'checkpoint_root': CHECKPOINT_ROOT,
    'power': 0.9,
    'evalIterations': 100,
    'adversarial': True
}

## Main

In [None]:
# Then clean the cache
torch.cuda.empty_cache()
# then collect the garbage
import gc
gc.collect()

In [None]:
if __name__ == '__main__':

    # for kaggle:
    pipeline(
        model_name=config['model_name'], 
        train_dataset_name=config['train_dataset_name'], 
        val_dataset_name=config['val_dataset_name'],
        n_classes=config['n_classes'],
        epochs=config['epochs'],
        augmented=config['augmented'],
        augmentedType=config['augmentedType'],
        optimizer_name=config['optimizer_name'],
        lr=config['lr'],
        momentum=config['momentum'],
        weight_decay=config['weight_decay'],
        loss_fn_name=config['loss_fn_name'],
        ignore_index=config['ignore_index'],
        batch_size=config['batch_size'],
        n_workers=config['n_workers'],
        device=config['device'],
        parallelize=config['parallelize'],
        project_step=config['project_step'],
        verbose=config['verbose'],
        checkpoint_root=config['checkpoint_root'],
        power=config['power'],
        evalIterations=config['evalIterations'],
        adversarial=config['adversarial']
    )

# Kaggle directory setup

In [None]:
# Only for reset
#! rm /kaggle/working/dir_name/checkpoint.pth
#! rm -r /kaggle/working/dir_name/images
#! rm -r /kaggle/working/dir_name/logs

In [None]:
! zip -r dir_name.zip /kaggle/working/dir_name

In [None]:
from IPython.display import FileLink
FileLink(r'dir_name.zip')