# Preparations

## Variables

### Logging

In [None]:
ENABLE_PRINT = False
ENABLE_WANDB_LOG = True
log_per_epoch = 20
n_classes = 19

train_step = 0
val_step = 0

### Device

In [None]:
import torch
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


## Utils

### Label-color correlator

In [None]:
from abc import ABCMeta
from dataclasses import dataclass
from typing import Tuple

@dataclass
class GTA5Label:
    name: str
    ID: int
    color: Tuple[int, int, int]

class GTA5Labels_TaskCV2017():
    road = GTA5Label(name = "road", ID=0, color=(128, 64, 128))
    sidewalk = GTA5Label(name = "sidewalk", ID=1, color=(244, 35, 232))
    building = GTA5Label(name = "building", ID=2, color=(70, 70, 70))
    wall = GTA5Label(name = "wall", ID=3, color=(102, 102, 156))
    fence = GTA5Label(name = "fence", ID=4, color=(190, 153, 153))
    pole = GTA5Label(name = "pole", ID=5, color=(153, 153, 153))
    light = GTA5Label(name = "light", ID=6, color=(250, 170, 30))
    sign = GTA5Label(name = "sign", ID=7, color=(220, 220, 0))
    vegetation = GTA5Label(name = "vegetation", ID=8, color=(107, 142, 35))
    terrain = GTA5Label(name = "terrain", ID=9, color=(152, 251, 152))
    sky = GTA5Label(name = "sky", ID=10, color=(70, 130, 180))
    person = GTA5Label(name = "person", ID=11, color=(220, 20, 60))
    rider = GTA5Label(name = "rider", ID=12, color=(255, 0, 0))
    car = GTA5Label(name = "car", ID=13, color=(0, 0, 142))
    truck = GTA5Label(name = "truck", ID=14, color=(0, 0, 70))
    bus = GTA5Label(name = "bus", ID=15, color=(0, 60, 100))
    train = GTA5Label(name = "train", ID=16, color=(0, 80, 100))
    motocycle = GTA5Label(name = "motocycle", ID=17, color=(0, 0, 230))
    bicycle = GTA5Label(name = "bicycle", ID=18, color=(119, 11, 32))
    void = GTA5Label(name = "void", ID=255, color=(0,0,0))

    list_ = [
        road,
        sidewalk,
        building,
        wall,
        fence,
        pole,
        light,
        sign,
        vegetation,
        terrain,
        sky,
        person,
        rider,
        car,
        truck,
        bus,
        train,
        motocycle,
        bicycle,
        void
    ]

### Functions

In [None]:
import numpy as np
import torch
import wandb

import seaborn as sns
import matplotlib.pyplot as plt

import zipfile
from tqdm.notebook import tqdm

import time

!pip -q install -U fvcore

from fvcore.nn import FlopCountAnalysis, flop_count_table

def pretty_extract(zip_path:str, extract_to:str) -> None:
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        png_files = [f for f in zip_ref.namelist() if f.endswith('.png')]

        for file in tqdm(png_files, desc="Extracting PNGs"):
            zip_ref.extract(file, path=extract_to)

def poly_lr_scheduler(optimizer, init_lr:float, iter:int=0, lr_decay_iter:int=1, max_iter:int=50, power:float=0.9) -> float:
    """Polynomial decay of learning rate
            :param init_lr is base learning rate
            :param iter is a current iteration
            :param lr_decay_iter how frequently decay occurs, default is 1
            :param max_iter is number of maximum iterations
            :param power is a polymomial power

    """
    if ((iter % lr_decay_iter) != 0) or iter > max_iter:
        return optimizer.param_groups[0]['lr']

    lr = init_lr*(1 - iter/max_iter)**power
    optimizer.param_groups[0]['lr'] = lr
    return lr

def fast_hist(a:np.ndarray, b:np.ndarray, n:int) -> np.ndarray:
    '''
    a and b are label and prediction respectively
    n is the number of classes
    '''
    k = (a >= 0) & (a < n)
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)

def fast_hist_torch_cuda(a: torch.Tensor, b: torch.Tensor, n: int) -> torch.Tensor:
    """
    a and b are label and prediction respectively.
    n is the number of classes.
    This version works with CUDA tensors.
    """
    k = (a >= 0) & (a < n)
    a = a[k].to(torch.int64)
    b = b[k].to(torch.int64)

    # Compute linear indices
    indices = n * a + b

    # Initialize histogram on the same device
    hist = torch.zeros(n * n, dtype=torch.int64, device=a.device)

    # Count occurrences (equivalent to bincount)
    hist.scatter_add_(0, indices, torch.ones_like(indices, dtype=torch.int64))

    return hist.view(n, n)

def per_class_iou(hist:np.ndarray) -> np.ndarray:
    epsilon = 1e-5
    return (np.diag(hist)) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)

def per_class_iou_cuda(hist:torch.Tensor) -> torch.Tensor:
    epsilon = 1e-5
    diag = torch.diag(hist)

    sum_rows = hist.sum(dim=1)
    sum_cols = hist.sum(dim=0)

    iou = diag / (sum_rows + sum_cols - diag + epsilon)
    return iou

# Mapping labelId image to RGB image
def decode_segmap(mask:np.ndarray) -> np.ndarray:
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for label_id in GTA5Labels_TaskCV2017().list_:
        color_mask[mask == label_id.ID, :] = label_id.color

    return color_mask

def tensorToImageCompatible(t:torch.Tensor) -> np.ndarray:
    """
    convert from a tensor of shape [C, H, W] where a normalization has been applied
    to an unnormalized tensor of shape [H, W, C],
    so *plt.imshow(tensorToImageCompatible(tensor))* works as expected.\n
    Intended to be used to recover the original element
    when this transformation is used:
    - transform = TF.Compose([
        TF.ToTensor(),
        TF.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225])])
    """
    mean = torch.tensor([0.485, 0.456, 0.406]).view([-1, 1, 1])
    std = torch.tensor([0.229, 0.224, 0.225]).view([-1, 1, 1])

    unnormalized = t * std + mean

    return (unnormalized.permute(1,2,0).clamp(0,1).numpy()*255).astype(np.uint8)


def log_confusion_matrix(title:str, hist:np.ndarray, tag:str, step_name:str, step_value:int):
    row_sums = hist.sum(axis=1, keepdims=True)
    safe_hist = np.where(row_sums == 0, 0, hist / row_sums)

    plt.figure(figsize=(10, 8))
    sns.heatmap(100.*safe_hist, fmt=".2f", annot=True, cmap="Blues", annot_kws={'size': 7})
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(title)

    wandb.log({tag: wandb.Image(plt), step_name:step_value})
    plt.close()

def log_confusion_matrix_cuda(title:str, hist:torch.Tensor, tag:str, step_name:str, step_value:int):
    hist_np = hist.detach().cpu().numpy()  # Conversione per compatibilità con seaborn

    row_sums = hist.sum(axis=1, keepdims=True)
    safe_hist = np.where(row_sums == 0, 0, hist / row_sums)

    plt.figure(figsize=(10, 8))
    sns.heatmap(100.*safe_hist, fmt=".2f", annot=True, cmap="Blues", annot_kws={'size': 7})
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(title)

    wandb.log({tag: wandb.Image(plt), step_name:step_value})
    plt.close()

def log_bar_chart_ioU(title:str, class_names:list, mIou:float, iou_class:list, tag:str, step_name:str, epoch:int):
    iou_percent = [round(iou*100., 2) for iou in iou_class]
    miou_percent = round(mIou*100., 2)

    all_labels = ["mIoU"] + class_names
    all_values = [miou_percent] + iou_percent

    plt.figure(figsize=(14, 5))
    bars = plt.bar(range(len(all_values)), all_values, color='skyblue')
    plt.xticks(range(len(all_labels)), all_labels, rotation=45, ha="right")

    for i, bar in enumerate(bars):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width() / 2.0, height + 1, f'{height:.2f}',
                ha='center', va='bottom', fontsize=9)

    plt.ylabel("IoU (%)")
    plt.ylim(0, 105)
    plt.title(title)
    plt.tight_layout()

    wandb.log({tag: wandb.Image(plt), step_name:epoch})

    plt.close()

def num_flops(device, model:torch.nn.Module, H:int, W:int):
    model.eval()
    img = (torch.zeros((1,3,H,W), device=device),)

    flops = FlopCountAnalysis(model, img)
    # return flop_count_table(flops)
    # return flops.total()

    flops = FlopCountAnalysis(model, (torch.zeros((1, 3, 512, 1024), device=device),))
    flops_by_op = flops.by_operator()

    # Prendi la prima operazione (es: 'conv2d') e i suoi flops
    first_op_name = list(flops_by_op.keys())[0]
    first_op_flops = flops_by_op[first_op_name] / 1e9  # in GFLOPs

def latency(device, model:torch.nn.Module, H:int, W:int):
    model.eval()

    img = torch.zeros((1,3,H,W)).to(device)
    iterations = 100
    latency_list = []
    FPS_list  = []

    with torch.no_grad():
        for _ in tqdm(range(iterations)):
            start_time = time.time()
            _ = model(img)
            end_time = time.time()

            latency = end_time - start_time

            latency_list.append(latency)
            FPS_list.append(1.0/latency)

    mean_latency = np.mean(latency_list)*1000
    std_latency = np.std(latency_list)*1000
    mean_FPS = np.mean(FPS_list)

    return mean_latency, std_latency, mean_FPS


## CityScapes download

In [None]:
!pip install -q gdown

file_id = "1MI8QsvjW0R6WDJiL49L7sDGpPWYAQB6O"
!gdown https://drive.google.com/uc?id={file_id}

pretty_extract("Cityscapes.zip", ".")


Downloading...
From (original): https://drive.google.com/uc?id=1MI8QsvjW0R6WDJiL49L7sDGpPWYAQB6O
From (redirected): https://drive.google.com/uc?id=1MI8QsvjW0R6WDJiL49L7sDGpPWYAQB6O&confirm=t&uuid=583a3567-d0bc-4404-8093-d0ce47cb55ca
To: /content/Cityscapes.zip
100% 4.97G/4.97G [00:44<00:00, 111MB/s]


Extracting PNGs:   0%|          | 0/6216 [00:00<?, ?it/s]

## Cityscapes implementation

In [None]:
from torch.utils.data import Dataset

import os
# from torchvision.transforms import ToTensor, ToPILImage
import cv2
from PIL import Image

# from torchvision.io import decode_image, read_image
# from torchvision.io.image import ImageReadMode

import numpy as np

class CityScapes(Dataset):
    def __init__(self, rootdir, split="train", targetdir="gtFine", imgdir="images", transform=None, target_transform=None):
        super(CityScapes, self).__init__()

        self.rootdir = rootdir
        self.split = split
        self.targetdir = os.path.join(self.rootdir, targetdir, self.split) # ./gtFine/train/
        self.imgdir = os.path.join(self.rootdir, imgdir, self.split) # ./images/train/
        self.transform = transform
        self.target_transform = target_transform

        self.imgs_path = []
        self.targets_color_path = []
        self.targets_labelIds_path = []

        for city in os.listdir(self.imgdir): # frankfurt
            img_city_dir = os.path.join(self.imgdir, city) # ./images/train/frankfurt/
            target_city_dir = os.path.join(self.targetdir, city) # ./gtFine/train/frankfurt/

            for img_path in os.listdir(img_city_dir): # frankfurt_000000_000294_leftImg8bit.png
                if img_path.endswith(".png"):
                  self.imgs_path.append(os.path.join(img_city_dir, img_path)) # ./images/train/frankfurt/frankfurt_000000_000294_leftImg8bit.png

                  target_color_path = img_path.replace("leftImg8bit", "gtFine_color") # frankfurt_000000_000294_gtFine_color.png
                  target_labelIds_path = img_path.replace("leftImg8bit", "gtFine_labelTrainIds") # frankfurt_000000_000294_gtFine_labelTrainIds.png

                  self.targets_color_path.append(os.path.join(target_city_dir, target_color_path)) # ./gtFine/train/frankfurt/frankfurt_000000_000294_gtFine_color.png
                  self.targets_labelIds_path.append(os.path.join(target_city_dir, target_labelIds_path)) # ./gtFine/train/frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png

    def __getitem__(self, idx):
        image = Image.open(self.imgs_path[idx]).convert('RGB')
        # image = decode_image(self.imgs_path[idx]).to(dtype=torch.float32)

        target_color = Image.open(self.targets_color_path[idx]).convert('RGB')

        target_labelIds = cv2.imread(self.targets_labelIds_path[idx], cv2.IMREAD_UNCHANGED).astype(np.long)
        # target_labelIds = Image.open(self.targets_labelIds_path[idx])
        # target_labelIds =  read_image(self.targets_labelIds_path[idx], mode=ImageReadMode.GRAY)
        # target_labelIds = decode_image(self.targets_labelIds_path[idx]).to(dtype=torch.uint8)

        if self.transform is not None:
            image = self.transform(image)
            target_color = self.transform(target_color)
        if self.target_transform is not None:
            target_labelIds = self.target_transform(target_labelIds)

        return image, target_color, target_labelIds
        # return image, target_labelIds

    def __len__(self):
        return len(self.imgs_path)


## GTA5 implementation

In [None]:
from torch.utils.data import Dataset
import numpy as np

import os
import cv2
from PIL import Image

from tqdm import tqdm
import random

import torchvision.transforms as TF

# TODO: decide other augmentations
class GTA5(Dataset):
    def __init__(self, rootdir, file_names, imgdir="images", targetdir="labels", augment=False, transform=None, target_transform=None):
        super(GTA5, self).__init__()

        self.rootdir = rootdir

        self.targetdir = os.path.join(self.rootdir, targetdir) # ./labels
        self.imgdir = os.path.join(self.rootdir, imgdir) # ./images

        self.augment = augment

        self.transform = transform
        self.target_transform = target_transform

        self.imgs_path = []
        self.targets_color_path = []
        self.targets_labelIds_path = []

        for i, image_file in enumerate(file_names): # 00001.png
            self.imgs_path.append(os.path.join(self.imgdir, image_file)) #./images/00001.png

            target_color_path = image_file # 00001.png
            target_labelsId_path = image_file.split(".")[0]+"_labelIds.png" # 00001_labelIds.png

            self.targets_color_path.append(os.path.join(self.targetdir, target_color_path)) #./labels/00001.png
            self.targets_labelIds_path.append(os.path.join(self.targetdir, target_labelsId_path)) #./labels/00001_labelIDs.png

    def create_target_img(self):
        list_ = GTA5Labels_TaskCV2017().list_

        for i, img_path in tqdm(enumerate(self.targets_color_path)):
            image_numpy = np.asarray(Image.open(img_path).convert('RGB'))

            H, W, _ = image_numpy.shape
            label_image = 255*np.ones((H, W), dtype=np.uint8)

            for label in list_:
                label_image[(image_numpy == label.color).all(axis=-1)] = label.ID

            new_img = Image.fromarray(label_image)
            new_img.save(self.targets_labelIds_path[i])

    def __getitem__(self, idx):
        image = Image.open(self.imgs_path[idx]).convert('RGB')

        target_color = Image.open(self.targets_color_path[idx]).convert('RGB')
        target_labelIds = cv2.imread(self.targets_labelIds_path[idx], cv2.IMREAD_UNCHANGED).astype(np.int64)

        if self.transform is not None:
            image = self.transform(image)
            target_color = self.transform(target_color)
        if self.target_transform is not None:
            target_labelIds = self.target_transform(target_labelIds)

        if self.augment:
            image, target_color, target_labelIds = self.augment_data(image, target_color, target_labelIds)
        return image, target_color, target_labelIds

    def __len__(self):
        return len(self.imgs_path)

    def augment_data(self, image, target_color, target_labelIds):
        # original = image.detach().clone()
        val = random.random()
    # Geometric Transformations

        # Random horizontal flipping
        if val < 0.5:
            image = TF.functional.hflip(image)
            target_color = TF.functional.hflip(target_color)
            target_labelIds = TF.functional.hflip(target_labelIds)

    # Photometric Transformations

        # Gaussian Blur
        if val < 0.5:
            image = TF.functional.gaussian_blur(image, 5)

        # Multiply
        if val < 0.5:
            factor = random.uniform(0.6, 1.2)
            image *= factor
            image = torch.clamp(image, 0, 255)

        return image, target_color, target_labelIds


def GTA5_dataset_splitter(rootdir, train_split_percent, split_seed = None, imgdir="images", targetdir="labels", augment=False, transform=None, target_transform=None):
    assert 0.0 < train_split_percent < 1.0, "train_split_percent should be a float between 0 and 1"

    target_path = os.path.join(rootdir, targetdir) # ./labels
    img_path = os.path.join(rootdir, imgdir) # ./images

    file_names = [
        f for f in os.listdir(img_path)
        if f.endswith(".png") and os.path.exists(os.path.join(target_path, f.split(".")[0]+"_labelIds.png"))
    ]

    if split_seed is not None:
        random.seed(split_seed)
    random.shuffle(file_names)
    random.seed()

    split_idx = int(len(file_names) * train_split_percent)

    train_files = file_names[:split_idx]
    val_files = file_names[split_idx:]

    return GTA5(rootdir, train_files, imgdir, targetdir, augment, transform, target_transform), \
           GTA5(rootdir, val_files, imgdir, targetdir, False, transform, target_transform)


## Bisenet

In [None]:
import torch
from torch import nn
import warnings
warnings.filterwarnings(action='ignore')

import torch
from torchvision import models

class resnet18(torch.nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.resnet18(pretrained=pretrained)
        self.conv1 = self.features.conv1
        self.bn1 = self.features.bn1
        self.relu = self.features.relu
        self.maxpool1 = self.features.maxpool
        self.layer1 = self.features.layer1
        self.layer2 = self.features.layer2
        self.layer3 = self.features.layer3
        self.layer4 = self.features.layer4

    def forward(self, input):
        x = self.conv1(input)
        x = self.relu(self.bn1(x))
        x = self.maxpool1(x)
        feature1 = self.layer1(x)  # 1 / 4
        feature2 = self.layer2(feature1)  # 1 / 8
        feature3 = self.layer3(feature2)  # 1 / 16
        feature4 = self.layer4(feature3)  # 1 / 32
        # global average pooling to build tail
        tail = torch.mean(feature4, 3, keepdim=True)
        tail = torch.mean(tail, 2, keepdim=True)
        return feature3, feature4, tail


class resnet101(torch.nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.resnet101(pretrained=pretrained)
        self.conv1 = self.features.conv1
        self.bn1 = self.features.bn1
        self.relu = self.features.relu
        self.maxpool1 = self.features.maxpool
        self.layer1 = self.features.layer1
        self.layer2 = self.features.layer2
        self.layer3 = self.features.layer3
        self.layer4 = self.features.layer4

    def forward(self, input):
        x = self.conv1(input)
        x = self.relu(self.bn1(x))
        x = self.maxpool1(x)
        feature1 = self.layer1(x)  # 1 / 4
        feature2 = self.layer2(feature1)  # 1 / 8
        feature3 = self.layer3(feature2)  # 1 / 16
        feature4 = self.layer4(feature3)  # 1 / 32
        # global average pooling to build tail
        tail = torch.mean(feature4, 3, keepdim=True)
        tail = torch.mean(tail, 2, keepdim=True)
        return feature3, feature4, tail


def build_contextpath(name):
    model = {
        'resnet18': resnet18(pretrained=True),
        'resnet101': resnet101(pretrained=True)
    }
    return model[name]

class ConvBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, input):
        x = self.conv1(input)
        return self.relu(self.bn(x))


class Spatial_path(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.convblock1 = ConvBlock(in_channels=3, out_channels=64)
        self.convblock2 = ConvBlock(in_channels=64, out_channels=128)
        self.convblock3 = ConvBlock(in_channels=128, out_channels=256)

    def forward(self, input):
        x = self.convblock1(input)
        x = self.convblock2(x)
        x = self.convblock3(x)
        return x


class AttentionRefinementModule(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.sigmoid = nn.Sigmoid()
        self.in_channels = in_channels
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

    def forward(self, input):
        # global average pooling
        x = self.avgpool(input)
        assert self.in_channels == x.size(1), 'in_channels and out_channels should all be {}'.format(x.size(1))
        x = self.conv(x)
        x = self.sigmoid(self.bn(x))
        # x = self.sigmoid(x)
        # channels of input and x should be same
        x = torch.mul(input, x)
        return x


class FeatureFusionModule(torch.nn.Module):
    def __init__(self, num_classes, in_channels):
        super().__init__()
        # self.in_channels = input_1.channels + input_2.channels
        # resnet101 3328 = 256(from spatial path) + 1024(from context path) + 2048(from context path)
        # resnet18  1024 = 256(from spatial path) + 256(from context path) + 512(from context path)
        self.in_channels = in_channels

        self.convblock = ConvBlock(in_channels=self.in_channels, out_channels=num_classes, stride=1)
        self.conv1 = nn.Conv2d(num_classes, num_classes, kernel_size=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(num_classes, num_classes, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

    def forward(self, input_1, input_2):
        x = torch.cat((input_1, input_2), dim=1)
        assert self.in_channels == x.size(1), 'in_channels of ConvBlock should be {}'.format(x.size(1))
        feature = self.convblock(x)
        x = self.avgpool(feature)

        x = self.relu(self.conv1(x))
        x = self.sigmoid(self.conv2(x))
        x = torch.mul(feature, x)
        x = torch.add(x, feature)
        return x


class BiSeNet(torch.nn.Module):
    def __init__(self, num_classes, context_path):
        super().__init__()
        # build spatial path
        self.saptial_path = Spatial_path()

        # build context path
        self.context_path = build_contextpath(name=context_path)

        # build attention refinement module  for resnet 101
        if context_path == 'resnet101':
            self.attention_refinement_module1 = AttentionRefinementModule(1024, 1024)
            self.attention_refinement_module2 = AttentionRefinementModule(2048, 2048)
            # supervision block
            self.supervision1 = nn.Conv2d(in_channels=1024, out_channels=num_classes, kernel_size=1)
            self.supervision2 = nn.Conv2d(in_channels=2048, out_channels=num_classes, kernel_size=1)
            # build feature fusion module
            self.feature_fusion_module = FeatureFusionModule(num_classes, 3328)

        elif context_path == 'resnet18':
            # build attention refinement module  for resnet 18
            self.attention_refinement_module1 = AttentionRefinementModule(256, 256)
            self.attention_refinement_module2 = AttentionRefinementModule(512, 512)
            # supervision block
            self.supervision1 = nn.Conv2d(in_channels=256, out_channels=num_classes, kernel_size=1)
            self.supervision2 = nn.Conv2d(in_channels=512, out_channels=num_classes, kernel_size=1)
            # build feature fusion module
            self.feature_fusion_module = FeatureFusionModule(num_classes, 1024)
        else:
            print('Error: unspport context_path network \n')

        # build final convolution
        self.conv = nn.Conv2d(in_channels=num_classes, out_channels=num_classes, kernel_size=1)

        self.init_weight()

        self.mul_lr = []
        self.mul_lr.append(self.saptial_path)
        self.mul_lr.append(self.attention_refinement_module1)
        self.mul_lr.append(self.attention_refinement_module2)
        self.mul_lr.append(self.supervision1)
        self.mul_lr.append(self.supervision2)
        self.mul_lr.append(self.feature_fusion_module)
        self.mul_lr.append(self.conv)

    def init_weight(self):
        for name, m in self.named_modules():
            if 'context_path' not in name:
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                elif isinstance(m, nn.BatchNorm2d):
                    m.eps = 1e-5
                    m.momentum = 0.1
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

    def forward(self, input):
        # output of spatial path
        sx = self.saptial_path(input)

        # output of context path
        cx1, cx2, tail = self.context_path(input)
        cx1 = self.attention_refinement_module1(cx1)
        cx2 = self.attention_refinement_module2(cx2)
        cx2 = torch.mul(cx2, tail)
        # upsampling
        cx1 = torch.nn.functional.interpolate(cx1, size=sx.size()[-2:], mode='bilinear')
        cx2 = torch.nn.functional.interpolate(cx2, size=sx.size()[-2:], mode='bilinear')
        cx = torch.cat((cx1, cx2), dim=1)

        if self.training == True:
            cx1_sup = self.supervision1(cx1)
            cx2_sup = self.supervision2(cx2)
            cx1_sup = torch.nn.functional.interpolate(cx1_sup, size=input.size()[-2:], mode='bilinear')
            cx2_sup = torch.nn.functional.interpolate(cx2_sup, size=input.size()[-2:], mode='bilinear')

        # output of feature fusion module
        result = self.feature_fusion_module(sx, cx)

        # upsampling
        result = torch.nn.functional.interpolate(result, scale_factor=8, mode='bilinear')
        result = self.conv(result)

        if self.training == True:
            return result, cx1_sup, cx2_sup

        return result


## Wandb

In [None]:
!pip -q install wandb

# Train/Val loops

### Train Loop

In [None]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import numpy as np
import wandb

# TODO: maybe last class (void) converted to n_classes-1 due to argmax
# TODO: waiting confirmation for logging 1 batch every X
def train(model:nn.Module, train_loader:DataLoader, criterion:nn.Module, optimizer:optim.Optimizer) -> tuple[float, float, np.ndarray]:
    global device
    global n_classes
    global ENABLE_PRINT
    global ENABLE_WANDB_LOG
    global train_step
    global log_per_epoch

    model.train()

    num_batch = len(train_loader)
    chunk_batch = num_batch//log_per_epoch+1

    num_sample = len(train_loader.dataset)
    seen_sample = 0

    train_loss = 0.0
    train_hist = torch.zeros((n_classes,n_classes)).to(device)

    for batch_idx, (inputs, _, targets) in enumerate(train_loader):
        batch_size = inputs.size(0)
        seen_sample += batch_size

        inputs, targets = inputs.to(device), targets.squeeze().to(device)

        outputs, cx1_sup, cx2_sup = model(inputs)

        loss = criterion(outputs, targets)
        loss += criterion(cx1_sup, targets)
        loss += criterion(cx2_sup, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        predicted = outputs.argmax(1)

        hist_batch = torch.zeros((n_classes, n_classes)).to(device)

        for i in range(len(inputs)):
            hist_batch += fast_hist_torch_cuda(targets[i].detach(), predicted[i].detach(), n_classes)

        train_loss += loss.item() * batch_size
        train_hist += hist_batch

        # TODO: maybe batch_idx+1? same for validate
        if ((batch_idx+1) % chunk_batch) == 0:
            iou_batch = per_class_iou_cuda(hist_batch)
            if ENABLE_PRINT:
                    print(f'Training [{seen_sample}/{num_sample} ({100. * seen_sample / num_sample:.0f}%)]')
                    print(f'\tLoss: {loss.item():.6f}')
                    print(f"\tmIoU: {100.*iou_batch[iou_batch > 0].mean():.4f}")

            if ENABLE_WANDB_LOG:
                wandb.log({
                        "train/step": train_step,
                        "train/batch_loss": loss.item(),
                        "train/batch_mIou": 100.*iou_batch[iou_batch > 0].mean()
                    },
                    commit=True,
                )
                train_step += 1

    train_loss = train_loss / seen_sample

    train_iou_class = per_class_iou_cuda(train_hist)
    train_mIou = train_iou_class[train_iou_class > 0].mean()

    return train_loss, train_mIou, train_hist, train_iou_class


### Validation loop

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import numpy as np
import wandb

def validate(model:nn.Module, val_loader:DataLoader, criterion:nn.Module) -> tuple[float, float, np.ndarray]:
    global device
    global n_classes
    global ENABLE_PRINT
    global ENABLE_WANDB_LOG
    global val_step
    global log_per_epoch

    model.eval()

    num_batch = len(val_loader)
    chunk_batch = num_batch//log_per_epoch+1

    num_sample = len(val_loader.dataset)
    seen_sample = 0
    chunk_sample = 0

    val_loss = 0.0
    val_hist = torch.zeros((n_classes,n_classes)).to(device)

    chunk_loss = 0.0
    chunk_hist = torch.zeros((n_classes,n_classes)).to(device)

    with torch.no_grad():
        for batch_idx, (inputs, _, targets) in enumerate(val_loader):
            batch_size = inputs.size(0)

            inputs, targets = inputs.to(device), targets.squeeze().to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            predicted = outputs.argmax(1)

            hist_batch = torch.zeros((n_classes, n_classes)).to(device)
            for i in range(len(inputs)):
                hist_batch += fast_hist_torch_cuda(targets[i].detach(), predicted[i].detach(), n_classes)

            chunk_sample += batch_size
            chunk_loss += loss.item() * batch_size
            chunk_hist += hist_batch

            if ((batch_idx+1) % chunk_batch) == 0:
                seen_sample += chunk_sample
                val_loss += chunk_loss
                val_hist += chunk_hist

                if ENABLE_PRINT:
                    iou_batch = per_class_iou_cuda(hist_batch)
                    print(f'Validation [{seen_sample}/{num_sample} ({100. * seen_sample / num_sample:.0f}%)]')
                    print(f'\tLoss: {loss.item():.6f}')
                    print(f"\tmIoU: {100.*iou_batch[iou_batch > 0].mean():.4f}")

                if ENABLE_WANDB_LOG:
                    iou_batch = per_class_iou_cuda(chunk_hist)
                    wandb.log({
                            "validate/step": val_step,
                            "validate/batch_loss": chunk_loss/chunk_sample,
                            "validate/batch_mIou": 100.*iou_batch[iou_batch > 0].mean()
                        },
                        commit=True,
                    )

                    val_step += 1

                chunk_sample = 0
                chunk_loss = 0.0
                chunk_hist = torch.zeros((n_classes, n_classes)).to(device)

        if chunk_sample > 0:
            seen_sample += chunk_sample
            val_loss += chunk_loss
            val_hist += chunk_hist

            if ENABLE_PRINT:
                iou_batch = per_class_iou_cuda(hist_batch)
                print(f'Validation [{seen_sample}/{num_sample} ({100. * seen_sample / num_sample:.0f}%)]')
                print(f'\tLoss: {loss.item():.6f}')
                print(f"\tmIoU: {100.*iou_batch[iou_batch > 0].mean():.4f}")

            if ENABLE_WANDB_LOG:
                iou_batch = per_class_iou_cuda(chunk_hist)
                wandb.log({
                        "validate/step": val_step,
                        "validate/batch_loss": chunk_loss/chunk_sample,
                        "validate/batch_mIou": 100.*iou_batch[iou_batch > 0].mean()
                    },
                    commit=True,
                )

                val_step += 1

    val_loss = val_loss / seen_sample

    val_iou_class = per_class_iou_cuda(val_hist)
    val_mIou = val_iou_class[val_iou_class > 0].mean()

    return val_loss, val_mIou, val_hist, val_iou_class

# Machine learning

In [22]:
# TODO: something not working with validate/batch_miou when epoch starts, always a peak
# Possibly since less images -> less classes seen -> lower denominator when calculating mIou
# TODO: something wrong with validate, batch_loss going up and batch_miou going down while epoch-level metrics are fine
# TODO: num of log from validate and train is different?
# Found validate/step went back to 0 -> typo: validate_step instead of val_step

def pipeline():
    from torch.utils.data import DataLoader
    import torchvision.transforms as TF
    import torch.nn as nn
    import torch.optim as optim
    import wandb
    import os

    global device
    global n_classes
    global ENABLE_PRINT
    global ENABLE_WANDB_LOG
    global train_step
    global val_step
    global log_per_epoch

    ENABLE_PRINT = False
    ENABLE_WANDB_LOG = True
    train_step = 0
    val_step = 0
    log_per_epoch = 20

    models_root_dir = "./models"
    !rm -rf {models_root_dir}
    !mkdir {models_root_dir}

    B = 3
    H = 512
    W = 1024
    n_classes = 19

    backbone = "BiSeNet"
    context_path = "resnet18"

    # FIXME: not work, look under where run config is defined
    start_epoch = 0 # <--------- Last epoch that i completed (in this script i will perform from the start+1 to end)
    end_epoch = 2
    max_epoch = 50

    assert start_epoch < end_epoch <= max_epoch, "Check your start/end/max epoch settings."

    init_lr=2.5e-3
    lr_decay_iter = 1
    momentum=0.9
    weight_decay=1e-4
    dataset = "Cityscapes"

    transform = TF.Compose([
        TF.ToTensor(),
        TF.Resize((H,W)),
        TF.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225])
    ])
    target_transform = TF.Compose([
        TF.ToTensor(),
        TF.Resize((H, W), interpolation=TF.InterpolationMode.NEAREST),
    ])

    # Dataset objects
    if dataset == "Cityscapes":
        data_train = CityScapes("./Cityscapes/Cityspaces", split="train", transform=transform, target_transform=target_transform)
        data_val = CityScapes("./Cityscapes/Cityspaces", split="val", transform=transform, target_transform=target_transform)
    elif dataset == "GTA5":
        data_train, data_val = GTA5_dataset_splitter("./Gta5_extended", train_split_percent=0.8, split_seed=42, transform=transform, target_transform=target_transform)
    else:
        raise Exception("Wrong dataset name")
    train_loader = DataLoader(data_train, batch_size=B, shuffle=True)
    val_loader = DataLoader(data_val, batch_size=B, shuffle=True)

    # Architecture
    if backbone == "BiSeNet":
        model = BiSeNet(n_classes, context_path).to(device)
        architecture = backbone+"-"+context_path
    elif backbone == "DeepLab":
        model = get_deeplab_v2(num_classes=n_classes, pretrain=True).to(device)
        architecture = backbone
    else:
        raise Exception("Wrong model name")

    # The other 2
    criterion = nn.CrossEntropyLoss(ignore_index=255)
    optimizer = optim.SGD(model.parameters(), lr=init_lr, momentum=momentum, weight_decay=weight_decay)

    # TODO: wandb can't let us reuse the same run_id, need way to manage it
    # Wandb setup and metrics
    run_name = f"step_2B_new"
    run_id = f"{run_name}_{architecture}_{dataset}"
    run = wandb.init(
        entity="Machine_learning_and_Deep_learning_labs",
        project="Semantic Segmentation",
        name=run_name,
        id=run_id,
        resume="allow", # <----------------  IMPORTANT CONFIG KEY
        config={
            "initial_learning_rate": init_lr,
            "lr_decay_iter": lr_decay_iter,
            "momentum": momentum,
            "weight_decay": weight_decay,
            "architecture": architecture,
            "dataset": dataset,
            "start_epoch": start_epoch,
            "end_epoch": end_epoch,
            "max_epoch": max_epoch,
            "batch": B,
            "lr_scheduler": "poly"
        },
    )

    wandb.define_metric("epoch/step")
    wandb.define_metric("epoch/*", step_metric="epoch/step")

    wandb.define_metric("train/step")
    wandb.define_metric("train/*", step_metric="train/step")

    wandb.define_metric("validate/step")
    wandb.define_metric("validate/*", step_metric="validate/step")

    # FIXME: does not work, problems with run id
    # Loading form a starting point
    if start_epoch > 0:
        artifact = run.use_artifact(f'Machine_learning_and_Deep_learning_labs/Semantic Segmentation/{run_id}:epoch_{start_epoch}', type='model')
        artifact_dir = artifact.download()

        artifact_path = os.path.join(artifact_dir, run_id+f"_epoch_{start_epoch}.pth")

        checkpoint = torch.load(artifact_path, map_location=device)

        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        train_step = checkpoint["train_step"]+1
        val_step = checkpoint["validate_step"]+1

    # Main Loop
    for epoch in range(start_epoch+1, end_epoch+1):
        print("-----------------------------")
        print(f"Epoch {epoch}")

        lr = poly_lr_scheduler(optimizer, init_lr, epoch-1, max_iter=max_epoch)

        print(f"[Poly LR] 100xLR: {100.*lr:.6f}")

        run.log({
            "epoch/step": epoch,
            "epoch/100xlearning_rate": 100.*lr,
        })

        train_loss, train_mIou, train_hist, train_mIou_class = train(model, train_loader, criterion, optimizer)

        print(f'[Train Loss] : {train_loss:.6f} [mIoU]: {100.*train_mIou:.2f}%')

        # log_confusion_matrix("Confusion Matrix - Train", train_hist, "epoch/train_confusion_matrix", "epoch/step", epoch)
        run.log({
                "epoch/step": epoch,
                "epoch/train_loss": train_loss,
                "epoch/train_mIou": 100*train_mIou
            },
            commit=True,
        )

        val_loss, val_mIou, val_hist, val_mIou_class = validate(model, val_loader, criterion)

        print(f'[Validation Loss] : {val_loss:.6f} [mIoU]: {100.*val_mIou:.2f}%')

        # log_confusion_matrix("Confusion Matrix - Validate", val_hist, "epoch/validate_confusion_matrix", "epoch/step", epoch)
        run.log({
                "epoch/step": epoch,
                "epoch/val_loss": val_loss,
                "epoch/val_mIou": 100*val_mIou
            },
            commit=True
        )


        if epoch % 2 == 0 or epoch == end_epoch:
            checkpoint = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": train_step,
                "validate_step": val_step,
            }

            file_name = f"{run_id}_epoch_{epoch}.pth"

            # TODO: add some tables to artifact to enable comparisons

            # Saving the progress
            file_path = os.path.join(models_root_dir, file_name)
            torch.save(checkpoint, file_path)

            print(f"Model saved to {file_path}")

            artifact = wandb.Artifact(name=run_id, type="model")
            artifact.add_file(file_path)

            run.log_artifact(artifact, aliases=["latest", f"epoch_{epoch}"])

        if (epoch % 10) == 0:
            log_confusion_matrix_cuda("Confusion Matrix - Train", train_hist, "epoch/train_confusion_matrix", "epoch/step", epoch)
            log_confusion_matrix_cuda("Confusion Matrix - Validate", val_hist, "epoch/validate_confusion_matrix", "epoch/step", epoch)

            log_bar_chart_ioU(f"Train IoU per class - epoch {epoch}", [c.name for c in GTA5Labels_TaskCV2017().list_ if c.name != "void"], train_mIou, train_mIou_class, "epoch/train_Iou_class", "epoch/step", epoch)
            log_bar_chart_ioU(f"Validate IoU per class - epoch {epoch}", [c.name for c in GTA5Labels_TaskCV2017().list_ if c.name != "void"], val_mIou, val_mIou_class, "epoch/validate_Iou_class", "epoch/step", epoch)


    # TODO: need to check if works
    # TODO: need to check if works
    run.config["end_epoch"] = end_epoch
    run.config["start_epoch"] = start_epoch

    if end_epoch == max_epoch:
        # TODO: fix types, from list and strings to numbers
        # TODO: add model param

        mean_latency, std_latency, mean_fps = latency(device, model, H=512, W=1024)


        run.log({
            "model/flops": num_flops(device, model, 512, 1024),
            "model/latency_mean": mean_latency,
            "model/latency_std": std_latency,
            "model/mean_fps": mean_fps,
            "model/param": num_param(model)
        })

    run.finish()

wandb.finish()
pipeline()


0,1
epoch/100xlearning_rate,▁
epoch/step,▁▁
epoch/train_loss,▁
epoch/train_mIou,▁
train/batch_loss,██▅▄▄▅▃▃▄▂▃▃▂▄▁▂▂▂▃
train/batch_mIou,▁▃▆▅▅▄▇▂▆▆▅▃▅▄▄█▆▅▃
train/step,▁▁▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇██
validate/batch_loss,▁
validate/batch_mIou,▁
validate/step,▁

0,1
epoch/100xlearning_rate,0.25
epoch/step,1.0
epoch/train_loss,2.48955
epoch/train_mIou,26.33832
train/batch_loss,2.01681
train/batch_mIou,48.04531
train/step,18.0
validate/batch_loss,0.45048
validate/batch_mIou,46.92132
validate/step,0.0


-----------------------------
Epoch 1
[Poly LR] 100xLR: 0.250000
[Train Loss] : 2.538756 [mIoU]: 26.77%
[Validation Loss] : 0.479447 [mIoU]: 38.61%
-----------------------------
Epoch 2
[Poly LR] 100xLR: 0.245495
[Train Loss] : 1.453502 [mIoU]: 44.45%
[Validation Loss] : 0.413842 [mIoU]: 44.99%
Model saved to ./models/step_2B_new_BiSeNet-resnet18_Cityscapes_epoch_2.pth


0,1
epoch/100xlearning_rate,█▁
epoch/step,▁▁▁███
epoch/train_loss,█▁
epoch/train_mIou,▁█
epoch/val_loss,█▁
epoch/val_mIou,▁█
train/batch_loss,█▇▅▄▃▃▃▃▃▃▂▂▂▂▁▃▁▁▁▂▂▂▂▂▁▁▁▁▁▁▁▁▃▂▁▁▁▂
train/batch_mIou,▁▂▄▆▇▄▅▅▄▅▆▃▃▆▅▅▆▄▃▅▄▅▅▃▄▅▅▆▃▃▃█▂▃▄▄▇▃
train/step,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
validate/batch_loss,▂▄▅▄▄▅▅▆█▄▄▃▄▆▅▆▄▅█▃▁▂▁▁▄▁▃▃▄▄▂▂▂▃▁▁▃▇

0,1
epoch/100xlearning_rate,0.2455
epoch/step,2.0
epoch/train_loss,1.4535
epoch/train_mIou,44.44945
epoch/val_loss,0.41384
epoch/val_mIou,44.98872
train/batch_loss,1.39806
train/batch_mIou,48.4876
train/step,37.0
validate/batch_loss,0.53728
