# Preparations

Imports

In [None]:
import random
import numpy as np

import torch
from torch import nn
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import torch.optim as optim

import wandb

from PIL import Image
import seaborn as sns
import matplotlib.pyplot as plt

import os
import zipfile
from tqdm.notebook import tqdm
import time
import cv2
import shutil

# !pip -q install -U fvcore
from fvcore.nn import FlopCountAnalysis #, flop_count_table

from dataclasses import dataclass
from typing import Tuple

import warnings
warnings.filterwarnings(action='ignore')

from utils import *


## Utils

### Variables

#### Dataset settings

In [None]:
randomcrop = RandomCrop((512, 1024))

s = 0.4
source_augmentation = v2.Compose([
    v2.RandomApply([v2.GaussianBlur(7)],p=0.5),
    v2.RandomApply([v2.ColorJitter(brightness=s, contrast=s, saturation=s, hue=s)],p=0.5),
    # v2.RandomApply([v2.GaussianNoise(mean=0.0, sigma=0.4)],p=0.5)
])

s = 0.1
target_augmentation = v2.Compose([
    v2.RandomApply([v2.GaussianBlur(3)],p=0.5),
    v2.RandomApply([v2.ColorJitter(brightness=s, contrast=s, saturation=s, hue=s)],p=0.5),
    # v2.RandomApply([v2.GaussianNoise(mean=0.0, sigma=0.03)],p=0.5)
])

s = 0.25
mixed_augmentation = target_augmentation = v2.Compose([
    v2.RandomApply([v2.GaussianBlur(7)],p=0.5),
    v2.RandomApply([v2.ColorJitter(brightness=s, contrast=s, saturation=s, hue=s)],p=0.5),
    # v2.RandomApply([v2.GaussianNoise(mean=0.0, sigma=0.03)],p=0.5)
])

source_augmentation = nn.Identity()
target_augmentation = nn.Identity()
mixed_augmentation = nn.Identity()

resolution_source_dataset = (720, 1280)
resolution_target_dataset = (512, 1024)

toTensor = v2.ToTensor()

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

normalize = v2.Normalize(mean=mean, std=std)

img_transform_target = v2.Compose([
    toTensor,
    v2.Resize(resolution_target_dataset),
    # normalize
])
label_transform_target = v2.Compose([
    toTensor,
    v2.Resize(resolution_target_dataset, interpolation=v2.InterpolationMode.NEAREST),
])

img_transform_source = v2.Compose([
    toTensor,
    v2.Resize(resolution_source_dataset),
    # normalize
])

label_transform_source = v2.Compose([
    toTensor,
    v2.Resize(resolution_source_dataset, interpolation=v2.InterpolationMode.NEAREST),
])


#### Logging

In [None]:
ENABLE_PRINT = False
ENABLE_WANDB_LOG = True
log_per_epoch = 20
n_classes = 19

train_step = 0
val_step = 0

#### Device

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu')
print(device)

# Downloads

## CityScapes download

In [None]:
# !pip install -q gdown

# file_id = "1MI8QsvjW0R6WDJiL49L7sDGpPWYAQB6O"
# !gdown https://drive.google.com/uc?id={file_id}

# pretty_extract("Cityscapes.zip", ".")


## GTA5 download

In [None]:
# !pip install -q gdown

# file_id = "1PWavqXDxuifsyYvs2PFua9sdMl0JG8AE"
# !gdown https://drive.google.com/uc?id={file_id}

# pretty_extract("Gta5_extended.zip", "./Gta5_extended")

## DeepLabV2 model weights

In [None]:
# !pip install -q gdown

# file_id = "1KgYgBTmvq7UcBwKui2b4TomnbTmzJMBf"
# !gdown https://drive.google.com/uc?id={file_id}

# Dataset

In [None]:
from datasets.cityscapes import CityScapes
from datasets.gta5 import GTA5_dataset_splitter

# Model

In [None]:
from models.bisenet.build_bisenet import *
from models.deeplabv2.deeplabv2 import *
from models.hrda.build_hrda import *

# Loss


## Bisenet

In [None]:
import torch.nn as nn

class BiSeNetloss(nn.Module):
    def __init__(self, ignore_index=255):
        super().__init__()
        self.ignore_index = ignore_index
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, outputs, cx1_sup, cx2_sup, label):
        # NOTE: the auxilliary losses may not be used

        loss = self.ce_loss(outputs, label) + self.ce_loss(cx1_sup, label) + self.ce_loss(cx2_sup, label)

        return loss


## DACS

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# NOTE: threshold = 0.968, from https://github.com/vikolss/DACS/blob/cc6a87f23b1c81ae32edad767b9772258774a974/trainUDA.py#L467
class DACSloss(nn.Module):
    def __init__(self, threshold = 0.968, ignore_index=255):
        super().__init__()
        self.ignore_index = ignore_index
        self.threshold = threshold
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, outputs_source, cx1_sup_src, cx2_sup_src, label_source, 
                outputs_mixed, cx1_sup_mixed, cx2_sup_mixed, label_mixed):
        """
        outputs_src: [B, C, H, W]
        targets: [B, H, W]
        outputs_tgt: [B, C, H, W]
        """
        # Cross entropy on source
        # NOTE: the auxilliary losses may not be used

        bisenet_loss = BiSeNetloss()

        if cx1_sup_src is not None:
            l1 = bisenet_loss(outputs_source, cx1_sup_src, cx2_sup_src, label_source)
        else:
            l1 = self.ce_loss(torch.argmax(outputs_source, 1), label_source)

        if cx1_sup_mixed is not None:
            l2 = bisenet_loss(outputs_mixed, cx1_sup_mixed, cx2_sup_mixed, label_mixed)
        else:
            l2 = self.ce_loss(torch.argmax(outputs_mixed, 1), label_mixed)
            
        max_probs, _ = torch.softmax(outputs_mixed, 1).max(1)

        f = (max_probs >= self.threshold).float()
        if f.size(-1) > 0:
            lambda_ = f.mean().item()
        else:
            lambda_ = 0

        loss = l1 + lambda_*l2

        return loss, lambda_


## FDA

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def charbonnier_penalty(x, eta = 0.5):
    return (x**2 + 0.001**2)**eta

class FDAloss(nn.Module):
    def __init__(self, lambda_entropy=0.01, ignore_index=255):
        super().__init__()
        self.lambda_entropy = lambda_entropy
        self.ignore_index = ignore_index
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, outputs_src, targets, outputs_tgt):
        """
        outputs_src: [B, C, H, W]
        targets: [B, H, W]
        outputs_tgt: [B, C, H, W]
        """
        # Cross entropy on source
        loss_ce = self.ce_loss(outputs_src, targets)

        # entropy on target
        probs = F.softmax(outputs_tgt, dim=1)                    # [B, C, H, W]
        log_probs = F.log_softmax(outputs_tgt, dim=1)            # [B, C, H, W]
        entropy = -torch.sum(probs * log_probs, dim=1)           # [B, H, W]

        if targets.shape == entropy.shape:
            valid_mask = (targets != self.ignore_index).float()  # [B, H, W]
            entropy = entropy * valid_mask

        entropy_per_image = charbonnier_penalty(entropy.view(entropy.size(0), -1).sum(dim=1))  # [B]

        # Totale
        loss_ent = entropy_per_image.sum()  # somma sui batch
        total_loss = loss_ce + self.lambda_entropy * loss_ent

        return total_loss


# Train/Val loops

## Train Loop

### Train step 2/3

In [None]:
def train3(model:nn.Module, train_loader:DataLoader, criterion:nn.Module, optimizer:optim.Optimizer) -> tuple[float, float, torch.Tensor, torch.Tensor]:
    global device
    global n_classes
    global ENABLE_PRINT
    global ENABLE_WANDB_LOG
    global train_step
    global log_per_epoch

    model.train()

    num_batch = len(train_loader)
    chunk_batch = num_batch//log_per_epoch+1

    num_sample = len(train_loader.dataset)
    seen_sample = 0

    train_loss = 0.0
    train_hist = torch.zeros((n_classes,n_classes), device=device)

    for batch_idx, (inputs, _, targets) in tqdm(enumerate(train_loader), total=len(train_loader)):
        batch_size = inputs.size(0)
        seen_sample += batch_size

        inputs, targets = inputs.to(device), targets.squeeze().to(device)

        if isinstance(model, BiSeNetWithHRDA):
            lr_img, lr_label, hr_img, hr_label, coords = model.hrda_crop(normalize(inputs), targets)
            outputs, (lr_out, hr_out) = model.hrda_forward(lr_img, hr_img, coords)
            loss = model.hrda_loss(criterion, outputs, targets, lr_out, lr_label, hr_out, hr_label)
        else:
            outputs, cx1_sup, cx2_sup = model(normalize(inputs))
            if cx1_sup is not None and cx2_sup is not None:
                loss = criterion(outputs, cx1_sup, cx2_sup, targets)
            else:
                loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        predicted = outputs.argmax(1)

        hist_batch = fast_hist_cuda(targets.view(-1).detach(), predicted.view(-1).detach(), n_classes)

        train_loss += loss.item() * batch_size
        train_hist += hist_batch

        if ((batch_idx+1) % chunk_batch) == 0:
            iou_batch = per_class_iou_cuda(hist_batch)
            if ENABLE_PRINT:
                    print(f'Training [{seen_sample}/{num_sample} ({100. * seen_sample / num_sample:.0f}%)]')
                    print(f'\tLoss: {loss.item():.6f}')
                    print(f"\tmIoU: {100.*iou_batch[iou_batch > 0].mean():.4f}")

            if ENABLE_WANDB_LOG:
                wandb.log({
                        "train/step": train_step,
                        "train/batch_loss": loss.item(),
                        "train/batch_mIou": 100.*iou_batch[iou_batch > 0].mean()
                    },
                    commit=True,
                )
                train_step += 1

    train_loss = train_loss / seen_sample

    train_iou_class = per_class_iou_cuda(train_hist)
    train_mIou = train_iou_class[train_iou_class > 0].mean().item()

    return train_loss, train_mIou, train_iou_class, train_hist 


### Train step 4

In [None]:
def train_4A(model:nn.Module, source_loader:DataLoader, target_loader:DataLoader , criterion:FDAloss, optimizer:optim.Optimizer, beta: float = 0.01, n:int = 1, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -> tuple[float, float, torch.Tensor, torch.Tensor]:
    global device
    global n_classes
    global ENABLE_PRINT
    global ENABLE_WANDB_LOG
    global train_step
    global log_per_epoch

    model.train()

    num_batch = len(source_loader)
    chunk_batch = num_batch//log_per_epoch+1

    num_sample = len(source_loader.dataset)
    seen_sample = 0

    train_loss = 0.0
    train_hist = torch.zeros((n_classes,n_classes)).to(device)

    for batch_idx, (inputs_src, _, targets_src), (inputs_tgt, _, _) in tqdm(enumerate(zip(source_loader, target_loader)), total=min(len(source_loader), len(target_loader))):
        inputs_src, inputs_tgt = FDA(inputs_src, inputs_tgt, beta, n, mean, std)

        batch_size = inputs_src.size(0)
        seen_sample += batch_size

        inputs_src, targets_src = inputs_src.to(device), targets_src.squeeze(1).to(device)
        inputs_tgt = inputs_tgt.to(device)

        outputs_src, cx1_sup_src, cx2_sup_src = model(inputs_src)
        outputs_tgt, cx1_sup_tgt, cx2_sup_tgt = model(inputs_tgt)

        loss = criterion(outputs_src, targets_src, outputs_tgt)
        loss += criterion(cx1_sup_src, targets_src, cx1_sup_tgt)
        loss += criterion(cx2_sup_src, targets_src, cx2_sup_tgt)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        predicted_src = outputs_src.argmax(1)

        hist_batch = torch.zeros((n_classes, n_classes)).to(device)

        assert(predicted_src.shape == targets_src.shape)
        hist_batch = fast_hist_cuda(targets_src.view(-1).detach(),predicted_src.view(-1).detach(),n_classes)

        train_loss += loss.item() * batch_size
        train_hist += hist_batch

        if ((batch_idx+1) % chunk_batch) == 0:
            iou_batch = per_class_iou_cuda(hist_batch)
            if ENABLE_PRINT:
                    print(f'Training [{seen_sample}/{num_sample} ({100. * seen_sample / num_sample:.0f}%)]')
                    print(f'\tLoss: {loss.item():.6f}')
                    print(f"\tmIoU: {100.*iou_batch[iou_batch > 0].mean():.4f}")

            if ENABLE_WANDB_LOG:
                wandb.log({
                        "train/step": train_step,
                        "train/batch_loss": loss.item(),
                        "train/batch_mIou": 100.*iou_batch[iou_batch > 0].mean()
                    },
                    commit=True,
                )
                train_step += 1

    train_loss = train_loss / seen_sample

    train_iou_class = per_class_iou_cuda(train_hist)
    train_mIou = train_iou_class[train_iou_class > 0].mean().item()

    return train_loss, train_mIou, train_hist, train_iou_class

def train_4B(model:nn.Module, source_loader:DataLoader, target_loader:DataLoader, criterion:DACSloss, optimizer:optim.Optimizer, classmixer:ClassMixer) -> tuple[float, float, torch.Tensor, torch.Tensor]:
    global device
    global n_classes
    global ENABLE_PRINT
    global ENABLE_WANDB_LOG
    global train_step
    global log_per_epoch

    model.train()

    num_batch = len(source_loader)
    chunk_batch = num_batch//log_per_epoch+1

    num_sample = len(source_loader.dataset)
    seen_sample = 0

    train_loss = 0.0
    train_hist = torch.zeros((n_classes,n_classes), device=device)
    
    for batch_idx, ((inputs_src, _, label_source), (inputs_target, _, _)) in tqdm(enumerate(zip(source_loader, target_loader)), total=min(len(source_loader), len(target_loader))):
        inputs_src, label_source = inputs_src.to(device), label_source.squeeze().to(device)
        inputs_target = inputs_target.to(device)        
    
        inputs_target = target_augmentation(inputs_target)

        inputs_src, label_source = randomcrop(inputs_src, label_source)
        inputs_target, _ = randomcrop(inputs_target, None)
        
        B = inputs_src.size(0) + inputs_target.size(0)
        seen_sample += B
        
        with torch.no_grad():
            outputs_target = model(normalize(inputs_target).detach())[0]
        label_target = outputs_target.argmax(1).detach()

        inputs_mixed, label_mixed = classmixer(inputs_src, inputs_target, label_source, label_target)

        inputs_src = source_augmentation(inputs_src)
        inputs_mixed = mixed_augmentation(inputs_mixed)

        inputs_src = normalize(inputs_src)
        inputs_mixed = normalize(inputs_mixed)

        outputs_source, cx1_sup_src, cx2_sup_src = model(inputs_src.detach())
        outputs_mixed, cx1_sup_mixed, cx2_sup_mixed = model(inputs_mixed.detach())
        
        # Loss calculation:
        loss, lambda_ = criterion(outputs_source, cx1_sup_src, cx2_sup_src, label_source, outputs_mixed, cx1_sup_mixed, cx2_sup_mixed, label_mixed)
        
        predicted_source = outputs_source.argmax(1)
        predicted_mixed = outputs_mixed.argmax(1)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        y, y_pred = torch.reshape(torch.cat((label_source,      label_mixed)),     (-1,)),\
                    torch.reshape(torch.cat((predicted_source,  predicted_mixed)), (-1,)) 

        hist_batch = fast_hist_cuda(y, y_pred, n_classes)

        train_loss += loss.item() * B
        train_hist += hist_batch

        if ((batch_idx+1) % chunk_batch) == 0:
            iou_batch = per_class_iou_cuda(hist_batch)
            if ENABLE_PRINT:
                    print(f'Training [{seen_sample}/{num_sample} ({100. * seen_sample / num_sample:.0f}%)]')
                    print(f'\tLoss: {loss.item():.6f}')
                    print(f"\tmIoU: {100.*iou_batch[iou_batch > 0].mean():.4f}")
                    print(f"\tunsupervised_confidence: {lambda_}")

            if ENABLE_WANDB_LOG:
                wandb.log({
                        "train/step": train_step,
                        "train/batch_loss": loss.item(),
                        "train/batch_mIou": 100.*iou_batch[iou_batch > 0].mean(),
                        "train/unsupervised_confidence": lambda_
                    }
                )

                train_step += 1

    train_loss = train_loss / seen_sample

    train_iou_class = per_class_iou_cuda(train_hist)
    train_mIou = train_iou_class[train_iou_class > 0].mean().item()

    return train_loss, train_mIou, train_iou_class, train_hist


### Train step 5

In [None]:
def train5_1(model:BiSeNetWithHRDA, source_loader:DataLoader, criterion:nn.CrossEntropyLoss, optimizer:optim.Optimizer) -> tuple[float, float, torch.Tensor, torch.Tensor]:
    global device
    global n_classes
    global ENABLE_PRINT
    global ENABLE_WANDB_LOG
    global train_step
    global log_per_epoch

    model.train()

    num_batch = len(source_loader)
    chunk_batch = num_batch//log_per_epoch+1

    num_sample = len(source_loader.dataset)
    seen_sample = 0

    train_loss = 0.0
    train_hist = torch.zeros((n_classes,n_classes), device=device)
    
    for batch_idx, (inputs_src, _, label_source) in tqdm(enumerate(source_loader), total=len(source_loader)):
        inputs_src, label_source = randomcrop(inputs_src, label_source)
        
        B = inputs_src.size(0)
        seen_sample += B

        inputs_src, label_source = inputs_src.to(device), label_source.squeeze().to(device)

        inputs_src = source_augmentation(inputs_src)

        inputs_src = normalize(inputs_src)
        
        lr_img_src, lr_label_src, hr_img_src, hr_label_src, coords_src = model.hrda_crop(inputs_src, label_source)
        outputs_source, (lr_out_src, hr_out_src) = model.hrda_forward(lr_img_src, hr_img_src, coords_src)

        loss = model.hrda_loss(criterion, outputs_source, label_source, lr_out_src, lr_label_src, hr_out_src, hr_label_src)
  
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        predicted_source = outputs_source.argmax(1)

        y, y_pred = torch.reshape(label_source.detach(), (-1,)),\
                    torch.reshape(predicted_source.detach(), (-1,)) 

        hist_batch = fast_hist_cuda(y, y_pred, n_classes)

        train_loss += loss.item() * B
        train_hist += hist_batch

        if ((batch_idx+1) % chunk_batch) == 0:
            iou_batch = per_class_iou_cuda(hist_batch)
            if ENABLE_PRINT:
                    print(f'Training [{seen_sample}/{num_sample} ({100. * seen_sample / num_sample:.0f}%)]')
                    print(f'\tLoss: {loss.item():.6f}')
                    print(f"\tmIoU: {100.*iou_batch[iou_batch > 0].mean():.4f}")
                    # print(f"\tunsupervised_confidence: {lambda_}")

            if ENABLE_WANDB_LOG:
                wandb.log({
                        "train/step": train_step,
                        "train/batch_loss": loss.item(),
                        "train/batch_mIou": 100.*iou_batch[iou_batch > 0].mean(),
                        # "train/unsupervised_confidence": lambda_
                    }
                )

                train_step += 1

    train_loss = train_loss / seen_sample

    train_iou_class = per_class_iou_cuda(train_hist)
    train_mIou = train_iou_class[train_iou_class > 0].mean().item()

    return train_loss, train_mIou, train_iou_class, train_hist



In [None]:
def train5_2(model:BiSeNetWithHRDA, source_loader:DataLoader, target_loader:DataLoader, criterion:nn.CrossEntropyLoss, optimizer:optim.Optimizer, classmixer:ClassMixer) -> tuple[float, float, torch.Tensor, torch.Tensor]:
    global device
    global n_classes
    global ENABLE_PRINT
    global ENABLE_WANDB_LOG
    global train_step
    global log_per_epoch

    model.train()

    num_batch = len(source_loader)
    chunk_batch = num_batch//log_per_epoch+1

    num_sample = len(source_loader.dataset)
    seen_sample = 0

    train_loss = 0.0
    train_hist = torch.zeros((n_classes,n_classes), device=device)

    threshold = 0.968
    
    # NOTE: threshold = 0.968, from https://github.com/vikolss/DACS/blob/cc6a87f23b1c81ae32edad767b9772258774a974/trainUDA.py#L467
    for batch_idx, ((inputs_src, _, label_source), (inputs_target, _, _)) in tqdm(enumerate(zip(source_loader, target_loader)), total=min(len(source_loader), len(target_loader))):
        inputs_src, label_source = inputs_src.to(device), label_source.squeeze().to(device)
        inputs_target = inputs_target.to(device)        
        
        inputs_target = target_augmentation(inputs_target)

        inputs_src, label_source = randomcrop(inputs_src, label_source)
        inputs_target, _ = randomcrop(inputs_target, None)
        
        B = inputs_src.size(0) + inputs_target.size(0)
        seen_sample += B
        
        with torch.no_grad():
            outputs_target = model(normalize(inputs_target).detach())[0]
        label_target = outputs_target.argmax(1)

        inputs_mixed, label_mixed = classmixer(inputs_src, inputs_target, label_source, label_target)

        inputs_src = source_augmentation(inputs_src)
        inputs_mixed = mixed_augmentation(inputs_mixed)

        inputs_src = normalize(inputs_src)
        inputs_mixed = normalize(inputs_mixed)
 
        lr_img_src, lr_label_src, hr_img_src, hr_label_src, coords_src = model.hrda_crop(inputs_src, label_source)
        outputs_source, (lr_out_src, hr_out_src) = model.hrda_forward(lr_img_src, hr_img_src, coords_src)

        lr_img_m, lr_label_m, hr_img_m, hr_label_m, coords_m = model.hrda_crop(inputs_mixed, label_mixed)
        outputs_m, (lr_out_m, hr_out_m) = model.hrda_forward(lr_img_m, hr_img_m, coords_m)
        
        # Loss calculation:
        l1 = model.hrda_loss(criterion, outputs_source, label_source, lr_out_src, lr_label_src, hr_out_src, hr_label_src)
        l2 = model.hrda_loss(criterion, outputs_m, label_mixed, lr_out_m, lr_label_m, hr_out_m, hr_label_m)

        max_probs, predicted_m = torch.softmax(outputs_m, 1).max(1)
        predicted_src = torch.argmax(outputs_source, 1)

        f = (max_probs >= threshold).float()
        if f.size(-1) > 0:
            lambda_ = f.mean().item()
        else:
            lambda_ = 0

        loss = l1 + lambda_*l2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        y, y_pred = torch.reshape(torch.cat((label_source,   label_mixed)),     (-1,)),\
                    torch.reshape(torch.cat((predicted_src,  predicted_m)),     (-1,)) 

        hist_batch = fast_hist_cuda(y, y_pred, n_classes)

        train_loss += loss.item() * B
        train_hist += hist_batch

        if ((batch_idx+1) % chunk_batch) == 0:
            iou_batch = per_class_iou_cuda(hist_batch)
            if ENABLE_PRINT:
                    print(f'Training [{seen_sample}/{num_sample} ({100. * seen_sample / num_sample:.0f}%)]')
                    print(f'\tLoss: {loss.item():.6f}')
                    print(f"\tmIoU: {100.*iou_batch[iou_batch > 0].mean():.4f}")
                    print(f"\tunsupervised_confidence: {lambda_}")

            if ENABLE_WANDB_LOG:
                wandb.log({
                        "train/step": train_step,
                        "train/batch_loss": loss.item(),
                        "train/batch_mIou": 100.*iou_batch[iou_batch > 0].mean(),
                        "train/unsupervised_confidence": lambda_
                    }
                )

                train_step += 1

    train_loss = train_loss / seen_sample

    train_iou_class = per_class_iou_cuda(train_hist)
    train_mIou = train_iou_class[train_iou_class > 0].mean().item()

    return train_loss, train_mIou, train_iou_class, train_hist


## Validation loop

In [None]:
def validate(model:nn.Module, val_loader:DataLoader, criterion:nn.Module) -> tuple[float, float, torch.Tensor, torch.Tensor]:
    global device
    global n_classes
    global ENABLE_PRINT
    global ENABLE_WANDB_LOG
    global val_step
    global log_per_epoch

    model.eval()

    num_batch = len(val_loader)
    chunk_batch = num_batch//log_per_epoch+1

    num_sample = len(val_loader.dataset)
    seen_sample = 0
    chunk_sample = 0

    val_loss = 0.0
    chunk_loss = 0.0

    with torch.no_grad():
        val_hist = torch.zeros((n_classes,n_classes), device=device)
        chunk_hist = torch.zeros((n_classes,n_classes), device=device)
    
        for batch_idx, (inputs, _, targets) in tqdm(enumerate(val_loader), total=num_batch):
            inputs, targets = randomcrop(inputs, targets)

            B = inputs.size(0)
        
            inputs, targets = inputs.to(device), targets.squeeze().to(device)

            if isinstance(model, BiSeNetWithHRDA):
                lr_img, lr_label, hr_img, hr_label, coords = model.hrda_crop(normalize(inputs), targets)
                outputs, (lr_out, hr_out) = model.hrda_forward(lr_img, hr_img, coords)
                
                # loss = model.hrda_loss(criterion, outputs, targets, lr_out, lr_label, hr_out, hr_label)
            else:
                outputs = model(normalize(inputs))
                
            loss = criterion(outputs, targets)

            predicted = outputs.argmax(1)

            hist_batch = fast_hist_cuda(torch.reshape(targets.detach(), (-1,)), torch.reshape(predicted.detach(), (-1,)), n_classes)

            chunk_sample += B
            chunk_loss += loss.item() * B
            chunk_hist += hist_batch

            if ((batch_idx+1) % chunk_batch) == 0:
                seen_sample += chunk_sample
                val_loss += chunk_loss
                val_hist += chunk_hist

                if ENABLE_PRINT:
                    iou_batch = per_class_iou_cuda(hist_batch)
                    print(f'Validation [{seen_sample}/{num_sample} ({100. * seen_sample / num_sample:.0f}%)]')
                    print(f'\tLoss: {loss.item():.6f}')
                    print(f"\tmIoU: {100.*iou_batch[iou_batch > 0].mean():.4f}")

                if ENABLE_WANDB_LOG:
                    iou_batch = per_class_iou_cuda(chunk_hist)
                    wandb.log({
                            "validate/step": val_step,
                            "validate/batch_loss": chunk_loss/chunk_sample,
                            "validate/batch_mIou": 100.*iou_batch[iou_batch > 0].mean()
                        }
                    )

                    val_step += 1

                chunk_sample = 0
                chunk_loss = 0.0
                chunk_hist *= 0.0

        if chunk_sample > 0:
            seen_sample += chunk_sample
            val_loss += chunk_loss
            val_hist += chunk_hist

            if ENABLE_PRINT:
                iou_batch = per_class_iou_cuda(hist_batch)
                print(f'Validation [{seen_sample}/{num_sample} ({100. * seen_sample / num_sample:.0f}%)]')
                print(f'\tLoss: {loss.item():.6f}')
                print(f"\tmIoU: {100.*iou_batch[iou_batch > 0].mean():.4f}")

            if ENABLE_WANDB_LOG:
                iou_batch = per_class_iou_cuda(chunk_hist)
                wandb.log({
                        "validate/step": val_step,
                        "validate/batch_loss": chunk_loss/chunk_sample,
                        "validate/batch_mIou": 100.*iou_batch[iou_batch > 0].mean()
                    }
                )

                val_step += 1

    val_loss = val_loss / seen_sample

    val_iou_class = per_class_iou_cuda(val_hist)
    val_mIou = val_iou_class[val_iou_class > 0].mean().item()

    return val_loss, val_mIou, val_iou_class, val_hist

# Machine learning

In [None]:
def pipeline():
    global device
    global n_classes
    global ENABLE_PRINT
    global ENABLE_WANDB_LOG
    global train_step
    global val_step
    global log_per_epoch

    ENABLE_PRINT = False
    ENABLE_WANDB_LOG = True
    train_step = 0
    val_step = 0
    log_per_epoch = 20

    models_root_dir = "./models"
    # !rm -r {models_root_dir}
    # !rd /s /q {models_root_dir}
    !mkdir {models_root_dir}

    B = 4
    n_classes = 19

    backbone = "..."
    context_path = "..."
    dataset = "..."

    start_epoch = 0 # <--------- Last epoch that has been completed (in this script i will perform from the start+1 to end)
    end_epoch = 50
    max_epoch = 50

    assert start_epoch < end_epoch <= max_epoch, "Check your start/end/max epoch settings."

    init_lr=2.5e-2
    lr_decay_iter = 1
    momentum=0.9
    weight_decay=5e-4
    n=2
    beta = 0.01
    lambda_entropy = 0.01
    eta = 2
    s = 0.5
    rcs = False

    # Dataset objects
    if dataset == "Cityscapes":
        data_train = CityScapes("./Cityscapes/Cityspaces", split="train", transform=img_transform_target, target_transform=label_transform_target)
        data_val = CityScapes("./Cityscapes/Cityspaces", split="val", transform=img_transform_target, target_transform=label_transform_target)

        train_loader = DataLoader(data_train, batch_size=B, shuffle=True)
        val_loader = DataLoader(data_val, batch_size=B, shuffle=True)
    elif dataset == "GTA5":
        data_train, data_val = GTA5_dataset_splitter("./Gta5_extended", train_split_percent=0.8, split_seed=42, transform=img_transform_source, target_transform=label_transform_source)
        
        train_loader = DataLoader(data_train, batch_size=B, shuffle=True)
        val_loader = DataLoader(data_val, batch_size=B, shuffle=True)
    elif dataset == "Augmentation":
        data_train, _ = GTA5_dataset_splitter("./Gta5_extended", train_split_percent=1, split_seed=42, augment=True, rcs=rcs, transform=img_transform_source, target_transform=label_transform_source)
        data_val = CityScapes("./Cityscapes/Cityspaces", split="val", transform=img_transform_target, target_transform=label_transform_target)

        train_loader = DataLoader(data_train, batch_size=B, shuffle=True)
        val_loader = DataLoader(data_val, batch_size=B, shuffle=True)
    elif dataset == "Mixed":
        city_val = CityScapes("./Cityscapes/Cityspaces", split="val", transform=img_transform_target, target_transform=label_transform_target)
        city_train = CityScapes("./Cityscapes/Cityspaces", split="train", transform=img_transform_target, target_transform=label_transform_target)
        gta_train, _ = GTA5_dataset_splitter("./Gta5_extended", 1.0, split_seed=42, augment=False, rcs=rcs, transform=img_transform_source, target_transform=label_transform_source)

        source_loader = DataLoader(gta_train, batch_size=B, shuffle=True)
        target_loader = DataLoader(city_train, batch_size=B, shuffle=True)
        val_loader = DataLoader(city_val, batch_size=B, shuffle=True)
    else:
        raise Exception("Wrong dataset name")

    # Architecture
    if backbone == "BiSeNet":
        model = BiSeNet(n_classes, context_path).to(device)
        architecture = backbone+"-"+context_path
    elif backbone == "DeepLab":
        model = get_deeplab_v2(num_classes=n_classes, pretrain=True).to(device)
        architecture = backbone
    elif backbone == "BiSeNetHRDA":
        model = BiSeNetWithHRDA(n_classes, context_path, s).to(device)
        architecture = backbone+"-"+context_path
    else:
        raise Exception("Wrong model name")

    # print(model)
    
    classmixer = ClassMixer(n_classes, 0.5, device=device)
    
    criterion_train = nn.CrossEntropyLoss(ignore_index=255)
    # criterion_train = FDAloss(lambda_entropy=lambda_entropy, eta=eta, ignore_index=255)
    # criterion_train = BiSeNetloss()
    # criterion_train = DACSloss()

    criterion_val = nn.CrossEntropyLoss(ignore_index=255)

    optimizer = optim.SGD(model.parameters(), lr=init_lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)

    run_name = f"..."
    run_detail = f"{run_name}_{architecture}_{dataset}"
    run_id = None
    
    # Wandb setup and metrics
    if ENABLE_WANDB_LOG:
        run = wandb.init(
            entity="Machine_learning_and_Deep_learning_labs",
            project="Semantic Segmentation",
            name=run_name,
            id=run_id, # <--------------- If run already created (start_epoch > 0) decomment
            resume="allow", # <----------------  IMPORTANT CONFIG KEY
            config={
                "initial_learning_rate": init_lr,
                "lr_decay_iter": lr_decay_iter,
                "momentum": momentum,
                "weight_decay": weight_decay,
                "architecture": architecture,
                "dataset": dataset,
                "start_epoch": start_epoch,
                "end_epoch": end_epoch,
                "max_epoch": max_epoch,
                "batch": B,
                "lr_scheduler": "poly"
            },
        )
        if run_id is None:
            print(f"\nThe ID for this run is: {run.id} save this if first time creating the run, then use it in run creation: wandb.init(..., id = {run.id}, ...)\n")
        
        wandb.define_metric("epoch/step")
        wandb.define_metric("epoch/*", step_metric="epoch/step")

        wandb.define_metric("train/step")
        wandb.define_metric("train/*", step_metric="train/step")

        wandb.define_metric("validate/step")
        wandb.define_metric("validate/*", step_metric="validate/step")

    # Loading from a starting point
    if start_epoch > 0:
        artifact_path = os.path.join(models_root_dir, f"{run_name}_epoch_{start_epoch}.pth")

        checkpoint = torch.load(artifact_path, map_location=device)

        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        train_step = checkpoint["train_step"]+1
        val_step = checkpoint["validate_step"]+1

    # Main Loop
    for epoch in range(start_epoch+1, end_epoch+1):
        print("-----------------------------")
        print(f"Epoch {epoch}")

        lr = poly_lr_scheduler(optimizer, init_lr, epoch-1, max_iter=max_epoch)

        print(f"[Poly LR] 100xLR: {100.*lr:.6f}")

        train_loss, train_mIou, train_mIou_class, train_hist = train3(model, train_loader, criterion_train, optimizer)
        train_loss, train_mIou, train_mIou_class, train_hist = train_4A(model, source_loader, target_loader, criterion_train, optimizer, beta, n, mean, std)
        train_loss, train_mIou, train_mIou_class, train_hist = train_4B(model, source_loader, target_loader, criterion_train, optimizer, classmixer)
        train_loss, train_mIou, train_mIou_class, train_hist = train5_1(model, source_loader, criterion_train, optimizer)
        train_loss, train_mIou, train_mIou_class, train_hist = train5_2(model, source_loader, target_loader, criterion_train, optimizer, classmixer)

        print(f'[Train Loss] : {train_loss:.6f} [mIoU]: {100.*train_mIou:.2f}%')
        
        val_loss, val_mIou, val_mIou_class, val_hist = validate(model, val_loader, criterion_val)

        print(f'[Validation Loss] : {val_loss:.6f} [mIoU]: {100.*val_mIou:.2f}%')

        if (epoch % 2) == 0:
            checkpoint = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": train_step,
                "validate_step": val_step,
            }

            file_name = f"{run_name}_epoch_{epoch}.pth"
            
            file_path = os.path.join(models_root_dir, file_name)
            torch.save(checkpoint, file_path)

            print(f"Model saved to {file_path}")


        if epoch % 10 == 0:
            log_confusion_matrix("Confusion Matrix - Train", train_hist.cpu().numpy(), "epoch/train_confusion_matrix", "epoch/step", epoch)
            log_confusion_matrix("Confusion Matrix - Validate", val_hist.cpu().numpy(), "epoch/validate_confusion_matrix", "epoch/step", epoch)

            log_bar_chart_ioU(f"Train IoU per class - epoch {epoch}", [c.name for c in GTA5Labels_TaskCV2017().list_ if c.name != "void"], train_mIou, train_mIou_class.cpu().numpy(), "epoch/train_Iou_class", "epoch/step", epoch)
            log_bar_chart_ioU(f"Validate IoU per class - epoch {epoch}", [c.name for c in GTA5Labels_TaskCV2017().list_ if c.name != "void"], val_mIou, val_mIou_class.cpu().numpy(), "epoch/validate_Iou_class", "epoch/step", epoch)

    if end_epoch == max_epoch:
        mean_latency, std_latency, mean_fps = latency(device, model, H=512, W=1024)

        print(f"[Num Flops]: {num_flops(device, model, 512, 1024)}")
        print(f"[Mean Latency]: {mean_latency}")
        print(f"[Std Latency]: {std_latency}")
        print(f"[Mean FPS]: {mean_fps}")
        print(f"[Num Param]: {num_param(model)}")
        
pipeline()


In [None]:
import wandb
wandb.finish()

# Load another model

In [None]:
def load_model(artifact_name:str, file_prefix:str, device, id=None, epoch=50):
    # model = BiSeNet(19, "resnet18").to(device)
    # model = get_deeplab_v2(num_classes=n_classes, pretrain=True).to(device)
    model = BiSeNetWithHRDA(n_classes, "resnet18", 0.5).to(device)

    run = wandb.init(id=id, resume="allow")

    artifact = run.use_artifact(f'Machine_learning_and_Deep_learning_labs/Semantic Segmentation/{artifact_name}:epoch_{epoch}', type='model')
    # artifact = run.use_artifact(f'Machine_learning_and_Deep_learning_labs/Semantic Segmentation/{artifact_name}:v72', type='model')
    artifact_dir = artifact.download()
    # artifact_dir = "./artifacts/step_3B_BiSeNet-resnet18_GTA5-v42"

    artifact_path = os.path.join(artifact_dir, f"{file_prefix}_epoch_{epoch}.pth")

    checkpoint = torch.load(artifact_path, map_location=device)

    run.finish()

    model.load_state_dict(checkpoint["model_state_dict"])

    return model


In [None]:
device = "cuda"

run_name = "step_5_hrda_dacs_rcs"
backbone = "BiSeNetHRDA"
context_path = "resnet18"
dataset = "Mixed"
epoch = 50
model = load_model(artifact_name=f"{run_name}_{backbone}-{context_path}_{dataset}", file_prefix=run_name, device=device, id="zfe3iniu", epoch=epoch)
model.eval()
None

In [None]:
B = 4
H = 512
W = 1024

transform = v2.Compose([
    v2.ToTensor(),
    v2.Resize((H,W)),
    # v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
target_transform = v2.Compose([
    v2.ToTensor(),
    v2.Resize((H, W), interpolation=v2.InterpolationMode.NEAREST),
])

# s = 0.5
custom_augmentation = v2.Compose([
    nn.Identity(),
    # v2.GaussianBlur(7),
    # v2.ColorJitter(brightness=s, contrast=s, saturation=s, hue=s),
    # v2.GaussianNoise(mean=0, sigma=0.05)
])

# data = CityScapes("./Cityscapes/Cityspaces", split="train", transform=transform, target_transform=target_transform)
data = CityScapes("./Cityscapes/Cityspaces", split="val", transform=img_transform_target, target_transform=label_transform_target)
# data, _ = GTA5_dataset_splitter("./Gta5_extended", train_split_percent=1.0, split_seed=42, augment=False, transform=transform, target_transform=target_transform)

dataloader = DataLoader(data, batch_size=B, shuffle=True)
# dataloader = DataLoader(data, batch_size=B, shuffle=False)

criterion = nn.CrossEntropyLoss(ignore_index=255)

ENABLE_WANDB_LOG = False
# print(validate(model, dataloader, criterion))

img_tensor, color_tensor, label = next(iter(dataloader))

# img_tensor, label = randomcrop(img_tensor, label)

img_tensor = \
    custom_augmentation(
        img_tensor
    # )
).to(device)

# img_tensor = img_tensor.clamp(0,1)

# lr_img, lr_label, hr_img, hr_label, coords = model.hrda_crop(img_tensor, label.squeeze().to(device))

# fused, (lr_out, hr_out) = model.hrda_forward(lr_img, hr_img, coords)

# fused = model(normalize(img_tensor.detach()))
fused, lr_out, hr_out, coords = model.hrda_eval(normalize(img_tensor).to(device))

# fused, lr_out, hr_out, coords = model.hrda_eval(img_tensor)

predicted_labels = fused.argmax(1).cpu()
lr_predicted = lr_out.argmax(1).cpu()
hr_predicted = hr_out.argmax(1).cpu()

x1, y1, cw, ch = coords

for i in range(B):
    predicted_colors = decode_segmap(predicted_labels[i].numpy())
    lr_colors = decode_segmap(lr_predicted[i].numpy())
    hr_colors = decode_segmap(hr_predicted[i].numpy())
    true_colors = decode_segmap(label[i, 0].numpy())
    true_colors = decode_segmap(label[i, 0].detach().cpu().numpy())

    fig, axes = plt.subplot_mosaic(
        """
            AB
            CD
        """,
        figsize=(7,7), layout="tight"
    )
    # fig, ax = plt.subplots(2,2, figsize=(10,10), layout="tight")

    axes["A"].set_title("Original segmentation map")
    axes["A"].imshow(true_colors[y1:y1+ch, x1:x1+cw])
    # axes["A"].imshow(true_colors)
    axes["A"].axis('off')

    axes["B"].set_title("Predicted segmentation map")
    axes["B"].imshow(predicted_colors[y1:y1+ch, x1:x1+cw])
    # axes["B"].imshow(predicted_colors)
    axes["B"].axis('off')

    axes["C"].set_title("Low resolution")
    axes["C"].imshow(lr_colors[y1:y1+ch, x1:x1+cw])
    # axes["C"].imshow(lr_colors)
    axes["C"].axis('off')

    axes["D"].set_title("High resolution")
    axes["D"].imshow(hr_colors)
    axes["D"].axis('off')

    fig.show()


In [None]:
def log_confusion_matrix2(title:str, hist:np.ndarray):
    row_sums = hist.sum(axis=1, keepdims=True)
    safe_hist = np.where(row_sums == 0, 0, hist / row_sums)

    fig = plt.figure(figsize=(10, 8))
    sns.heatmap(100.*safe_hist, fmt=".2f", annot=True, cmap="Blues", annot_kws={'size': 7})
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(title)

     # --- Diagonal Metrics ---
    diag = np.diag(hist)
    total = hist.sum()
    overall_acc = diag.sum() / total if total > 0 else 0
    per_class_acc = diag / row_sums.squeeze()
    mean_class_acc = np.nanmean(per_class_acc)

    # Optional: Frobenius norm-based diagonal dominance score
    frob_total = np.linalg.norm(hist)
    frob_offdiag = np.linalg.norm(hist - np.diagflat(diag))
    diag_dominance_score = 1 - (frob_offdiag / frob_total) if frob_total > 0 else 0

    print(f"\n📊 Overall Accuracy: {overall_acc:.4f}")
    print(f"📈 Mean Per-Class Accuracy: {mean_class_acc:.4f}")
    print(f"📐 Diagonal Dominance (Frobenius-based): {diag_dominance_score:.4f}")

def log_bar_chart_ioU2(title:str, class_names:list, mIou:float, iou_class:np.ndarray):
    iou_percent = [round(iou*100., 2) for iou in iou_class]
    miou_percent = round(mIou*100., 2)

    all_labels = ["mIoU"] + class_names
    all_values = [miou_percent] + iou_percent

    fig = plt.figure(figsize=(14, 5))
    bars = plt.bar(range(len(all_values)), all_values, color='skyblue')
    plt.xticks(range(len(all_labels)), all_labels, rotation=45, ha="right")

    for i, bar in enumerate(bars):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width() / 2.0, height + 1, f'{height:.2f}',
                ha='center', va='bottom', fontsize=9)

    plt.ylabel("IoU (%)")
    plt.ylim(0, 105)
    plt.title(title)
    plt.tight_layout()

ENABLE_PRINT = False
ENABLE_WANDB_LOG = False

loss, val_mIou, val_mIou_class, val_hist = validate(model, dataloader, criterion)

In [None]:
print(f"Loss: {loss}\nmIou: {val_mIou}\nmIou per classe: {val_mIou_class.tolist()}")

log_confusion_matrix2(f"{run_name} - Confusion Matrix - Validate", val_hist.cpu().numpy())

log_bar_chart_ioU2(f"{run_name} - Validate IoU per class - epoch {epoch}", [c.name for c in GTA5Labels_TaskCV2017().list_ if c.name != "void"], val_mIou, val_mIou_class.cpu().numpy())

In [None]:
# data = CityScapes("./Cityscapes/Cityspaces", split="train", transform=None, target_transform=None)
# data = CityScapes("./Cityscapes/Cityspaces", split="val", transform=img_transform_target, target_transform=label_transform_target)
data, _ = GTA5_dataset_splitter("./Gta5_extended", train_split_percent=1.0, split_seed=42, augment=False, transform=None, target_transform=None)

np.ndarray(next(iter(data))[0]).size

# Class frequency calculation

In [None]:
data, _ = GTA5_dataset_splitter("./Gta5_extended", train_split_percent=1.0, split_seed=42, augment=True, transform=v2.ToTensor(), target_transform=v2.ToTensor())
dataloader = DataLoader(data, batch_size=1, shuffle=True)

freqs = torch.zeros(n_classes)

for _,_,label in tqdm(dataloader):
    freqs += torch.bincount(label[label < n_classes].view(-1), minlength=n_classes)[:n_classes]

    # plt.imshow(label[0].repeat(3,1,1).permute(1,2,0))

print(freqs)

In [None]:
objects = GTA5Labels_TaskCV2017().list_

print("id;name;count")
for i in range(n_classes):
    print(f"{objects[i].ID};{objects[i].name};{freqs[i].item()}")

In [None]:
import torch
import matplotlib.pyplot as plt

l = [2000279424,157799632,586026240,79183408,31111896,52304620,5818070,5331167,313206784,140294832,734290112,5793555,1213679,100922184,54996872,12020944,15980102,1665942,260148]

a = torch.Tensor(l)

a /= a.sum()

fig, axs= plt.subplots(2,2)

T = [[0.015, 0.02], [0.03, 0.06]]
correction = torch.zeros(19)
for i in range(2):
    for j in range(2):
        b = torch.exp((1-a)/T[i][j])/torch.sum(torch.exp((1-a)/T[i][j]))
    
        axs[i,j].plot(b)

T = [[0.015, 0.02], [0.03, 0.06]]
correction = torch.zeros(19)
for i in range(2):
    for j in range(2):
        b = torch.exp((1-a)/T[i][j])/torch.sum(torch.exp((1-a)/T[i][j]))
        
        correction += 1-b/0.6

        axs[i,j].plot((b+correction)/(b+correction).sum())
        
print(b.sum())
plt.plot(b)