# Seed initialization (to make results reproducible)

In [None]:
import torch
import random
import numpy as np

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchark = False

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

# Environment setup (execute for any Step)

## Package install

In [None]:
!pip install wget
!pip install requests gdown
!pip install fvcore
!pip install torchmetrics

## Import

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.patches as mpatches
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
import matplotlib.colors as mcolors
from pathlib import Path
import wget
import requests
import gdown
from torchvision import transforms
from torchvision.datasets import VisionDataset
from torch.utils.data import Subset, DataLoader
from enum import Enum
from google.colab import drive
import os
import torch.nn.functional as F
import time
from fvcore.nn import FlopCountAnalysis, flop_count_table
import torch.optim.lr_scheduler as lr_scheduler
from torch.backends import cudnn
from statistics import mean
import cv2
from torchmetrics.segmentation import MeanIoU
from torch.utils.data import ConcatDataset
import random
from torch.nn.utils import clip_grad
import albumentations as A

## Variables

In [None]:
DATA_DIR = 'loveDA_dataset'
TRAIN_ZIP = f'{DATA_DIR}/train.zip'
VAL_ZIP = f'{DATA_DIR}/validation.zip'
TEST_ZIP = f'{DATA_DIR}/test.zip'
TRAIN_DIR = f'{DATA_DIR}/train'
VAL_DIR = f'{DATA_DIR}/validation'
TEST_DIR = f'{DATA_DIR}/test'
RURAL_PATH = "Rural"
URBAN_PATH = "Urban"
IMG_PATH = "images_png"
MASK_PATH = "masks_png"
PRETRAINED_WEIGHTS_DIR = 'pretrained_weights'
DEEPLAB_V2_WEIGHTS = f'{PRETRAINED_WEIGHTS_DIR}/DeepLab_resnet_pretrained_imagenet.pth'
STDC1_WEIGHTS = f"{PRETRAINED_WEIGHTS_DIR}/STDC1_pretrained_weights.pth"

IGNORE_INDEX=-1

RGB = 'RGB'
grayscale = 'L'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Domain(Enum):
    RURAL = 0
    URBAN = 1

class ModelType(Enum):
    DEEPLAB = 0
    PIDNET = 1
    BISENET = 2
    STDC = 3

categories = {
    'BARREN': (0.003921568859368563, (159, 129, 183)),       # Lilla
    'AGRICULTURE': (0.027450980618596077, (255, 195, 128)),  # Arancione
    'BUILDING': (0.007843137718737125, (255, 0, 0)),         # Rosso
    'WATER': (0.01568627543747425, (0, 0, 255)),             # Blu
    'ROAD': (0.0117647061124444, (255, 255, 0)),             # Giallo
    'BG': (0.019607843831181526, (255, 255, 255)),           # Bianco
    'FOREST': (0.0235294122248888, (0, 255, 0))              # Verde
}

categories = dict(sorted(categories.items(), key=lambda item: item[1][0]))

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

num_classes = len(categories.keys())

## Dataset: LoveDA

### Download dataset

#### Without Google Drive

In [None]:
"""
download_directory = Path(DATA_DIR)
if not download_directory.exists():
    download_directory.mkdir(exist_ok=True)

# Zip download

train_zip = Path(TRAIN_ZIP)
if not train_zip.exists():
    !wget -O {TRAIN_ZIP} 'https://zenodo.org/record/5706578/files/Train.zip?download=1'

val_zip = Path(VAL_ZIP)
if not val_zip.exists():
    !wget -O {VAL_ZIP} 'https://zenodo.org/records/5706578/files/Val.zip?download=1'

test_zip = Path(TEST_ZIP)
if not test_zip.exists():
    !wget -O {TEST_ZIP} 'https://zenodo.org/records/5706578/files/Test.zip?download=1'

# Zip extraction

## I suppose to not cancel the original zip since who knows

train_dir = Path(TRAIN_DIR)
if not train_dir.exists():
    !unzip -q {TRAIN_ZIP} -d {DATA_DIR}
    !mv {DATA_DIR}/Train {TRAIN_DIR}

val_dir = Path(VAL_DIR)
if not val_dir.exists():
    !unzip -q {VAL_ZIP} -d {DATA_DIR}
    !mv {DATA_DIR}/Val {VAL_DIR}

test_dir = Path(TEST_DIR)
if not test_dir.exists():
    !unzip -q {TEST_ZIP} -d {DATA_DIR}
    !mv {DATA_DIR}/Test {TEST_DIR}
"""

#### With Google Drive

In [None]:
def download_to_gdrive():
    drive_path_dir = '/content/drive'
    mydrive_path_dir = f'{drive_path_dir}/MyDrive'
    data_path_dir = f'{mydrive_path_dir}/{DATA_DIR}'
    train_path_zip = f'{mydrive_path_dir}/{TRAIN_ZIP}'
    val_path_zip = f'{mydrive_path_dir}/{VAL_ZIP}'
    test_path_zip = f'{mydrive_path_dir}/{TEST_ZIP}'
    train_path_dir = f'{mydrive_path_dir}/{TRAIN_DIR}'
    val_path_dir = f'{mydrive_path_dir}/{VAL_DIR}'
    test_path_dir = f'{mydrive_path_dir}/{TEST_DIR}'

    from google.colab import drive
    drive.mount(drive_path_dir)

    download_directory = Path(data_path_dir)
    if not download_directory.exists():
        download_directory.mkdir(exist_ok=True)

    train_zip = Path(train_path_zip)
    if not train_zip.exists():
        !wget -O {train_path_zip} 'https://zenodo.org/record/5706578/files/Train.zip?download=1'

    val_zip = Path(val_path_zip)
    if not val_zip.exists():
        !wget -O {val_path_zip} 'https://zenodo.org/records/5706578/files/Val.zip?download=1'

    test_zip = Path(test_path_zip)
    if not test_zip.exists():
        !wget -O {test_path_zip} 'https://zenodo.org/records/5706578/files/Test.zip?download=1'

def extract_from_gdrive():

    drive_path_dir = '/content/drive'
    mydrive_path_dir = f'{drive_path_dir}/MyDrive'
    data_path_dir = f'{mydrive_path_dir}/{DATA_DIR}'
    train_path_zip = f'{mydrive_path_dir}/{TRAIN_ZIP}'
    val_path_zip = f'{mydrive_path_dir}/{VAL_ZIP}'
    test_path_zip = f'{mydrive_path_dir}/{TEST_ZIP}'
    train_path_dir = f'{mydrive_path_dir}/{TRAIN_DIR}'
    val_path_dir = f'{mydrive_path_dir}/{VAL_DIR}'
    test_path_dir = f'{mydrive_path_dir}/{TEST_DIR}'

    from google.colab import drive
    drive.mount(drive_path_dir)

    train_dir = Path(TRAIN_DIR)
    if not train_dir.exists():
        !unzip -q {train_path_zip} -d {DATA_DIR}
        !mv {DATA_DIR}/Train {TRAIN_DIR}

    val_dir = Path(VAL_DIR)
    if not val_dir.exists():
        !unzip -q {val_path_zip} -d {DATA_DIR}
        !mv {DATA_DIR}/Val {VAL_DIR}

    #test_dir = Path(TEST_DIR)
    #if not test_dir.exists():
    #    !unzip -q {test_path_zip} -d {DATA_DIR}
    #    !mv {DATA_DIR}/Test {TEST_DIR}

In [None]:
def copy_to_gdrive():
    drive_path_dir = '/content/drive'
    mydrive_path_dir = f'{drive_path_dir}/MyDrive'
    data_path_dir = f'{mydrive_path_dir}/{DATA_DIR}'
    train_path_zip = f'{mydrive_path_dir}/{TRAIN_ZIP}'
    val_path_zip = f'{mydrive_path_dir}/{VAL_ZIP}'
    test_path_zip = f'{mydrive_path_dir}/{TEST_ZIP}'

    from google.colab import drive
    drive.mount(drive_path_dir)

    import shutil  # Import shutil for file operations

    # Create the directory if it doesn't exist
    os.makedirs(data_path_dir, exist_ok=True)

    # Copy the zip files using shutil.copy
    if not os.path.exists(train_path_zip):
        shutil.copy(TRAIN_ZIP, train_path_zip)
        print(f"Copied {TRAIN_ZIP} to {train_path_zip}")

    if not os.path.exists(val_path_zip):
        shutil.copy(VAL_ZIP, val_path_zip)
        print(f"Copied {VAL_ZIP} to {val_path_zip}")

    if not os.path.exists(test_path_zip):
        shutil.copy(TEST_ZIP, test_path_zip)
        print(f"Copied {TEST_ZIP} to {test_path_zip}")

In [None]:
#download_to_gdrive()
#copy_to_gdrive()
extract_from_gdrive()

### Dataset construction

In [None]:
def pil_loader(path, codify):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert(codify)

def load_images(root_path, directory, img, mask):
    directory_path = root_path / directory
    img_path = directory_path / img
    mask_path = directory_path / mask
    if not img_path.is_dir() or not mask_path.is_dir():
        raise RuntimeError("folder structure different from expected")

    images = [item.name for item in img_path.iterdir()]
    masks = [item.name for item in mask_path.iterdir()]

    if set(images) != set(masks):
        raise RuntimeError("images and masks do not match")

    return images

def generate_bd(mask, edge_pad=False, is_flip=False, edge_size=2):

    y_k_size = 6
    x_k_size = 6

    edge = cv2.Canny(mask, 0, 8)
    kernel = np.ones((edge_size, edge_size), np.uint8)

    if edge_pad:
        edge = edge[y_k_size:-y_k_size, x_k_size:-x_k_size]
        edge = np.pad(edge, ((y_k_size,y_k_size),(x_k_size,x_k_size)), mode='constant')
    edge = (cv2.dilate(edge, kernel, iterations=1)>50)*1.0

    return edge

class LoveDA(VisionDataset):
    def __init__(self, root, img, mask, directories=None, transforms=None, bd=False):
        super(LoveDA, self).__init__(root)

        root_path = Path(root)

        if not root_path.is_dir():
            raise RuntimeError("root should be a directory")

        self.root = root
        self.img_path = img
        self.mask_path = mask
        self.transforms = transforms

        self.image_names = []

        self.bd = bd

        if directories is None:
            raise RuntimeError("at least one directory must be passed")

        directories = [directories] if isinstance(directories, str) else directories

        for d in directories:
          image_names = load_images(root_path, d, img, mask)
          self.image_names.extend([(d, image_name) for image_name in image_names])

    def __getitem__(self, index):
        dir, image_name = self.image_names[index]
        image_path = f'{self.root}/{dir}/{self.img_path}/{image_name}'
        mask_path = f'{self.root}/{dir}/{self.mask_path}/{image_name}'

        image = pil_loader(image_path, RGB)
        mask = pil_loader(mask_path, grayscale)

        image = np.array(image)
        mask = np.array(mask)

        if self.transforms is not None:
          data = self.transforms(image=image, mask=mask)
          image = data['image']
          mask = data['mask']

        image = transforms.ToTensor()(image)
        mask = transforms.ToTensor()(mask).squeeze(0)
        mask = transforms.ToPILImage()(mask)
        mask = transforms.PILToTensor()(mask).squeeze(0).long()

        mask = mask - 1

        if self.bd:
            bd = generate_bd(mask.numpy().astype(np.uint8))

            return image, mask, bd

        return image, mask

    def __len__(self):
        length = len(self.image_names)
        return length

### Statistics and metrics

#### Average, Standard deviation

In [None]:
def compute_avg_std(dataset, dataloader, device):
    with torch.no_grad():
        avg = torch.zeros((1,3)).to(device)
        std = torch.zeros((1,3)).to(device)
        data_len = 0
        tot_pixels = 0

        assert len(dataloader) > 0, "Dataloader must contain some data"

        tot_batches = len(dataloader)

        for (step, (inputs, labels)) in enumerate(dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            b, _, h, w = inputs.shape

            data_len += b
            tot_pixels += b * h * w
            avg += torch.sum(inputs, dim=(0,2,3))
            std += torch.sum(inputs * inputs, dim=(0,2,3))

        avg /= tot_pixels
        std = torch.sqrt(std / tot_pixels - avg * avg)

        return data_len, avg.flatten().tolist(), std.flatten().tolist()

#### IoU

In [None]:
def calculate_iou(outputs, masks, num_classes):

    # Get predictions from the model output probabilities
    _, preds = torch.max(outputs, dim=1) # B x H x W

    # IoU for each class
    iou_per_class = torch.zeros(num_classes, dtype=torch.float32, device=outputs.device)

    for i in range(num_classes):  # Iterate over all classes
        pred_mask = preds == i
        label_mask = masks == i

        intersection = torch.logical_and(pred_mask, label_mask).sum().float()
        union = torch.logical_or(pred_mask, label_mask).sum().float()

        if union > 0:
            iou_per_class[i] = intersection / union

    # Calculate mIoU for classes with a non-zero IoU
    valid_ious = iou_per_class
    miou = valid_ious.mean() if len(valid_ious) > 0 else torch.tensor(0.0, device=outputs.device)

    return miou, iou_per_class



#### Latency, FPS

In [None]:
def calculate_latency_fps(model, device, height, width, iterations, model_type: ModelType):
    image = torch.randn(1, 3, height, width).to(device)
    mask = None
    boundary = None

    if model_type == ModelType.PIDNET:
        mask = torch.randint(0, num_classes, (1, height, width), dtype=torch.int64).to(device)
        boundary = torch.randint(0, 2, (1, height, width), dtype=torch.float64).to(device)

    latency = []
    FPS = []

    for _ in range(iterations):
        start = time.time()

        with torch.no_grad():
            if model_type == ModelType.DEEPLAB:
                _ = model(image)
            else:
                _ = model(image, mask, boundary)

        end = time.time()

        latency_i = end - start
        latency.append(latency_i)

        FPS_i = 1 / latency_i
        FPS.append(FPS_i)

    meanLatency = np.mean(latency) * 1000 # millis
    stdLatency = np.std(latency) * 1000
    meanFPS = np.mean(FPS)
    stdFPS = np.std(FPS)

    return meanLatency, stdLatency, meanFPS, stdFPS

#### FLOPS, Params

In [None]:
def calculate_flops_params(model, device, height, width, model_type: ModelType):
    image = torch.zeros(1, 3, height, width).to(device)
    model = model.to(device)
    flops = None
    if model_type == ModelType.PIDNET:
        mask = torch.zeros(1, height, width, dtype=torch.int64).to(device)
        boundary = torch.zeros(1, height, width, dtype=torch.float64).to(device)
        flops = FlopCountAnalysis(model, (image, mask, boundary))
    else:
        flops = FlopCountAnalysis(model, image)
    print(flop_count_table(flops))

### Plot losses and mious

In [None]:
def plot_losses_mious(train_losses, eval_losses, miou_scores, num_epochs):
    # Crea una figura con due assi, disposti uno accanto all'altro
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 5))

    # Disegna il grafico delle perdite di training e validation sul primo asse
    ax1.plot(train_losses, label='Training Loss')
    ax1.plot(eval_losses, label='Validation Loss')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.set_xticks(range(0, num_epochs), range(1, num_epochs + 1))
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid()

    # Disegna il grafico di mIoU sul secondo asse
    ax2.plot(miou_scores, label='mIoU')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('mIoU')
    ax2.set_xticks(range(0, num_epochs), range(1, num_epochs + 1))
    ax2.set_title('mIoU')
    ax2.legend()
    ax2.grid()

    # Mostra la figura
    plt.show()

In [None]:
def plot_mious_per_category(miou_scores, num_epochs):
    plt.figure(figsize=(10, 6))
    for class_name, miou_values in miou_scores.items():
        plt.plot(range(num_epochs), miou_values, label=class_name)

    plt.xlabel('Epoch')
    plt.ylabel('mIoU (%)')
    plt.xticks(range(0, num_epochs), range(1, num_epochs + 1))
    plt.title('mIoU per Class over Epochs')
    plt.legend()
    plt.grid(True)
    plt.show()

### Checkpoint resume

In [None]:
def resume_checkpoint(resume_path, model, optimizer=None, scheduler=None):
    checkpoint = torch.load(resume_path)
    iteration = checkpoint['iteration'] + 1
    model.load_state_dict(checkpoint['model'])
    if optimizer is not None:
      optimizer.load_state_dict(checkpoint['optimizer'])
    if scheduler is not None:
      scheduler.load_state_dict(checkpoint['scheduler'])
    return iteration, model, optimizer, scheduler

In [None]:
def save_checkpoint(path, iteration, model, optimizer, scheduler):
    checkpoint = {
        'iteration': iteration,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict()
    }
    torch.save(checkpoint, path)

### Image visualization

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches

def plot_tensor_mask(mask_tensor, categories):

    categories = dict(sorted(categories.items(), key=lambda item: item[1][0]))

    # Convert mask tensor to numpy array
    mask_array = mask_tensor.squeeze().numpy()

    # Create a colored mask image
    colored_mask = np.zeros((mask_array.shape[0], mask_array.shape[1], 3), dtype=np.uint8)
    for i, (label, (value, color)) in enumerate(categories.items()):
        mask = mask_array == i
        colored_mask[mask] = color

    # Display the colored mask
    plt.figure(figsize=(8, 5))
    plt.imshow(colored_mask)
    plt.axis("off")

    # Create a legend
    legend_patches = [mpatches.Patch(color=np.array(color)/255, label=label) for label, (_, color) in categories.items()]
    plt.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    plt.show()


# Step 2a: Testing classic semantic segmentation network

### Download pre-trained weights

In [None]:
weights_dir = Path(PRETRAINED_WEIGHTS_DIR)
if not weights_dir.exists():
    weights_dir.mkdir(exist_ok=True)

deeplab_v2_weights = Path(DEEPLAB_V2_WEIGHTS)
if not deeplab_v2_weights.exists():
    # Replace with the correct Google Drive file ID
    file_id = '1ZX0UCXvJwqd2uBGCX7LI2n-DfMg3t74v'
    gdown.download(id=file_id, output=str(deeplab_v2_weights), quiet=False)

## Model: DeepLabv2


### Implementation

In [None]:
affine_par = True

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
        self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
        for i in self.bn1.parameters():
            i.requires_grad = False
        padding = dilation
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                               padding=padding, bias=False, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
        for i in self.bn2.parameters():
            i.requires_grad = False
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)
        for i in self.bn3.parameters():
            i.requires_grad = False
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)

        return out


class ClassifierModule(nn.Module):
    def __init__(self, inplanes, dilation_series, padding_series, num_classes):
        super(ClassifierModule, self).__init__()
        self.conv2d_list = nn.ModuleList()
        for dilation, padding in zip(dilation_series, padding_series):
            self.conv2d_list.append(
                nn.Conv2d(inplanes, num_classes, kernel_size=3, stride=1, padding=padding,
                          dilation=dilation, bias=True))

        for m in self.conv2d_list:
            m.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.conv2d_list[0](x)
        for i in range(len(self.conv2d_list) - 1):
            out += self.conv2d_list[i + 1](x)
        return out


class ResNetMulti(nn.Module):
    def __init__(self, block, layers, num_classes):
        self.inplanes = 64
        super(ResNetMulti, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64, affine=affine_par)

        for i in self.bn1.parameters():
            i.requires_grad = False

        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)  # change
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
        self.layer6 = ClassifierModule(2048, [6, 12, 18, 24], [6, 12, 18, 24], num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.01)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        if (stride != 1
                or self.inplanes != planes * block.expansion
                or dilation == 2
                or dilation == 4):
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, affine=affine_par))
        for i in downsample._modules['1'].parameters():
            i.requires_grad = False
        layers = []
        layers.append(
            block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        _, _, H, W = x.size()

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer6(x)

        x = torch.nn.functional.interpolate(x, size=(H, W), mode='bilinear')

        if self.training == True:
            return x, None, None

        return x

    def get_1x_lr_params_no_scale(self):
        """
        This generator returns all the parameters of the net except for
        the last classification layer. Note that for each batchnorm layer,
        requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
        any batchnorm parameter
        """
        b = []

        b.append(self.conv1)
        b.append(self.bn1)
        b.append(self.layer1)
        b.append(self.layer2)
        b.append(self.layer3)
        b.append(self.layer4)

        for i in range(len(b)):
            for j in b[i].modules():
                jj = 0
                for k in j.parameters():
                    jj += 1
                    if k.requires_grad:
                        yield k

    def get_10x_lr_params(self):
        """
        This generator returns all the parameters for the last layer of the net,
        which does the classification of pixel into classes
        """
        b = []
        if self.multi_level:
            b.append(self.layer5.parameters())
        b.append(self.layer6.parameters())

        for j in range(len(b)):
            for i in b[j]:
                yield i

    def optim_parameters(self, lr):
        return [{'params': self.get_1x_lr_params_no_scale(), 'lr': lr},
                {'params': self.get_10x_lr_params(), 'lr': 10 * lr}]


def get_deeplab_v2(num_classes=19, pretrain=True, pretrain_model_path='DeepLab_resnet_pretrained_imagenet.pth'):
    model = ResNetMulti(Bottleneck, [3, 4, 23, 3], num_classes)

    # Pretraining loading
    if pretrain:
        print('Deeplab pretraining loading...')
        saved_state_dict = torch.load(pretrain_model_path, weights_only=True)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            i_parts = i.split('.')
            new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        model.load_state_dict(new_params, strict=False)

    return model

## Run

### Parameters

In [None]:
# Change in case of resume training
RESUME_TRAINING = False
RESUME_PATH = f"/content/drive/MyDrive/loveDA_dataset/Model training/DeepLab/DeepLabV2_{num_epochs}_{learning_rate}_{step_size}_{gamma}_{resize}_{w_decay}_epoch{epoch}.pth.tar"

num_epochs = 20
BATCH_SIZE = 6
learning_rate = 1e-3
step_size = 10
gamma = 0.1
resize = 512
w_decay = 1e-3

### Dataset preprocessing

#### Normalization metrics

In [None]:
#preprocessing_dataset = LoveDA(TRAIN_DIR, IMG_PATH, MASK_PATH, directories=URBAN_PATH)
# No shuffle (waste of time), no drop last (we lose some data)

num_workers = 2 if device.type == 'cuda' else 0

#preprocessing_dataloader = DataLoader(preprocessing_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=num_workers)
#_, avg, std = compute_avg_std(preprocessing_dataset, preprocessing_dataloader, device)

# Poiché il modello è pretrainato su ImageNet, si usano media e varianza di ImageNet
avg = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

#### Training set

In [None]:
train_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255),
    A.Resize(resize, resize, p=1, always_apply=True)
])
train_dataset = LoveDA(TRAIN_DIR, IMG_PATH, MASK_PATH, directories=URBAN_PATH, transforms=train_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

#### Validation set

In [None]:
val_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255)
])
val_dataset = LoveDA(VAL_DIR, IMG_PATH, MASK_PATH, directories=URBAN_PATH, transforms=val_transform)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

### Training process

#### Model engine

In [None]:
model = get_deeplab_v2(num_classes=num_classes, pretrain=True, pretrain_model_path=DEEPLAB_V2_WEIGHTS).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=w_decay)
scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)

#### Model handling

#### Training loop

In [None]:
train_losses = []

if RESUME_TRAINING:
  start_epoch, model, optimizer, scheduler = resume_checkpoint(RESUME_PATH, model, optimizer, scheduler)
else:
  start_epoch = 0

for epoch in range(start_epoch, num_epochs):
    print("### Training mode")
    model.train()
    running_loss = 0.0
    for i, (images, masks) in enumerate(train_loader):
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        outputs, _, _ = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if (i + 1) % 25 == 0:
            print(f"Processed {i + 1} batches, loss: {running_loss / (i+1)}")

    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss:.4f}")
    path=f"/content/drive/MyDrive/loveDA_dataset/Model training/DeepLab/DeepLabV2_{num_epochs}_{learning_rate}_{step_size}_{gamma}_{resize}_{w_decay}_epoch{epoch}.pth.tar"
    save_checkpoint(path, epoch, model, optimizer, scheduler)

    scheduler.step()

#### Evaluation loop

In [None]:
# Requires saving the models for each epoch

start_epoch_eval = 0

eval_losses = []
mious = []

for epoch in range(start_epoch_eval, num_epochs):

    model = get_deeplab_v2(num_classes=num_classes, pretrain=False)  # Assuming get_deeplab_v2 is defined

    path = f"/content/drive/MyDrive/loveDA_dataset/Model training/DeepLab/DeepLabV2_{num_epochs}_{learning_rate}_{step_size}_{gamma}_{resize}_{w_decay}_epoch{epoch}.pth.tar"
    _, model, _, _ = resume_checkpoint(path, model)

    model.to(device)
    print("### Evaluation mode")
    miou = 0.0 # Accumulator for mIoU
    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for i, (images, masks) in enumerate(val_loader):
            images = images.to(device)
            masks = masks.to(device)

            # loss
            outputs= model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()

            # mIoU
            iou, _ = calculate_iou(outputs, masks, num_classes)
            miou += iou

            if (i + 1) % 25 == 0:
                print(f"Processed {i + 1} batches: loss {val_loss / (i+1)}, mIoU: {miou / (i+1)}")

    val_loss /= len(val_loader)

    eval_losses.append(val_loss)

    miou /= len(val_loader)

    mious.append(miou)

    print(f"Epoch: [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}, mIoU: {(miou * 100):.2f}%")

### Metric calculation

In [None]:
model.eval()
mean_latency, _, _, _ = calculate_latency_fps(model, device, 1024, 1024, num_epochs, ModelType.DEEPLAB)
print(f"Mean Latency: {mean_latency:.2f} ms")

calculate_flops_params(model, device, 1024, 1024, ModelType.DEEPLAB)

# PIDNet Implementation (execute for any Step from 2b on)

## Model: PIDnet


### Implementation

#### Definitions

In [None]:
BatchNorm2d = nn.BatchNorm2d
bn_mom = 0.1
algc = False

#### BasicBlock

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=False):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               padding=1, bias=False)
        self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
        self.downsample = downsample
        self.stride = stride
        self.no_relu = no_relu

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        if self.no_relu:
            return out
        else:
            return self.relu(out)

#### Bottleneck

In [None]:
class Bottleneck(nn.Module):
    expansion = 2

    def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=True):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = BatchNorm2d(planes * self.expansion, momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.no_relu = no_relu

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        if self.no_relu:
            return out
        else:
            return self.relu(out)

#### SegmentHead

In [None]:
class segmenthead(nn.Module):

    def __init__(self, inplanes, interplanes, outplanes, scale_factor=None):
        super(segmenthead, self).__init__()
        self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom)
        self.conv1 = nn.Conv2d(inplanes, interplanes, kernel_size=3, padding=1, bias=False)
        self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(interplanes, outplanes, kernel_size=1, padding=0, bias=True)
        self.scale_factor = scale_factor

    def forward(self, x):
        x = self.conv1(self.relu(self.bn1(x)))
        out = self.conv2(self.relu(self.bn2(x)))

        if self.scale_factor is not None:
            height = x.shape[-2] * self.scale_factor
            width = x.shape[-1] * self.scale_factor
            out = F.interpolate(out,
                        size=[height, width],
                        mode='bilinear', align_corners=algc)

        return out

#### DAPPM

In [None]:
class DAPPM(nn.Module):
    def __init__(self, inplanes, branch_planes, outplanes, BatchNorm=nn.BatchNorm2d):
        super(DAPPM, self).__init__()
        bn_mom = 0.1
        self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale0 = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.process1 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.process2 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.process3 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.process4 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.compression = nn.Sequential(
                                    BatchNorm(branch_planes * 5, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False),
                                    )
        self.shortcut = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False),
                                    )

    def forward(self, x):
        width = x.shape[-1]
        height = x.shape[-2]
        x_list = []

        x_list.append(self.scale0(x))
        x_list.append(self.process1((F.interpolate(self.scale1(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[0])))
        x_list.append((self.process2((F.interpolate(self.scale2(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[1]))))
        x_list.append(self.process3((F.interpolate(self.scale3(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[2])))
        x_list.append(self.process4((F.interpolate(self.scale4(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[3])))

        out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x)
        return out

#### PAPPM

In [None]:
class PAPPM(nn.Module):
    def __init__(self, inplanes, branch_planes, outplanes, BatchNorm=nn.BatchNorm2d):
        super(PAPPM, self).__init__()
        bn_mom = 0.1
        self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )

        self.scale0 = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )

        self.scale_process = nn.Sequential(
                                    BatchNorm(branch_planes*4, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes*4, branch_planes*4, kernel_size=3, padding=1, groups=4, bias=False),
                                    )


        self.compression = nn.Sequential(
                                    BatchNorm(branch_planes * 5, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False),
                                    )

        self.shortcut = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False),
                                    )


    def forward(self, x):
        width = x.shape[-1]
        height = x.shape[-2]
        scale_list = []

        x_ = self.scale0(x)
        scale_list.append(F.interpolate(self.scale1(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)
        scale_list.append(F.interpolate(self.scale2(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)
        scale_list.append(F.interpolate(self.scale3(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)
        scale_list.append(F.interpolate(self.scale4(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)

        scale_out = self.scale_process(torch.cat(scale_list, 1))

        out = self.compression(torch.cat([x_,scale_out], 1)) + self.shortcut(x)
        return out

#### PagFM

In [None]:
class PagFM(nn.Module):
    def __init__(self, in_channels, mid_channels, after_relu=False, with_channel=False, BatchNorm=nn.BatchNorm2d):
        super(PagFM, self).__init__()
        self.with_channel = with_channel
        self.after_relu = after_relu
        self.f_x = nn.Sequential(
                                nn.Conv2d(in_channels, mid_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(mid_channels)
                                )
        self.f_y = nn.Sequential(
                                nn.Conv2d(in_channels, mid_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(mid_channels)
                                )
        if with_channel:
            self.up = nn.Sequential(
                                    nn.Conv2d(mid_channels, in_channels,
                                              kernel_size=1, bias=False),
                                    BatchNorm(in_channels)
                                   )
        if after_relu:
            self.relu = nn.ReLU(inplace=True)

    def forward(self, x, y):
        input_size = x.size()
        if self.after_relu:
            y = self.relu(y)
            x = self.relu(x)

        y_q = self.f_y(y)
        y_q = F.interpolate(y_q, size=[input_size[2], input_size[3]],
                            mode='bilinear', align_corners=False)
        x_k = self.f_x(x)

        if self.with_channel:
            sim_map = torch.sigmoid(self.up(x_k * y_q))
        else:
            sim_map = torch.sigmoid(torch.sum(x_k * y_q, dim=1).unsqueeze(1))

        y = F.interpolate(y, size=[input_size[2], input_size[3]],
                            mode='bilinear', align_corners=False)
        x = (1-sim_map)*x + sim_map*y

        return x

#### LightBag

In [None]:
class Light_Bag(nn.Module):
    def __init__(self, in_channels, out_channels, BatchNorm=nn.BatchNorm2d):
        super(Light_Bag, self).__init__()
        self.conv_p = nn.Sequential(
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )
        self.conv_i = nn.Sequential(
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )

    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)

        p_add = self.conv_p((1-edge_att)*i + p)
        i_add = self.conv_i(i + edge_att*p)

        return p_add + i_add

#### DDFMv2

In [None]:
class DDFMv2(nn.Module):
    def __init__(self, in_channels, out_channels, BatchNorm=nn.BatchNorm2d):
        super(DDFMv2, self).__init__()
        self.conv_p = nn.Sequential(
                                BatchNorm(in_channels),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )
        self.conv_i = nn.Sequential(
                                BatchNorm(in_channels),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )

    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)

        p_add = self.conv_p((1-edge_att)*i + p)
        i_add = self.conv_i(i + edge_att*p)

        return p_add + i_add

#### Bag

In [None]:
class Bag(nn.Module):
    def __init__(self, in_channels, out_channels, BatchNorm=nn.BatchNorm2d):
        super(Bag, self).__init__()

        self.conv = nn.Sequential(
                                BatchNorm(in_channels),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=3, padding=1, bias=False)
                                )


    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)
        return self.conv(edge_att*p + (1-edge_att)*i)

#### PIDNet

In [None]:
class PIDNet(nn.Module):

    def __init__(self, m=2, n=3, num_classes=19, planes=64, ppm_planes=96, head_planes=128, augment=True):
        super(PIDNet, self).__init__()
        self.augment = augment

        # I Branch
        self.conv1 =  nn.Sequential(
                          nn.Conv2d(3,planes,kernel_size=3, stride=2, padding=1),
                          BatchNorm2d(planes, momentum=bn_mom),
                          nn.ReLU(inplace=True),
                          nn.Conv2d(planes,planes,kernel_size=3, stride=2, padding=1),
                          BatchNorm2d(planes, momentum=bn_mom),
                          nn.ReLU(inplace=True),
                      )

        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(BasicBlock, planes, planes, m)
        self.layer2 = self._make_layer(BasicBlock, planes, planes * 2, m, stride=2)
        self.layer3 = self._make_layer(BasicBlock, planes * 2, planes * 4, n, stride=2)
        self.layer4 = self._make_layer(BasicBlock, planes * 4, planes * 8, n, stride=2)
        self.layer5 =  self._make_layer(Bottleneck, planes * 8, planes * 8, 2, stride=2)

        # P Branch
        self.compression3 = nn.Sequential(
                                          nn.Conv2d(planes * 4, planes * 2, kernel_size=1, bias=False),
                                          BatchNorm2d(planes * 2, momentum=bn_mom),
                                          )

        self.compression4 = nn.Sequential(
                                          nn.Conv2d(planes * 8, planes * 2, kernel_size=1, bias=False),
                                          BatchNorm2d(planes * 2, momentum=bn_mom),
                                          )
        self.pag3 = PagFM(planes * 2, planes)
        self.pag4 = PagFM(planes * 2, planes)

        self.layer3_ = self._make_layer(BasicBlock, planes * 2, planes * 2, m)
        self.layer4_ = self._make_layer(BasicBlock, planes * 2, planes * 2, m)
        self.layer5_ = self._make_layer(Bottleneck, planes * 2, planes * 2, 1)

        # D Branch
        if m == 2:
            self.layer3_d = self._make_single_layer(BasicBlock, planes * 2, planes)
            self.layer4_d = self._make_layer(Bottleneck, planes, planes, 1)
            self.diff3 = nn.Sequential(
                                        nn.Conv2d(planes * 4, planes, kernel_size=3, padding=1, bias=False),
                                        BatchNorm2d(planes, momentum=bn_mom),
                                        )
            self.diff4 = nn.Sequential(
                                     nn.Conv2d(planes * 8, planes * 2, kernel_size=3, padding=1, bias=False),
                                     BatchNorm2d(planes * 2, momentum=bn_mom),
                                     )
            self.spp = PAPPM(planes * 16, ppm_planes, planes * 4)
            self.dfm = Light_Bag(planes * 4, planes * 4)
        else:
            self.layer3_d = self._make_single_layer(BasicBlock, planes * 2, planes * 2)
            self.layer4_d = self._make_single_layer(BasicBlock, planes * 2, planes * 2)
            self.diff3 = nn.Sequential(
                                        nn.Conv2d(planes * 4, planes * 2, kernel_size=3, padding=1, bias=False),
                                        BatchNorm2d(planes * 2, momentum=bn_mom),
                                        )
            self.diff4 = nn.Sequential(
                                     nn.Conv2d(planes * 8, planes * 2, kernel_size=3, padding=1, bias=False),
                                     BatchNorm2d(planes * 2, momentum=bn_mom),
                                     )
            self.spp = DAPPM(planes * 16, ppm_planes, planes * 4)
            self.dfm = Bag(planes * 4, planes * 4)

        self.layer5_d = self._make_layer(Bottleneck, planes * 2, planes * 2, 1)

        # Prediction Head
        if self.augment:
            self.seghead_p = segmenthead(planes * 2, head_planes, num_classes)
            self.seghead_d = segmenthead(planes * 2, planes, 1)

        self.final_layer = segmenthead(planes * 4, head_planes, num_classes)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


    def _make_layer(self, block, inplanes, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
            )

        layers = []
        layers.append(block(inplanes, planes, stride, downsample))
        inplanes = planes * block.expansion
        for i in range(1, blocks):
            if i == (blocks-1):
                layers.append(block(inplanes, planes, stride=1, no_relu=True))
            else:
                layers.append(block(inplanes, planes, stride=1, no_relu=False))

        return nn.Sequential(*layers)

    def _make_single_layer(self, block, inplanes, planes, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
            )

        layer = block(inplanes, planes, stride, downsample, no_relu=True)

        return layer

    def forward(self, x):

        width_output = x.shape[-1] // 8
        height_output = x.shape[-2] // 8

        x = self.conv1(x)
        x = self.layer1(x)
        x = self.relu(self.layer2(self.relu(x)))
        x_ = self.layer3_(x)
        x_d = self.layer3_d(x)

        x = self.relu(self.layer3(x))
        x_ = self.pag3(x_, self.compression3(x))
        x_d = x_d + F.interpolate(
                        self.diff3(x),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=algc)
        if self.augment:
            temp_p = x_

        x = self.relu(self.layer4(x))
        x_ = self.layer4_(self.relu(x_))
        x_d = self.layer4_d(self.relu(x_d))

        x_ = self.pag4(x_, self.compression4(x))
        x_d = x_d + F.interpolate(
                        self.diff4(x),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=algc)
        if self.augment:
            temp_d = x_d

        x_ = self.layer5_(self.relu(x_))
        x_d = self.layer5_d(self.relu(x_d))
        x = F.interpolate(
                        self.spp(self.layer5(x)),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=algc)

        x_ = self.final_layer(self.dfm(x_, x, x_d))

        if self.augment:
            x_extra_p = self.seghead_p(temp_p)
            x_extra_d = self.seghead_d(temp_d)
            return [x_extra_p, x_, x_extra_d]
        else:
            return x_

#### SemanticLoss

##### CrossEntropy

In [None]:
class CrossEntropy(nn.Module):
    def __init__(self, ignore_label=-1, weight=None):
        super(CrossEntropy, self).__init__()
        self.ignore_label = ignore_label
        self.criterion = nn.CrossEntropyLoss(
            weight=weight,
            ignore_index=ignore_label
        )

    def _forward(self, score, target):

        loss = self.criterion(score, target)

        return loss

    def forward(self, score, target):

        # From original configs
        balance_weights = [0.4, 1.0]
        sb_weights = 1.0

        if len(balance_weights) == len(score):
            return sum([w * self._forward(x, target) for (w, x) in zip(balance_weights, score)])
        elif len(score) == 1:
            return sb_weights * self._forward(score[0], target)

        else:
            raise ValueError("lengths of prediction and target are not identical!")

##### OHEM Cross Entropy

In [None]:
class OhemCrossEntropy(nn.Module):
    def __init__(self, ignore_label=-1, thres=0.7,
                 min_kept=100000, weight=None):
        super(OhemCrossEntropy, self).__init__()
        self.thresh = thres
        self.min_kept = max(1, min_kept)
        self.ignore_label = ignore_label
        self.criterion = nn.CrossEntropyLoss(
            weight=weight,
            ignore_index=ignore_label,
            reduction='none'
        )

    def _ce_forward(self, score, target):

        loss = self.criterion(score, target)

        return loss.mean()

    def _ohem_forward(self, score, target, **kwargs):

        pred = F.softmax(score, dim=1)
        pixel_losses = self.criterion(score, target).contiguous().view(-1)
        mask = target.contiguous().view(-1) != self.ignore_label

        tmp_target = target.clone()
        tmp_target[tmp_target == self.ignore_label] = 0
        pred = pred.gather(1, tmp_target.unsqueeze(1))
        pred, ind = pred.contiguous().view(-1,)[mask].contiguous().sort()

        min_value = pred[min(self.min_kept, pred.numel() - 1)]

        threshold = max(min_value, self.thresh)

        pixel_losses = pixel_losses[mask][ind]
        pixel_losses = pixel_losses[pred < threshold]
        return pixel_losses.mean()

    def forward(self, score, target):

        if not (isinstance(score, list) or isinstance(score, tuple)):
            score = [score]

        balance_weights = [0.4, 1.0]
        sb_weights = 1.0

        if len(balance_weights) == len(score):
            functions = [self._ce_forward] * (len(balance_weights) - 1) + [self._ohem_forward]
            return sum([w * func(x, target) for (w, x, func) in zip(balance_weights, score, functions)])

        elif len(score) == 1:
            return sb_weights * self._ohem_forward(score[0], target)

        else:
            raise ValueError("lengths of prediction and target are not identical!")


##### Focal Loss

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, ignore_label=-1, weight=None):
        super(FocalLoss, self).__init__()
        self.ignore_label = ignore_label
        self.gamma = gamma
        self.alpha = alpha
        self.criterion = nn.CrossEntropyLoss(
            weight=weight,
            ignore_index=ignore_label
        )

    def _forward(self, score, target):

        ce_loss = self.criterion(score, target)

        pt = torch.exp(-ce_loss)
        focal_loss = torch.pow(1 - pt, self.gamma)

        if self.alpha is not None:
            return self.alpha * focal_loss
        return focal_loss

    def forward(self, score, target):

        if not (isinstance(score, list) or isinstance(score, tuple)):
            score = [score]

        # From original configs
        balance_weights = [0.4, 1.0]
        sb_weights = 1.0

        if len(balance_weights) == len(score):
            return sum([w * self._forward(x, target) for (w, x) in zip(balance_weights, score)])
        elif len(score) == 1:
            return sb_weights * self._forward(score[0], target)

        else:
            raise ValueError("lengths of prediction and target are not identical!")

#### BoundaryLoss

In [None]:
def weighted_bce(bd_pre, target):
    n, c, h, w = bd_pre.size()
    log_p = bd_pre.permute(0,2,3,1).contiguous().view(1, -1)
    target_t = target.view(1, -1)

    pos_index = (target_t == 1)
    neg_index = (target_t == 0)

    weight = torch.zeros_like(log_p)
    pos_num = pos_index.sum()
    neg_num = neg_index.sum()
    sum_num = pos_num + neg_num
    weight[pos_index] = neg_num * 1.0 / sum_num
    weight[neg_index] = pos_num * 1.0 / sum_num

    loss = F.binary_cross_entropy_with_logits(log_p, target_t, weight, reduction='mean')

    return loss

class BondaryLoss(nn.Module):
    def __init__(self, coeff_bce = 20.0):
        super(BondaryLoss, self).__init__()
        self.coeff_bce = coeff_bce

    def forward(self, bd_pre, bd_gt):
        bce_loss = self.coeff_bce * weighted_bce(bd_pre, bd_gt)
        loss = bce_loss

        return loss

#### FullModel

In [None]:
class FullModel(nn.Module):

    def __init__(self, model, sem_loss, bd_loss):
        super(FullModel, self).__init__()
        self.model = model
        self.sem_loss = sem_loss
        self.bd_loss = bd_loss

    def pixel_acc(self, pred, label):
        _, preds = torch.max(pred, dim=1)
        valid = (label != IGNORE_INDEX).long()
        acc_sum = torch.sum(valid * (preds == label).long())
        pixel_sum = torch.sum(valid)
        acc = acc_sum.float() / (pixel_sum.float() + 1e-10)
        return acc

    def forward(self, inputs, labels, bd_gt, *args, **kwargs):
        outputs = self.model(inputs, *args, **kwargs)

        if labels is None:
          h, w = inputs.size(2), inputs.size(3)
        else:
          h, w = labels.size(1), labels.size(2)

        ph, pw = outputs[0].size(2), outputs[0].size(3)
        if ph != h or pw != w:
            for i in range(len(outputs)):
                outputs[i] = F.interpolate(outputs[i], size=(
                    h, w), mode='bilinear', align_corners=True)     #from original configs

        if bd_gt is  None:
            return None, outputs, None, None

        acc  = self.pixel_acc(outputs[-2], labels)
        loss_s = self.sem_loss(outputs[:-1], labels)

        loss_b = self.bd_loss(outputs[-1], bd_gt)

        filler = torch.ones_like(labels) * IGNORE_INDEX       #from original configs
        bd_label = torch.where(F.sigmoid(outputs[-1][:,0,:,:])>0.8, labels, filler)

        loss_sb = self.sem_loss([outputs[-2]], bd_label)

        loss = loss_s + loss_b + loss_sb


        return torch.unsqueeze(loss,0), outputs, acc, [loss_s, loss_b, loss_sb]

#### Other functions

In [None]:
def get_seg_model(model_name, num_classes, pretrained_weights, imgnet_pretrained):

    if 's' in model_name:
        model = PIDNet(m=2, n=3, num_classes=num_classes, planes=32, ppm_planes=96, head_planes=128, augment=True)
    elif 'm' in model_name:
        model = PIDNet(m=2, n=3, num_classes=num_classes, planes=64, ppm_planes=96, head_planes=128, augment=True)
    else:
        model = PIDNet(m=3, n=4, num_classes=num_classes, planes=64, ppm_planes=112, head_planes=256, augment=True)

    if imgnet_pretrained:
        pretrained_state = torch.load(pretrained_weights, map_location='cpu')['state_dict']
        model_dict = model.state_dict()
        pretrained_state = {k: v for k, v in pretrained_state.items() if (k in model_dict and v.shape == model_dict[k].shape)}
        model_dict.update(pretrained_state)
        msg = 'Loaded {} parameters!'.format(len(pretrained_state))
        print('Attention!!!')
        print(msg)
        print('Over!!!')
        model.load_state_dict(model_dict, strict = False)
    else:
        pretrained_dict = torch.load(pretrained_weights, map_location='cpu')
        if 'state_dict' in pretrained_dict:
            pretrained_dict = pretrained_dict['state_dict']
        model_dict = model.state_dict()
        pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if (k[6:] in model_dict and v.shape == model_dict[k[6:]].shape)}
        msg = 'Loaded {} parameters!'.format(len(pretrained_dict))
        print('Attention!!!')
        print(msg)
        print('Over!!!')
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict, strict = False)

    return model

def get_pred_model(name, num_classes):

    if 's' in name:
        model = PIDNet(m=2, n=3, num_classes=num_classes, planes=32, ppm_planes=96, head_planes=128, augment=False)
    elif 'm' in name:
        model = PIDNet(m=2, n=3, num_classes=num_classes, planes=64, ppm_planes=96, head_planes=128, augment=False)
    else:
        model = PIDNet(m=3, n=4, num_classes=num_classes, planes=64, ppm_planes=112, head_planes=256, augment=False)

    return model

## Download pre-trained weights

In [None]:
weights_dir = Path(PRETRAINED_WEIGHTS_DIR)
if not weights_dir.exists():
    weights_dir.mkdir(exist_ok=True)

PIDNET_S_WEIGHTS = weights_dir / 'pidnet_s_imagenet_pretrained.pth'

pidnet_s_weights = Path(PIDNET_S_WEIGHTS)
if not pidnet_s_weights.exists():
    # Replace with the correct Google Drive file ID

    file_id = '1hIBp_8maRr60-B3PF0NVtaA6TYBvO4y-'
    gdown.download(id=file_id, output=str(pidnet_s_weights), quiet=False)

# Step 2b: Real-time semantic segmentation network

## Run

### Parameters

In [None]:
resize = 512
BATCH_SIZE = 6
num_epochs = 20

LR = 1e-3
MOMEUNTUM = 0.9
WEIGHT_DECAY = 1e-2
STEP_SIZE = 10
GAMMA = 0.1

log_frequency = 25

### Dataset preprocessing

#### Normalization metrics

In [None]:
num_workers = 2 if device.type == 'cuda' else 0

# Poiché il modello è pretrainato su ImageNet, si usano media e varianza di ImageNet
avg = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

#### Training set

In [None]:
train_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255),
    A.Resize(resize, resize, p=1, always_apply=True)
])
train_dataset = LoveDA(TRAIN_DIR, IMG_PATH, MASK_PATH, directories=URBAN_PATH, transforms=train_transform, bd=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

#### Validation set

In [None]:
# La resize è bene farla solo sul training set
# La normalizzazione invece può essere applicata anche qui

val_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255)
])
val_dataset = LoveDA(VAL_DIR, IMG_PATH, MASK_PATH, directories=URBAN_PATH, transforms=val_transform, bd=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

### Training process

#### Model engine

In [None]:
pidnet = get_seg_model("pidnet_s", num_classes, PIDNET_S_WEIGHTS,imgnet_pretrained=True)
model = FullModel(pidnet, sem_loss=CrossEntropy(ignore_label=IGNORE_INDEX), bd_loss=BondaryLoss())
model = model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr = LR, momentum = MOMEUNTUM, weight_decay = WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

#### Evaluate function

In [None]:
@torch.no_grad()
def evaluate(model, dataloader, device, ) -> tuple:

    model.eval()

    running_loss = 0.0
    data_len = 0
    iou_scores = 0.0

    for (inputs, masks, boundaries) in dataloader:

        data_len += inputs.size(0)

        inputs = inputs.to(device)
        masks = masks.to(device)
        boundaries = boundaries.to(device)

        # Forward pass
        loss, outputs, acc, loss_list = model(inputs, masks, boundaries)
        running_loss += loss.item()*inputs.size(0)

        # Calculate mIoU
        iou, _ = calculate_iou(outputs[1], masks, num_classes)
        iou_scores += iou*inputs.size(0)

    mIoU = iou_scores/data_len
    loss = running_loss/data_len


    return loss, mIoU

#### Training/evaluation loop

In [None]:
val_losses, val_accuracies = [], []
train_losses, train_accuracies = [], []
miou_scores = []
best_mIoU = -1
best_num_epochs = None

for epoch in range(num_epochs):

    current_step = 0
    train_loss = 0.0
    model.train()
    for (inputs, masks, boundaries) in train_loader:

        inputs = inputs.to(device)
        masks = masks.to(device)

        boundaries = boundaries.to(device)

        # Forward pass
        optimizer.zero_grad()
        loss, outputs, pixel_acc, [loss_s, loss_b, loss_sb] = model(inputs, masks, boundaries)
        train_loss += loss.item()

        # Backward pass
        loss.backward()
        optimizer.step()

        if current_step % log_frequency == 0:
            print(f"Epoch {epoch+1}, Iteration {current_step}, Loss: {loss.item():.3f} Loss_s: {loss_s.item():.3f} Loss_b: {loss_b.item():.3f} Loss_sb: {loss_sb.item():.3f}")

        current_step += 1

    train_loss /= len(train_loader)

    print(f"End of Epoch {epoch+1}")
    print(f"Training loss: {train_loss:.5f}")


    val_loss, val_mean_iou = evaluate(model, val_loader, device)
    print(f"Validation mIoU: {val_mean_iou*100:.3f}%, Validation loss: {val_loss:.5f}")

    val_losses.append(val_loss)
    miou_scores.append(val_mean_iou)
    val_accuracies.append(val_mean_iou.cpu().item())
    train_losses.append(train_loss)

    print()
    # Scheduler is None if learning rate is constant
    if scheduler is not None:
        scheduler.step()

### Metric calculation

In [None]:
model.eval()
mean_latency, _, _, _ = calculate_latency_fps(model, device, 1024, 1024, num_epochs, ModelType.PIDNET)
print(f"Mean Latency: {mean_latency:.2f} ms")
calculate_flops_params(model, device, 1024, 1024, ModelType.PIDNET)

# Step 3a: Evaluating the domain shift problem in Semantic Segmentation

## Run

### Parameters

In [None]:
resize = 512
BATCH_SIZE = 6
num_epochs = 20

LR = 1e-3
MOMEUNTUM = 0.9
WEIGHT_DECAY = 1e-2
STEP_SIZE = 10
GAMMA = 0.1

log_frequency = 25

### Dataset preprocessing

#### Normalization metrics

In [None]:
num_workers = 2 if device.type == 'cuda' else 0

# Poiché il modello è pretrainato su ImageNet, si usano media e varianza di ImageNet
avg = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

#### Training set

In [None]:
train_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255),
    A.Resize(resize, resize, p=1, always_apply=True)
])

train_dataset = LoveDA(TRAIN_DIR, IMG_PATH, MASK_PATH, directories=URBAN_PATH, bd=True, transforms=train_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

#### Validation set

In [None]:
val_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255)
])

val_dataset = LoveDA(VAL_DIR, IMG_PATH, MASK_PATH, directories=RURAL_PATH, bd=True, transforms=val_transform)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

### Training process

#### Model engine

In [None]:
pidnet = get_seg_model("pidnet_s", num_classes, PIDNET_S_WEIGHTS,imgnet_pretrained=True)
model = FullModel(pidnet, sem_loss=CrossEntropy(ignore_label=IGNORE_INDEX), bd_loss=BondaryLoss())
model = model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr = LR, momentum = MOMEUNTUM, weight_decay = WEIGHT_DECAY)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = STEP_SIZE, gamma = GAMMA)

#### Evaluate function

In [None]:
@torch.no_grad()
def evaluate(model, dataloader, device, ) -> tuple:

    model.eval()

    running_loss = 0.0
    data_len = 0
    iou_scores = 0.0
    ious_per_class = torch.zeros(num_classes)

    for (inputs, masks, boundaries) in dataloader:

        data_len += inputs.size(0)

        inputs = inputs.to(device)
        masks = masks.to(device)
        boundaries = boundaries.to(device)

        # Forward pass
        loss, outputs, acc, loss_list = model(inputs, masks, boundaries)
        running_loss += loss.item()*inputs.size(0)

        # Calculate mIoU
        iou, iou_per_class = calculate_iou(outputs[1], masks, num_classes)
        iou_scores += iou*inputs.size(0)
        ious_per_class+=iou_per_class.cpu()*inputs.size(0)

    mIoU = iou_scores/data_len
    loss = running_loss/data_len
    ious_per_class/=data_len


    return loss, mIoU, ious_per_class

#### Training/evaluation loop

In [None]:
val_losses, val_accuracies = [], []
train_losses, train_accuracies = [], []
miou_scores = []
best_mIoU = -1
best_num_epochs = None


for epoch in range(num_epochs):

    current_step = 0
    train_loss = 0.0
    model.train()
    for (inputs, masks, boundaries) in train_loader:

        inputs = inputs.to(device)
        masks = masks.to(device)

        boundaries = boundaries.to(device)

        # Forward pass
        optimizer.zero_grad()
        loss, outputs, pixel_acc, [loss_s, loss_b, loss_sb] = model(inputs, masks, boundaries)
        train_loss += loss.item()

        # Backward pass
        loss.backward()
        optimizer.step()

        if current_step % log_frequency == 0:
            print(f"Epoch {epoch+1}, Iteration {current_step}, Loss: {loss.item():.3f} Loss_s: {loss_s.item():.3f} Loss_b: {loss_b.item():.3f} Loss_sb: {loss_sb.item():.3f}")

        current_step += 1

    train_loss /= len(train_loader)

    print(f"End of Epoch {epoch+1}")
    print(f"Training loss: {train_loss:.5f}")


    val_loss, val_mean_iou, ious_per_class = evaluate(model, val_loader, device)

    print(f"Validation mIoU: {val_mean_iou*100:.3f}%, Validation loss: {val_loss:.5f}")

    for i, cat in enumerate(categories.keys()):
        print(f"{cat} mIoU: {ious_per_class[i]*100:.3f}")

    val_losses.append(val_loss)
    train_losses.append(train_loss)
    miou_scores.append(val_mean_iou)

    print()
    # Scheduler is None if learning rate is constant
    if scheduler is not None:
        scheduler.step()

miou_scores = list(map(lambda x: x.item(), miou_scores))

### Evaluate using saved models

In [None]:
validation_losses = []
miou_scores = []

starting_epoch = 0

for epoch in range(starting_epoch, num_epochs):
    model_path = f"/content/drive/MyDrive/loveDA_dataset/Model training/PIDNet_{num_epochs}_{LR}_{STEP_SIZE}_{GAMMA}_{resize}_{WEIGHT_DECAY}_{MOMEUNTUM}_{GAMMA}_epoch{epoch}.pth"

    pidnet = get_seg_model("pidnet_s", num_classes, PIDNET_S_WEIGHTS,imgnet_pretrained=True)
    model = FullModel(pidnet, sem_loss=CrossEntropy(ignore_label=IGNORE_INDEX), bd_loss=BondaryLoss())
    model.load_state_dict(torch.load(model_path))

    model = model.to(device)

    val_loss, val_mean_iou, ious_per_class = evaluate(model, val_loader, device)

    validation_losses.append(val_loss)
    miou_scores.append(val_mean_iou)

    print(f"Epoch {epoch+1}")
    print(f"Validation mIoU: {val_mean_iou*100:.3f}%, Validation loss: {val_loss:.5f}")

    for i, cat in enumerate(categories.keys()):
        print(f"{cat} mIoU: {ious_per_class[i]*100:.3f}")

    print()




# Step 3b: Data augmentations to reduce the domain shift

### Parameters

In [None]:
resize = 512
BATCH_SIZE = 6
num_epochs = 20

LR = 1e-3
MOMEUNTUM = 0.9
WEIGHT_DECAY = 1e-2
STEP_SIZE = 10
GAMMA = 0.1

### Dataset preprocessing

#### Normalization metrics

In [None]:
num_workers = 2 if device.type == 'cuda' else 0

# Poiché il modello è pretrainato su ImageNet, si usano media e varianza di ImageNet
avg = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

#### Augmentations

In [None]:
aug_prob = 0.5

augmentations = [
    A.ShiftScaleRotate(p=1),
    A.GridDistortion(p=1),
    A.RandomCrop(height=resize, width=resize, p=1),
    A.HorizontalFlip(p=1),
    A.GaussianBlur(p=1),
    A.GridDropout(p=1),
    A.ColorJitter(p=1),
    A.GaussNoise(var_limit=(0.2, 0.3), p=1),
    A.ChannelDropout(p=1),
    A.RandomSizedCrop(min_max_height=(resize//8, resize), height=resize, width=resize, p=1),
]

selected_indices = [2]

selected_augmentations = A.Compose([augmentations[i] for i in selected_indices], p=aug_prob)

#### Training set

In [None]:
train_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255),
    selected_augmentations,
    A.Resize(resize, resize, p=1, always_apply=True)
])

train_dataset = LoveDA(TRAIN_DIR, IMG_PATH, MASK_PATH, directories=URBAN_PATH, bd=True, transforms=train_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

#### Validation set

In [None]:
val_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255)
])

val_dataset = LoveDA(VAL_DIR, IMG_PATH, MASK_PATH, directories=RURAL_PATH, bd=True, transforms=val_transform)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

### Training process

#### Model engine

In [None]:
pidnet = get_seg_model("pidnet_s", num_classes, PIDNET_S_WEIGHTS,imgnet_pretrained=True)
model = FullModel(pidnet, sem_loss=CrossEntropy(ignore_label=IGNORE_INDEX), bd_loss=BondaryLoss())
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr = LR, momentum = MOMEUNTUM, weight_decay = WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = STEP_SIZE, gamma = GAMMA)

#### Evaluate function

In [None]:
@torch.no_grad()
def evaluate(model, dataloader, device, ) -> tuple:

    model.eval()

    running_loss = 0.0
    data_len = 0
    iou_scores = 0.0
    ious_per_class = torch.zeros(num_classes)

    for (inputs, masks, boundaries) in dataloader:

        data_len += inputs.size(0)

        inputs = inputs.to(device)
        masks = masks.to(device)
        boundaries = boundaries.to(device)

        # Forward pass
        loss, outputs, acc, loss_list = model(inputs, masks, boundaries)
        running_loss += loss.item()*inputs.size(0)

        # Calculate mIoU
        iou, iou_per_class = calculate_iou(outputs[1], masks, num_classes)
        iou_scores += iou*inputs.size(0)
        ious_per_class+=iou_per_class.cpu()*inputs.size(0)

    mIoU = iou_scores/data_len
    loss = running_loss/data_len
    ious_per_class/=data_len


    return loss, mIoU, ious_per_class

#### Training/evaluation loop

In [None]:
val_losses, val_accuracies = [], []
train_losses, train_accuracies = [], []
miou_scores = []
miou_per_category = dict()
best_mIoU = -1
best_num_epochs = None


for epoch in range(num_epochs):

    current_step = 0
    train_loss = 0.0
    model.train()
    for (inputs, masks, boundaries) in train_loader:

        inputs = inputs.to(device)
        masks = masks.to(device)

        boundaries = boundaries.to(device)

        # Forward pass
        optimizer.zero_grad()
        loss, outputs, pixel_acc, [loss_s, loss_b, loss_sb] = model(inputs, masks, boundaries)
        train_loss += loss.item()

        # Backward pass
        loss.backward()
        optimizer.step()

        if current_step % log_frequency == 0:
            print(f"Epoch {epoch+1}, Iteration {current_step}, Loss: {loss.item():.3f} Loss_s: {loss_s.item():.3f} Loss_b: {loss_b.item():.3f} Loss_sb: {loss_sb.item():.3f}")

        current_step += 1

    train_loss /= len(train_loader)

    print(f"End of Epoch {epoch+1}")
    print(f"Training loss: {train_loss:.5f}")


    val_loss, val_mean_iou, ious_per_class = evaluate(model, val_loader, device)


    print(f"Validation mIoU: {val_mean_iou*100:.3f}%, Validation loss: {val_loss:.5f}")

    for i, cat in enumerate(categories.keys()):
        if cat in miou_per_category:
            miou_per_category[cat] += [ious_per_class[i].item()]
        else:
            miou_per_category[cat] = [ious_per_class[i].item()]
        print(f"{cat} mIoU: {ious_per_class[i]*100:.3f}%")

    val_losses.append(val_loss)
    train_losses.append(train_loss)
    miou_scores.append(val_mean_iou)


    print()
    # Scheduler is None if learning rate is constant
    if scheduler is not None:
        scheduler.step()

miou_scores = list(map(lambda x: x.item(), miou_scores))

# Step 4a: Adversarial Domain Adaptation

### FC discriminator

In [None]:
class FCDiscriminator(nn.Module):

	def __init__(self, num_classes, ndf = 64):
		super(FCDiscriminator, self).__init__()

		self.conv1 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
		self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
		self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
		self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
		self.classifier = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=2, padding=1)

		self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
		self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear')
		self.sigmoid = nn.Sigmoid()



	def forward(self, x):
		x = self.conv1(x)
		x = self.leaky_relu(x)
		x = self.conv2(x)
		x = self.leaky_relu(x)
		x = self.conv3(x)
		x = self.leaky_relu(x)
		x = self.conv4(x)
		x = self.leaky_relu(x)
		x = self.classifier(x)
		#x = self.up_sample(x)
		#x = self.sigmoid(x)

		return x

## Run

### Parameters

In [None]:
resize = 512
BATCH_SIZE = 6
num_epochs = 20

LR = 1e-3
MOMEUNTUM = 0.9
WEIGHT_DECAY = 1e-2
STEP_SIZE = 10
GAMMA = 0.1

LAMBDA = 1e-3

log_frequency = 50

LR_D = 1e-5
WEIGHT_DECAY_D = 0

### Dataset preprocessing

#### Normalization metrics

In [None]:
num_workers = 2 if device.type == 'cuda' else 0

# Poiché il modello è pretrainato su ImageNet, si usano media e varianza di ImageNet
avg = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

#### Training set

In [None]:
train_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255),
    A.RandomCrop(height=resize, width=resize, p=0.5),
    A.Resize(resize, resize, p=1, always_apply=True)
])

train_dataset = LoveDA(TRAIN_DIR, IMG_PATH, MASK_PATH, directories=URBAN_PATH, bd=True, transforms=train_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

In [None]:
train_transform_target = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255),
    #A.RandomCrop(height=resize, width=resize, p=0.5),
    A.Resize(resize, resize, p=1, always_apply=True)
])

train_dataset_target = LoveDA(TRAIN_DIR, IMG_PATH, MASK_PATH, directories=RURAL_PATH, bd=True, transforms=train_transform_target)
train_loader_target = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

#### Validation set

In [None]:
val_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255)
])

val_dataset = LoveDA(VAL_DIR, IMG_PATH, MASK_PATH, directories=RURAL_PATH, bd=True, transforms=val_transform)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

### Training process

#### Model engine

In [None]:
pidnet = get_seg_model("pidnet_s", num_classes, PIDNET_S_WEIGHTS,imgnet_pretrained=True)
model = FullModel(pidnet, sem_loss=CrossEntropy(ignore_label=IGNORE_INDEX), bd_loss=BondaryLoss())
model = model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr = LR, momentum = MOMEUNTUM, weight_decay = WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = STEP_SIZE, gamma = GAMMA)

domain_criterion = nn.BCEWithLogitsLoss()

model_domain = FCDiscriminator(num_classes=7)
model_domain = model_domain.to(device)
domain_optimizer = torch.optim.Adam(model_domain.parameters(), lr = LR_D, weight_decay = WEIGHT_DECAY_D)

#### Evaluate function

In [None]:
@torch.no_grad()
def evaluate(model, dataloader, device, ) -> tuple:

    model.eval()

    running_loss = 0.0
    data_len = 0
    iou_scores = 0.0
    ious_per_class = torch.zeros(num_classes)

    for (inputs, masks, boundaries) in dataloader:

        data_len += inputs.size(0)

        inputs = inputs.to(device)
        masks = masks.to(device)
        boundaries = boundaries.to(device)

        # Forward pass
        loss, outputs, _, _ = model(inputs, masks, boundaries)
        running_loss += loss.item()*inputs.size(0)

        # Calculate mIoU
        iou, iou_per_class = calculate_iou(outputs[1], masks, num_classes)
        iou_scores += iou*inputs.size(0)
        ious_per_class+=iou_per_class.cpu()*inputs.size(0)

    mIoU = iou_scores/data_len
    loss = running_loss/data_len
    ious_per_class/=data_len


    return loss, mIoU, ious_per_class

#### Training/evaluation loop

In [None]:
val_losses, val_accuracies = [], []
train_losses, train_accuracies = [], []
miou_scores = []
best_mIoU = -1
best_num_epochs = None

for epoch in range(num_epochs):

    current_step = 0
    running_source_loss_seg = 0.0

    loss_G, loss_D = 0, 0

    model.train()
    model_domain.train()


    for (inputs, masks, boundaries), (target_inputs, _, _) in zip(train_loader, train_loader_target):

        inputs = inputs.to(device)
        masks = masks.to(device)
        boundaries = boundaries.to(device)
        target_inputs = target_inputs.to(device)

        # Train G
        for param in model_domain.parameters():
            param.requires_grad = False

        optimizer.zero_grad()
        domain_optimizer.zero_grad()

        ## train with source
        source_loss, [_, source_PIDNET_output, _], _, _ = model(inputs, masks, boundaries)
        source_loss.backward()
        running_source_loss_seg += source_loss.item()

        ## train with target
        _, [_, target_PIDNET_output, _], _, _ = model(target_inputs, None, None)
        preds = F.softmax(target_PIDNET_output, dim=1)
        D_out = model_domain(preds)

        domain_loss = LAMBDA * domain_criterion(D_out, torch.zeros_like(D_out))
        domain_loss.backward()
        loss_G += domain_loss.item()

        # Train D

        for param in model_domain.parameters():
            param.requires_grad = True

        ## train with source
        source_PIDNET_output = source_PIDNET_output.detach()
        preds = F.softmax(source_PIDNET_output, dim=1)
        D_out = model_domain(preds)

        domain_loss = domain_criterion(D_out, torch.zeros_like(D_out))
        domain_loss = domain_loss / 2
        domain_loss.backward()
        loss_D += domain_loss.item()

        ## train with target
        target_PIDNET_output = target_PIDNET_output.detach()
        preds = F.softmax(target_PIDNET_output, dim=1)
        D_out = model_domain(preds)

        domain_loss = domain_criterion(D_out, torch.ones_like(D_out))
        domain_loss = domain_loss / 2
        domain_loss.backward()
        loss_D += domain_loss.item()

        clip_grad.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), max_norm=35, norm_type=2)
        clip_grad.clip_grad_norm_(filter(lambda p: p.requires_grad, model_domain.parameters()), max_norm=35, norm_type=2)
        optimizer.step()
        domain_optimizer.step()


        if current_step % log_frequency == 0:
            print(f"Epoch {epoch+1}, Iteration {current_step}, Source loss: {running_source_loss_seg/(current_step+1):5f}, Domain loss: {loss_G/(current_step+1):.5f} ({loss_D/(current_step+1):.5f}")
        current_step += 1

    train_loss = running_source_loss_seg/len(train_loader)
    train_domain_loss_G = loss_G/len(train_loader)
    train_domain_loss_D = loss_D/len(train_loader)

    print(f"End of Epoch {epoch+1}")
    print(f"Training loss: {train_loss:.5f}")
    print(f"Domain loss G: {train_domain_loss_G:.5f}")
    print(f"Domain loss D: {train_domain_loss_D:.5f}")

    val_loss, val_mean_iou, ious_per_class = evaluate(model, val_loader, device)

    path=f"/content/drive/MyDrive/loveDA_dataset/Model training/PIDNet/PIDNet_{num_epochs}_{LR}_{STEP_SIZE}_{GAMMA}_{resize}_{WEIGHT_DECAY}_{MOMEUNTUM}_{GAMMA}_epoch{epoch}_DA.pth"
    torch.save(model.state_dict(), path)
    print(f"Validation mIoU: {val_mean_iou*100:.3f}%, Validation loss: {val_loss:.5f}")

    for i, cat in enumerate(categories.keys()):
        print(f"{cat} mIoU: {ious_per_class[i]*100:.3f}")

    val_losses.append(val_loss)
    train_losses.append(train_loss)
    miou_scores.append(val_mean_iou)

    print()
    # Scheduler is None if learning rate is constant
    if scheduler is not None:
        scheduler.step()

miou_scores = list(map(lambda x: x.item()*100, miou_scores))

# Step 4b: Image-to-Image Domain Adaptation




## Mix & EMA model

In [None]:
def oneMix(mask, data = None, target = None):
    #Mix
    if not (data is None):
        stackedMask0, _ = torch.broadcast_tensors(mask[0], data[0])
        data = (stackedMask0*data[0]+(1-stackedMask0)*data[1]).unsqueeze(0)
    if not (target is None):
        stackedMask0, _ = torch.broadcast_tensors(mask[0], target[0])
        target = (stackedMask0*target[0]+(1-stackedMask0)*target[1]).unsqueeze(0)
    return data, target


def generate_class_mask(pred, classes):
    pred, classes = torch.broadcast_tensors(pred.unsqueeze(0), classes.unsqueeze(1).unsqueeze(2))
    N = pred.eq(classes).sum(0)
    return N


def mix(parameters, data=None, target=None):
    assert ((data is not None) or (target is not None))
    data, target = oneMix(mask = parameters["Mix"], data = data, target = target)
    return data, target

def update_ema_variables(ema_model, model, alpha_teacher, iteration):
    # Use the "true" average until the exponential average is more correct
    alpha_teacher = min(1 - 1 / (iteration + 1), alpha_teacher)

    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        #ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
        ema_param.data[:] = alpha_teacher * ema_param[:].data[:] + (1 - alpha_teacher) * param[:].data[:]
    return ema_model

def create_ema_model(model):
    pidnet = get_seg_model("pidnet_s", num_classes, PIDNET_S_WEIGHTS,imgnet_pretrained=True)
    ema_model = FullModel(pidnet, sem_loss=CrossEntropy(ignore_label=IGNORE_INDEX), bd_loss=BondaryLoss())
    for param in ema_model.parameters():
        param.detach_()
    mp = list(model.parameters())
    mcp = list(ema_model.parameters())
    n = len(mp)
    for i in range(0, n):
        mcp[i].data[:] = mp[i].data[:].clone()
    return ema_model

## Unlabeled loss

In [None]:
def calc_U_loss(outputs):
  loss_s = sem_loss(outputs[:-1], targets_u)

  bd_gt = np.zeros_like(targets_u.cpu().numpy(), dtype=np.float32)
  for i, m in enumerate(targets_u):
    bd_gt[i] = generate_bd(m.cpu().numpy().astype(np.uint8))

  bd_gt = torch.from_numpy(bd_gt).to(device)

  loss_b = bd_loss(outputs[-1], bd_gt)

  filler = torch.ones_like(targets_u) * IGNORE_INDEX       #from original configs
  bd_label = torch.where(F.sigmoid(outputs[-1][:,0,:,:])>0.8, targets_u, filler)

  loss_sb = sem_loss([outputs[-2]], bd_label)

  return loss_s + loss_b + loss_sb

## Run

### Parameters

In [None]:
LR = 1e-3
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-2
num_epochs = 20
STEP_SIZE = 10
GAMMA = 0.1

BATCH_SIZE = 6
resize = 512
pixel_weight = "threshold_uniform"
#pixel_weight = False

### Dataset preprocessing

#### Normalization metrics

In [None]:
num_workers = 2 if device.type == 'cuda' else 0

# Poiché il modello è pretrainato su ImageNet, si usano media e varianza di ImageNet
avg = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

#### Training set

In [None]:
train_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255),
    A.RandomCrop(height=resize, width=resize, p=0.5),
    A.Resize(resize, resize, p=1, always_apply=True)
])

train_transform_target = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255),
    A.RandomCrop(height=resize, width=resize, p=0.5),
    A.Resize(resize, resize, p=1, always_apply=True)
])

train_dataset = LoveDA(TRAIN_DIR, IMG_PATH, MASK_PATH, directories=URBAN_PATH, bd=True, transforms=train_transform)
source_trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)
train_dataset_target = LoveDA(TRAIN_DIR, IMG_PATH, MASK_PATH, directories=RURAL_PATH, bd=True, transforms=train_transform_target)
target_trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

#### Validation set

In [None]:
# La resize è bene farla solo sul training set
# La normalizzazione invece può essere applicata anche qui

val_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255)
])

val_dataset = LoveDA(VAL_DIR, IMG_PATH, MASK_PATH, directories=RURAL_PATH, bd=True, transforms=val_transform)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

### Training process

#### Model engine

In [None]:
pidnet = get_seg_model("pidnet_s", num_classes, PIDNET_S_WEIGHTS,imgnet_pretrained=True)
model = FullModel(pidnet, sem_loss=CrossEntropy(ignore_label=IGNORE_INDEX), bd_loss=BondaryLoss())
model.to(device)

ema_model = create_ema_model(model)
ema_model = ema_model.to(device)

optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = STEP_SIZE, gamma = GAMMA)

sem_loss=CrossEntropy(ignore_label=IGNORE_INDEX)
bd_loss=BondaryLoss()


#### Evaluate function

In [None]:
@torch.no_grad()
def evaluate(model, dataloader, device, ) -> tuple:

    model.eval()

    running_loss = 0.0
    data_len = 0
    iou_scores = 0.0
    ious_per_class = torch.zeros(num_classes)

    for (inputs, masks, boundaries) in dataloader:

        data_len += inputs.size(0)

        inputs = inputs.to(device)
        masks = masks.to(device)
        boundaries = boundaries.to(device)

        # Forward pass
        loss, outputs, _, _ = model(inputs, masks, boundaries)
        running_loss += loss.item()*inputs.size(0)

        # Calculate mIoU
        iou, iou_per_class = calculate_iou(outputs[1], masks, num_classes)
        iou_scores += iou*inputs.size(0)
        ious_per_class+=iou_per_class.cpu()*inputs.size(0)

    mIoU = iou_scores/data_len
    loss = running_loss/data_len
    ious_per_class/=data_len


    return loss, mIoU, ious_per_class

#### Training/evaluation loop

In [None]:
ema_model.train()


accumulated_loss_l = []
accumulated_loss_u = []

miou_scores = []
training_losses = []
validation_losses = []

for epoch in range(num_epochs):
    model.train()

    loss_u_value = 0
    loss_l_value = 0

    n = 0
    for (src_images, src_labels, src_bd), (tgt_images, _, _) in zip(source_trainloader, target_trainloader):
        optimizer.zero_grad()

        src_images = src_images.to(device)
        src_labels = src_labels.to(device)
        tgt_images = tgt_images.to(device)
        src_bd = src_bd.to(device)


        L_l, [_, pred, _], _, _ = model(src_images, src_labels, src_bd)

        # _, [_, logits_u_w, _], _, _ = ema_model(tgt_images, None, None)
        _, [_, logits_u_w, _], _, _ = model(tgt_images, None, None)


        pseudo_label = torch.softmax(logits_u_w.detach(), dim=1)
        max_probs, targets_u_w = torch.max(pseudo_label, dim=1)

        inputs_u_s = []
        targets_u = []
        pixel_weights = []


        for i in range(len(src_images)):
            classes = torch.unique(src_labels[i])
            nclasses = classes.shape[0]
            classes = (classes[torch.Tensor(np.random.choice(nclasses, int((nclasses+nclasses%2)/2),replace=False)).long()]).to(device)
            MixMask_i = generate_class_mask(src_labels[i], classes).unsqueeze(0).to(device)

            strong_parameters = {"Mix": MixMask_i}

            inputs_u_si, _ = mix(strong_parameters, data = torch.cat((src_images[i].unsqueeze(0),tgt_images[i].unsqueeze(0))))
            inputs_u_s.append(inputs_u_si)

            _, targets_ui = mix(strong_parameters, target = torch.cat((src_labels[i].unsqueeze(0),targets_u_w[i].unsqueeze(0))))
            targets_u.append(targets_ui)

        inputs_u_s = torch.cat(inputs_u_s)
        _, outputs, _, _ = model(inputs_u_s, None, None)
        logits_u_s = outputs[1]

        targets_u = torch.cat(targets_u).long().to(device)


        if pixel_weight == "threshold_uniform":
            unlabeled_weight = torch.sum(max_probs.ge(0.968).long() == 1).item() / np.size(np.array(targets_u.cpu()))
            pixelWiseWeight = unlabeled_weight * torch.ones(max_probs.shape).to(device)
        elif pixel_weight == "threshold":
            pixelWiseWeight = max_probs.ge(0.968).float().to(device)
        elif pixel_weight == False:
            pixelWiseWeight = torch.ones(max_probs.shape).to(device)


        onesWeights = torch.ones((pixelWiseWeight.shape)).to(device)
        for i in range(len(src_images)):
            _, pixelWiseWeight_i = mix(strong_parameters, target = torch.cat((onesWeights[0].unsqueeze(0),pixelWiseWeight[0].unsqueeze(0))))
            pixel_weights.append(pixelWiseWeight_i)


        pixel_weights = torch.cat(pixel_weights).to(device)


        L_u = calc_U_loss(outputs)
        L_u *= torch.mean(pixel_weights)

        loss = L_l + L_u

        loss_l_value += L_l.item()
        loss_u_value += L_u.item()

        loss.backward()
        optimizer.step()


        if n %25 == 0:
          print('\tProcessed {0:d} batches, loss_l = {1:.3f}, loss_u = {2:.3f} loss = {3:.3f}'.format(n, loss_l_value/(n+1), loss_u_value/(n+1),(loss_l_value+loss_u_value)/(n+1)))

        n+=1

    loss_l_value /= len(source_trainloader)
    loss_u_value /= len(target_trainloader)

    accumulated_loss_l.append(loss_l_value)
    accumulated_loss_u.append(loss_u_value)
    training_losses.append(loss_l_value+loss_u_value)

    # Update learning rate
    scheduler.step()

    # update Mean teacher network
    alpha_teacher = 0.99
    ema_model = update_ema_variables(ema_model = ema_model, model = model, alpha_teacher=alpha_teacher, iteration=epoch)

    print('iter = {0:6d}/{1:6d}, loss_l = {2:.3f}, loss_u = {3:.3f} loss = {4:.3f}'.format(epoch+1, num_epochs, loss_l_value, loss_u_value, loss_l_value+loss_u_value))



    val_loss, val_mean_iou, ious_per_class = evaluate(model, val_loader, device)
    print(f"Validation mIoU: {val_mean_iou*100:.3f}%, Validation loss: {val_loss:.5f}")
    for i, cat in enumerate(categories.keys()):
        print(f"{cat} mIoU: {ious_per_class[i]*100:.3f}")
    print()

    validation_losses.append(val_loss)
    miou_scores.append(val_mean_iou)

In [None]:
for i in range(BATCH_SIZE):
  plt.imshow(src_images[i].permute(1,2,0).cpu()*torch.tensor(std)+torch.tensor(avg))
  plt.show()
  plt.imshow(tgt_images[i].permute(1,2,0).cpu()*torch.tensor(std)+torch.tensor(avg))
  plt.show()
  plt.imshow(inputs_u_s[i].permute(1,2,0).cpu()*torch.tensor(std)+torch.tensor(avg))
  plt.show()

  plot_tensor_mask(src_labels[i].cpu(), categories)
  plot_tensor_mask(targets_u_w[i].cpu(), categories)
  plot_tensor_mask(targets_u[i].cpu(), categories)

# Step 5: Improving the results

### Calculate class distribution

In [None]:
urban_dataset = LoveDA(TRAIN_DIR, IMG_PATH, MASK_PATH, directories=URBAN_PATH)
urban_loader = DataLoader(urban_dataset, batch_size=64, worker_init_fn=seed_worker, generator=g)

rural_dataset = LoveDA(TRAIN_DIR, IMG_PATH, MASK_PATH, directories=RURAL_PATH)
rural_loader = DataLoader(rural_dataset, batch_size=64, worker_init_fn=seed_worker, generator=g)

urban_classes = dict()
rural_classes = dict()

for (_, masks) in urban_loader:

      masks = masks.to(device)

      for i, cat in enumerate(categories.keys()):
        if cat in urban_classes:
          urban_classes[cat] += torch.count_nonzero(masks == i)
        else:
          urban_classes[cat] = torch.count_nonzero(masks == i)

for (_, masks) in rural_loader:

      masks = masks.to(device)

      for i, cat in enumerate(categories.keys()):
        if cat in rural_classes:
          rural_classes[cat] += torch.count_nonzero(masks == i)
        else:
          rural_classes[cat] = torch.count_nonzero(masks == i)


In [None]:
colors= [np.array(color)/255 for _, color in sorted(categories.values())]

wedges, texts, autotexts= plt.pie([v.cpu().numpy() for v in urban_classes.values()], labels=urban_classes.keys(), colors=colors, autopct='%1.1f%%', pctdistance=0.85, labeldistance=1.1, startangle=90)
for text in texts:
    text.set_fontsize(12)
for autotext in autotexts:
    autotext.set_fontsize(9)
plt.title('Urban Dataset', fontsize=16)
plt.tight_layout()
plt.show()

print("urban_percentage = ", [float(autotext.get_text().strip('%')) for autotext in autotexts])

wedges, texts, autotexts= plt.pie([v.cpu().numpy() for v in rural_classes.values()], labels=rural_classes.keys(), colors=colors, autopct='%1.1f%%', pctdistance=0.85, labeldistance=1.1, startangle=90)
for text in texts:
    text.set_fontsize(12)
for autotext in autotexts:
    autotext.set_fontsize(9)
plt.title("Rural dataset", fontsize=16)
plt.tight_layout()
plt.show()

print("rural_percentage = ", [float(autotext.get_text().strip('%')) for autotext in autotexts])



### Calculate class weights

In [None]:
import numpy as np

def calc_weights(percentages):
  percentages = np.array(percentages)
  proportions = percentages / 100  # Divide by 100 to convert percentages to fractions

  # Calculate class weights inversely proportional to proportions
  class_weights = 1 / proportions

  # Optional: Normalize weights so the mean is 1
  normalized_weights = class_weights / np.mean(class_weights)

  alpha = 0.5  # Adjust this hyperparameter
  softened_weights = 1 / (proportions ** alpha)
  softened_weights /= np.mean(softened_weights)

  normalized_weights_v2 = class_weights / max(class_weights)


  return list(class_weights), list(normalized_weights), list(softened_weights), list(normalized_weights_v2)



urban_percentage =  [48.5, 21.2, 9.3, 3.7, 7.6, 7.9, 1.9]
rural_percentage =  [42.9, 3.7, 2.6, 11.6, 3.6, 5.0, 30.5]

urban_class_weights , urban_normalized_weights, urban_softened_weights, urban_normalized_weights_v2 = calc_weights(urban_percentage)
rural_class_weights , rural_normalized_weights, rural_softened_weights, rural_normalized_weights_v2 = calc_weights(rural_percentage)

print(f"urban_class_weights = {urban_class_weights}")
print(f"urban_normalized_weights = {urban_normalized_weights}")
print(f"urban_softened_weights = {urban_softened_weights}")
print(f"urban_normalized_weights_v2 = {urban_normalized_weights_v2}")

print()

print(f"rural_class_weights = {rural_class_weights}")
print(f"rural_normalized_weights = {rural_normalized_weights}")
print(f"rural_softened_weights = {rural_softened_weights}")
print(f"rural_normalized_weights_v2 = {rural_normalized_weights_v2}")

We then used **urban_softened_weights** passing them to the Cross Entropy and trained the model. Training loop is not reported again

In [None]:
pidnet = get_seg_model("pidnet_s", num_classes, PIDNET_S_WEIGHTS,imgnet_pretrained=True)
model = FullModel(pidnet, sem_loss=CrossEntropy(ignore_label=IGNORE_INDEX, weight=torch.tensor(urban_softened_weights)), bd_loss=BondaryLoss())
model = model.to(device)

## OHEM Loss

In [None]:
class OhemCrossEntropy(nn.Module):
    def __init__(self, ignore_label=-1, thres=0.7,
                 min_kept=100000, weight=None):
        super(OhemCrossEntropy, self).__init__()
        self.thresh = thres
        self.min_kept = max(1, min_kept)
        self.ignore_label = ignore_label
        self.criterion = nn.CrossEntropyLoss(
            weight=weight,
            ignore_index=ignore_label,
            reduction='none'
        )

    def _ce_forward(self, score, target):

        loss = self.criterion(score, target)

        return loss.mean()

    def _ohem_forward(self, score, target, **kwargs):

        pred = F.softmax(score, dim=1)
        pixel_losses = self.criterion(score, target).contiguous().view(-1)
        mask = target.contiguous().view(-1) != self.ignore_label

        tmp_target = target.clone()
        tmp_target[tmp_target == self.ignore_label] = 0
        pred = pred.gather(1, tmp_target.unsqueeze(1))
        pred, ind = pred.contiguous().view(-1,)[mask].contiguous().sort()

        min_value = pred[min(self.min_kept, pred.numel() - 1)]

        threshold = max(min_value, self.thresh)

        pixel_losses = pixel_losses[mask][ind]
        pixel_losses = pixel_losses[pred < threshold]
        return pixel_losses.mean()

    def forward(self, score, target):

        if not (isinstance(score, list) or isinstance(score, tuple)):
            score = [score]

        balance_weights = [0.4, 1.0]
        sb_weights = 1.0

        if len(balance_weights) == len(score):
            functions = [self._ce_forward] * (len(balance_weights) - 1) + [self._ohem_forward]
            return sum([w * func(x, target) for (w, x, func) in zip(balance_weights, score, functions)])

        elif len(score) == 1:
            return sb_weights * self._ohem_forward(score[0], target)

        else:
            raise ValueError("lengths of prediction and target are not identical!")


## Focal Loss

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, ignore_label=-1, weight=None):
        super(FocalLoss, self).__init__()
        self.ignore_label = ignore_label
        self.gamma = gamma
        self.alpha = alpha
        self.criterion = nn.CrossEntropyLoss(
            weight=weight,
            ignore_index=ignore_label
        )

    def _forward(self, score, target):

        ce_loss = self.criterion(score, target)

        pt = torch.exp(-ce_loss)
        focal_loss = torch.pow(1 - pt, self.gamma)

        if self.alpha is not None:
            return self.alpha * focal_loss
        return focal_loss

    def forward(self, score, target):

        if not (isinstance(score, list) or isinstance(score, tuple)):
            score = [score]

        # From original configs
        balance_weights = [0.4, 1.0]
        sb_weights = 1.0

        if len(balance_weights) == len(score):
            return sum([w * self._forward(x, target) for (w, x) in zip(balance_weights, score)])
        elif len(score) == 1:
            return sb_weights * self._forward(score[0], target)

        else:
            raise ValueError("lengths of prediction and target are not identical!")

# Step 5: BiSeNet

### Context path

In [None]:
import torch
from torchvision import models


class resnet18(torch.nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.resnet18(pretrained=pretrained)
        self.conv1 = self.features.conv1
        self.bn1 = self.features.bn1
        self.relu = self.features.relu
        self.maxpool1 = self.features.maxpool
        self.layer1 = self.features.layer1
        self.layer2 = self.features.layer2
        self.layer3 = self.features.layer3
        self.layer4 = self.features.layer4

    def forward(self, input):
        x = self.conv1(input)
        x = self.relu(self.bn1(x))
        x = self.maxpool1(x)
        feature1 = self.layer1(x)  # 1 / 4
        feature2 = self.layer2(feature1)  # 1 / 8
        feature3 = self.layer3(feature2)  # 1 / 16
        feature4 = self.layer4(feature3)  # 1 / 32
        # global average pooling to build tail
        tail = torch.mean(feature4, 3, keepdim=True)
        tail = torch.mean(tail, 2, keepdim=True)
        return feature3, feature4, tail


class resnet101(torch.nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.resnet101(pretrained=pretrained)
        self.conv1 = self.features.conv1
        self.bn1 = self.features.bn1
        self.relu = self.features.relu
        self.maxpool1 = self.features.maxpool
        self.layer1 = self.features.layer1
        self.layer2 = self.features.layer2
        self.layer3 = self.features.layer3
        self.layer4 = self.features.layer4

    def forward(self, input):
        x = self.conv1(input)
        x = self.relu(self.bn1(x))
        x = self.maxpool1(x)
        feature1 = self.layer1(x)  # 1 / 4
        feature2 = self.layer2(feature1)  # 1 / 8
        feature3 = self.layer3(feature2)  # 1 / 16
        feature4 = self.layer4(feature3)  # 1 / 32
        # global average pooling to build tail
        tail = torch.mean(feature4, 3, keepdim=True)
        tail = torch.mean(tail, 2, keepdim=True)
        return feature3, feature4, tail


def build_contextpath(name):
    model = {
        'resnet18': resnet18(pretrained=True),
        'resnet101': resnet101(pretrained=True)
    }
    return model[name]

### Model

In [None]:
import torch
from torch import nn
import warnings
warnings.filterwarnings(action='ignore')


class ConvBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, input):
        x = self.conv1(input)
        return self.relu(self.bn(x))


class Spatial_path(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.convblock1 = ConvBlock(in_channels=3, out_channels=64)
        self.convblock2 = ConvBlock(in_channels=64, out_channels=128)
        self.convblock3 = ConvBlock(in_channels=128, out_channels=256)

    def forward(self, input):
        x = self.convblock1(input)
        x = self.convblock2(x)
        x = self.convblock3(x)
        return x


class AttentionRefinementModule(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.sigmoid = nn.Sigmoid()
        self.in_channels = in_channels
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

    def forward(self, input):
        # global average pooling
        x = self.avgpool(input)
        assert self.in_channels == x.size(1), 'in_channels and out_channels should all be {}'.format(x.size(1))
        x = self.conv(x)
        x = self.sigmoid(self.bn(x))
        # x = self.sigmoid(x)
        # channels of input and x should be same
        x = torch.mul(input, x)
        return x


class FeatureFusionModule(torch.nn.Module):
    def __init__(self, num_classes, in_channels):
        super().__init__()
        # self.in_channels = input_1.channels + input_2.channels
        # resnet101 3328 = 256(from spatial path) + 1024(from context path) + 2048(from context path)
        # resnet18  1024 = 256(from spatial path) + 256(from context path) + 512(from context path)
        self.in_channels = in_channels

        self.convblock = ConvBlock(in_channels=self.in_channels, out_channels=num_classes, stride=1)
        self.conv1 = nn.Conv2d(num_classes, num_classes, kernel_size=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(num_classes, num_classes, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

    def forward(self, input_1, input_2):
        x = torch.cat((input_1, input_2), dim=1)
        assert self.in_channels == x.size(1), 'in_channels of ConvBlock should be {}'.format(x.size(1))
        feature = self.convblock(x)
        x = self.avgpool(feature)

        x = self.relu(self.conv1(x))
        x = self.sigmoid(self.conv2(x))
        x = torch.mul(feature, x)
        x = torch.add(x, feature)
        return x


class BiSeNet(torch.nn.Module):
    def __init__(self, num_classes, context_path):
        super().__init__()
        # build spatial path
        self.saptial_path = Spatial_path()

        # build context path
        self.context_path = build_contextpath(name=context_path)

        # build attention refinement module  for resnet 101
        if context_path == 'resnet101':
            self.attention_refinement_module1 = AttentionRefinementModule(1024, 1024)
            self.attention_refinement_module2 = AttentionRefinementModule(2048, 2048)
            # supervision block
            self.supervision1 = nn.Conv2d(in_channels=1024, out_channels=num_classes, kernel_size=1)
            self.supervision2 = nn.Conv2d(in_channels=2048, out_channels=num_classes, kernel_size=1)
            # build feature fusion module
            self.feature_fusion_module = FeatureFusionModule(num_classes, 3328)

        elif context_path == 'resnet18':
            # build attention refinement module  for resnet 18
            self.attention_refinement_module1 = AttentionRefinementModule(256, 256)
            self.attention_refinement_module2 = AttentionRefinementModule(512, 512)
            # supervision block
            self.supervision1 = nn.Conv2d(in_channels=256, out_channels=num_classes, kernel_size=1)
            self.supervision2 = nn.Conv2d(in_channels=512, out_channels=num_classes, kernel_size=1)
            # build feature fusion module
            self.feature_fusion_module = FeatureFusionModule(num_classes, 1024)
        else:
            print('Error: unspport context_path network \n')

        # build final convolution
        self.conv = nn.Conv2d(in_channels=num_classes, out_channels=num_classes, kernel_size=1)

        self.init_weight()

        self.mul_lr = []
        self.mul_lr.append(self.saptial_path)
        self.mul_lr.append(self.attention_refinement_module1)
        self.mul_lr.append(self.attention_refinement_module2)
        self.mul_lr.append(self.supervision1)
        self.mul_lr.append(self.supervision2)
        self.mul_lr.append(self.feature_fusion_module)
        self.mul_lr.append(self.conv)

    def init_weight(self):
        for name, m in self.named_modules():
            if 'context_path' not in name:
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                elif isinstance(m, nn.BatchNorm2d):
                    m.eps = 1e-5
                    m.momentum = 0.1
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

    def forward(self, input):
        # output of spatial path
        sx = self.saptial_path(input)

        # output of context path
        cx1, cx2, tail = self.context_path(input)
        cx1 = self.attention_refinement_module1(cx1)
        cx2 = self.attention_refinement_module2(cx2)
        cx2 = torch.mul(cx2, tail)
        # upsampling
        cx1 = torch.nn.functional.interpolate(cx1, size=sx.size()[-2:], mode='bilinear')
        cx2 = torch.nn.functional.interpolate(cx2, size=sx.size()[-2:], mode='bilinear')
        cx = torch.cat((cx1, cx2), dim=1)

        if self.training == True:
            cx1_sup = self.supervision1(cx1)
            cx2_sup = self.supervision2(cx2)
            cx1_sup = torch.nn.functional.interpolate(cx1_sup, size=input.size()[-2:], mode='bilinear')
            cx2_sup = torch.nn.functional.interpolate(cx2_sup, size=input.size()[-2:], mode='bilinear')

        # output of feature fusion module
        result = self.feature_fusion_module(sx, cx)

        # upsampling
        result = torch.nn.functional.interpolate(result, scale_factor=8, mode='bilinear')
        result = self.conv(result)

        if self.training == True:
            return result, cx1_sup, cx2_sup

        return result

## Run

### Parameters

In [None]:
num_epochs = 20
BATCH_SIZE = 6
learning_rate = 1e-3
step_size = 10
gamma = 0.1
resize = 512
w_decay = 1e-4

### Dataset preprocessing

#### Normalization metrics

In [None]:
num_workers = 2 if device.type == 'cuda' else 0

# Poiché il modello è pretrainato su ImageNet, si usano media e varianza di ImageNet
avg = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

#### Training set

In [None]:
train_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255),
    A.Resize(resize, resize, p=1, always_apply=True)
])
train_dataset = LoveDA(TRAIN_DIR, IMG_PATH, MASK_PATH, directories=URBAN_PATH, transforms=train_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

#### Validation set

In [None]:
val_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255)
])
#val_dataset = LoveDA(VAL_DIR, IMG_PATH, MASK_PATH, directories=URBAN_PATH, transforms=val_transform)
val_dataset = LoveDA(VAL_DIR, IMG_PATH, MASK_PATH, directories=RURAL_PATH, transforms=val_transform)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

### Training process

#### Model engine

In [None]:
model = BiSeNet(num_classes,'resnet101').to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=w_decay)
scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)

#### Evaluate function

In [None]:
@torch.no_grad()
def evaluate(model, dataloader, device, ) -> tuple:

    model.eval()

    running_loss = 0.0
    data_len = 0
    iou_scores = 0.0
    ious_per_class = torch.zeros(num_classes)

    for i, (inputs, masks) in enumerate(dataloader):

        data_len += inputs.size(0)

        inputs = inputs.to(device)
        masks = masks.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, masks)

        running_loss += loss.item()*inputs.size(0)

        # Calculate mIoU
        iou, iou_per_class = calculate_iou(outputs, masks, num_classes)
        iou_scores += iou*inputs.size(0)
        ious_per_class+=iou_per_class.cpu()*inputs.size(0)

    mIoU = iou_scores/data_len
    loss = running_loss/data_len
    ious_per_class/=data_len


    return loss, mIoU, ious_per_class

#### Training loop

In [None]:
train_losses = []
eval_losses = []
mious = []

for epoch in range(num_epochs):
    print("### Training mode")
    model.train()
    running_loss = 0.0
    for i, (images, masks) in enumerate(train_loader):
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        outputs, outputs16, outputs32 = model(images)
        loss1 = criterion(outputs, masks)
        loss2 = criterion(outputs16, masks)
        loss3 = criterion(outputs32, masks)
        loss = loss1 + loss2 + loss3
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if (i + 1) % 25 == 0:
            print(f"Processed {i + 1} batches, loss: {running_loss / (i+1)}")

    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss:.4f}")

    print("### Evaluation mode")
    val_loss, miou, ious_per_class = evaluate(model, val_loader, device)

    print(f"Validation mIoU: {miou*100:.3f}%, Validation loss: {val_loss:.5f}")
    for i, cat in enumerate(categories.keys()):
        print(f"{cat} mIoU: {ious_per_class[i]*100:.3f}")
    print()

    eval_losses.append(val_loss)

    mious.append(miou)

    print(f"Epoch: [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}, mIoU: {(miou * 100):.2f}%")




    scheduler.step()

### Metric calculation

In [None]:
model.eval()
mean_latency, _, _, _ = calculate_latency_fps(model, device, 1024, 1024, num_epochs, ModelType.BISENET)
print(f"Mean Latency: {mean_latency:.2f} ms")

calculate_flops_params(model, device, 1024, 1024, ModelType.BISENET)

# Step 5: STDC

### Download pre-trained weights

In [None]:
weights_dir = Path(PRETRAINED_WEIGHTS_DIR)
if not weights_dir.exists():
    weights_dir.mkdir(exist_ok=True)

stdc1_weights = Path(STDC1_WEIGHTS)
if not stdc1_weights.exists():
    # Replace with the correct Google Drive file ID
    file_id = "1DFoXcV42zy-apUcMh5P8WhsXMRJofgl8"
    gdown.download(id=file_id, output=str(stdc1_weights), quiet=False)

### Nets

In [None]:
import torch
import torch.nn as nn
from torch.nn import init
import math



class ConvX(nn.Module):
    def __init__(self, in_planes, out_planes, kernel=3, stride=1):
        super(ConvX, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel//2, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.relu(self.bn(self.conv(x)))
        return out


class AddBottleneck(nn.Module):
    def __init__(self, in_planes, out_planes, block_num=3, stride=1):
        super(AddBottleneck, self).__init__()
        assert block_num > 1, print("block number should be larger than 1.")
        self.conv_list = nn.ModuleList()
        self.stride = stride
        if stride == 2:
            self.avd_layer = nn.Sequential(
                nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False),
                nn.BatchNorm2d(out_planes//2),
            )
            self.skip = nn.Sequential(
                nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=2, padding=1, groups=in_planes, bias=False),
                nn.BatchNorm2d(in_planes),
                nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_planes),
            )
            stride = 1

        for idx in range(block_num):
            if idx == 0:
                self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1))
            elif idx == 1 and block_num == 2:
                self.conv_list.append(ConvX(out_planes//2, out_planes//2, stride=stride))
            elif idx == 1 and block_num > 2:
                self.conv_list.append(ConvX(out_planes//2, out_planes//4, stride=stride))
            elif idx < block_num - 1:
                self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx+1))))
            else:
                self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx))))

    def forward(self, x):
        out_list = []
        out = x

        for idx, conv in enumerate(self.conv_list):
            if idx == 0 and self.stride == 2:
                out = self.avd_layer(conv(out))
            else:
                out = conv(out)
            out_list.append(out)

        if self.stride == 2:
            x = self.skip(x)

        return torch.cat(out_list, dim=1) + x



class CatBottleneck(nn.Module):
    def __init__(self, in_planes, out_planes, block_num=3, stride=1):
        super(CatBottleneck, self).__init__()
        assert block_num > 1, print("block number should be larger than 1.")
        self.conv_list = nn.ModuleList()
        self.stride = stride
        if stride == 2:
            self.avd_layer = nn.Sequential(
                nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False),
                nn.BatchNorm2d(out_planes//2),
            )
            self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
            stride = 1

        for idx in range(block_num):
            if idx == 0:
                self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1))
            elif idx == 1 and block_num == 2:
                self.conv_list.append(ConvX(out_planes//2, out_planes//2, stride=stride))
            elif idx == 1 and block_num > 2:
                self.conv_list.append(ConvX(out_planes//2, out_planes//4, stride=stride))
            elif idx < block_num - 1:
                self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx+1))))
            else:
                self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx))))

    def forward(self, x):
        out_list = []
        out1 = self.conv_list[0](x)

        for idx, conv in enumerate(self.conv_list[1:]):
            if idx == 0:
                if self.stride == 2:
                    out = conv(self.avd_layer(out1))
                else:
                    out = conv(out1)
            else:
                out = conv(out)
            out_list.append(out)

        if self.stride == 2:
            out1 = self.skip(out1)
        out_list.insert(0, out1)

        out = torch.cat(out_list, dim=1)
        return out

#STDC2Net
class STDCNet1446(nn.Module):
    def __init__(self, base=64, layers=[4,5,3], block_num=4, type="cat", num_classes=1000, dropout=0.20, pretrain_model='', use_conv_last=False):
        super(STDCNet1446, self).__init__()
        if type == "cat":
            block = CatBottleneck
        elif type == "add":
            block = AddBottleneck
        self.use_conv_last = use_conv_last
        self.features = self._make_layers(base, layers, block_num, block)
        self.conv_last = ConvX(base*16, max(1024, base*16), 1, 1)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(max(1024, base*16), max(1024, base*16), bias=False)
        self.bn = nn.BatchNorm1d(max(1024, base*16))
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout)
        self.linear = nn.Linear(max(1024, base*16), num_classes, bias=False)

        self.x2 = nn.Sequential(self.features[:1])
        self.x4 = nn.Sequential(self.features[1:2])
        self.x8 = nn.Sequential(self.features[2:6])
        self.x16 = nn.Sequential(self.features[6:11])
        self.x32 = nn.Sequential(self.features[11:])

        if pretrain_model:
            print('use pretrain model {}'.format(pretrain_model))
            self.init_weight(pretrain_model)
        else:
            self.init_params()

    def init_weight(self, pretrain_model):

        state_dict = torch.load(pretrain_model)["state_dict"]
        self_state_dict = self.state_dict()
        for k, v in state_dict.items():
            self_state_dict.update({k: v})
        self.load_state_dict(self_state_dict)

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def _make_layers(self, base, layers, block_num, block):
        features = []
        features += [ConvX(3, base//2, 3, 2)]
        features += [ConvX(base//2, base, 3, 2)]

        for i, layer in enumerate(layers):
            for j in range(layer):
                if i == 0 and j == 0:
                    features.append(block(base, base*4, block_num, 2))
                elif j == 0:
                    features.append(block(base*int(math.pow(2,i+1)), base*int(math.pow(2,i+2)), block_num, 2))
                else:
                    features.append(block(base*int(math.pow(2,i+2)), base*int(math.pow(2,i+2)), block_num, 1))

        return nn.Sequential(*features)

    def forward(self, x):
        feat2 = self.x2(x)
        feat4 = self.x4(feat2)
        feat8 = self.x8(feat4)
        feat16 = self.x16(feat8)
        feat32 = self.x32(feat16)
        if self.use_conv_last:
           feat32 = self.conv_last(feat32)

        return feat2, feat4, feat8, feat16, feat32

    def forward_impl(self, x):
        out = self.features(x)
        out = self.conv_last(out).pow(2)
        out = self.gap(out).flatten(1)
        out = self.fc(out)
        # out = self.bn(out)
        out = self.relu(out)
        # out = self.relu(self.bn(self.fc(out)))
        out = self.dropout(out)
        out = self.linear(out)
        return out

# STDC1Net
class STDCNet813(nn.Module):
    def __init__(self, base=64, layers=[2,2,2], block_num=4, type="cat", num_classes=1000, dropout=0.20, pretrain_model='', use_conv_last=False):
        super(STDCNet813, self).__init__()
        if type == "cat":
            block = CatBottleneck
        elif type == "add":
            block = AddBottleneck
        self.use_conv_last = use_conv_last
        self.features = self._make_layers(base, layers, block_num, block)
        self.conv_last = ConvX(base*16, max(1024, base*16), 1, 1)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(max(1024, base*16), max(1024, base*16), bias=False)
        self.bn = nn.BatchNorm1d(max(1024, base*16))
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout)
        self.linear = nn.Linear(max(1024, base*16), num_classes, bias=False)

        self.x2 = nn.Sequential(self.features[:1])
        self.x4 = nn.Sequential(self.features[1:2])
        self.x8 = nn.Sequential(self.features[2:4])
        self.x16 = nn.Sequential(self.features[4:6])
        self.x32 = nn.Sequential(self.features[6:])

        if pretrain_model:
            print('use pretrain model {}'.format(pretrain_model))
            self.init_weight(pretrain_model)
        else:
            self.init_params()

    def init_weight(self, pretrain_model):

        state_dict = torch.load(pretrain_model)["state_dict"]
        self_state_dict = self.state_dict()
        for k, v in state_dict.items():
            self_state_dict.update({k: v})
        self.load_state_dict(self_state_dict)

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def _make_layers(self, base, layers, block_num, block):
        features = []
        features += [ConvX(3, base//2, 3, 2)]
        features += [ConvX(base//2, base, 3, 2)]

        for i, layer in enumerate(layers):
            for j in range(layer):
                if i == 0 and j == 0:
                    features.append(block(base, base*4, block_num, 2))
                elif j == 0:
                    features.append(block(base*int(math.pow(2,i+1)), base*int(math.pow(2,i+2)), block_num, 2))
                else:
                    features.append(block(base*int(math.pow(2,i+2)), base*int(math.pow(2,i+2)), block_num, 1))

        return nn.Sequential(*features)

    def forward(self, x):
        feat2 = self.x2(x)
        feat4 = self.x4(feat2)
        feat8 = self.x8(feat4)
        feat16 = self.x16(feat8)
        feat32 = self.x32(feat16)
        if self.use_conv_last:
           feat32 = self.conv_last(feat32)

        return feat2, feat4, feat8, feat16, feat32

    def forward_impl(self, x):
        out = self.features(x)
        out = self.conv_last(out).pow(2)
        out = self.gap(out).flatten(1)
        out = self.fc(out)
        # out = self.bn(out)
        out = self.relu(out)
        # out = self.relu(self.bn(self.fc(out)))
        out = self.dropout(out)
        out = self.linear(out)
        return out


### BN

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as functional

try:
    from queue import Queue
except ImportError:
    from Queue import Queue


class ABN(nn.Module):
    """Activated Batch Normalization

    This gathers a `BatchNorm2d` and an activation function in a single module
    """

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
        """Creates an Activated Batch Normalization module

        Parameters
        ----------
        num_features : int
            Number of feature channels in the input and output.
        eps : float
            Small constant to prevent numerical issues.
        momentum : float
            Momentum factor applied to compute running statistics as.
        affine : bool
            If `True` apply learned scale and shift transformation after normalization.
        activation : str
            Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
        slope : float
            Negative slope for the `leaky_relu` activation.
        """
        super(ABN, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps
        self.momentum = momentum
        self.activation = activation
        self.slope = slope
        if self.affine:
            self.weight = nn.Parameter(torch.ones(num_features))
            self.bias = nn.Parameter(torch.zeros(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.constant_(self.running_mean, 0)
        nn.init.constant_(self.running_var, 1)
        if self.affine:
            nn.init.constant_(self.weight, 1)
            nn.init.constant_(self.bias, 0)

    def forward(self, x):
        x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
                                  self.training, self.momentum, self.eps)

        if self.activation == ACT_RELU:
            return functional.relu(x, inplace=True)
        elif self.activation == ACT_LEAKY_RELU:
            return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
        elif self.activation == ACT_ELU:
            return functional.elu(x, inplace=True)
        else:
            return x

    def __repr__(self):
        rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
              ' affine={affine}, activation={activation}'
        if self.activation == "leaky_relu":
            rep += ', slope={slope})'
        else:
            rep += ')'
        return rep.format(name=self.__class__.__name__, **self.__dict__)


class InPlaceABN(ABN):
    """InPlace Activated Batch Normalization"""

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
        """Creates an InPlace Activated Batch Normalization module

        Parameters
        ----------
        num_features : int
            Number of feature channels in the input and output.
        eps : float
            Small constant to prevent numerical issues.
        momentum : float
            Momentum factor applied to compute running statistics as.
        affine : bool
            If `True` apply learned scale and shift transformation after normalization.
        activation : str
            Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
        slope : float
            Negative slope for the `leaky_relu` activation.
        """
        super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)

    def forward(self, x):
        return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
                           self.training, self.momentum, self.eps, self.activation, self.slope)


class InPlaceABNSync(ABN):
    """InPlace Activated Batch Normalization with cross-GPU synchronization
    This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`.
    """

    def forward(self, x):
        return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
                                   self.training, self.momentum, self.eps, self.activation, self.slope)

    def __repr__(self):
        rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
              ' affine={affine}, activation={activation}'
        if self.activation == "leaky_relu":
            rep += ', slope={slope})'
        else:
            rep += ')'
        return rep.format(name=self.__class__.__name__, **self.__dict__)

### Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

#BatchNorm2d = InPlaceABNSync
BatchNorm2d = nn.BatchNorm2d

class ConvBNReLU(nn.Module):
    def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_chan,
                out_chan,
                kernel_size = ks,
                stride = stride,
                padding = padding,
                bias = False)
        self.bn = BatchNorm2d(out_chan)
        #self.bn = BatchNorm2d(out_chan, activation='none')
        self.relu = nn.ReLU()
        self.init_weight()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)


class BiSeNetOutput(nn.Module):
    def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
        super(BiSeNetOutput, self).__init__()
        self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
        self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
        self.init_weight()

    def forward(self, x):
        x = self.conv(x)
        x = self.conv_out(x)
        return x

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class AttentionRefinementModule(nn.Module):
    def __init__(self, in_chan, out_chan, *args, **kwargs):
        super(AttentionRefinementModule, self).__init__()
        self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
        self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
        self.bn_atten = BatchNorm2d(out_chan)
        #self.bn_atten = BatchNorm2d(out_chan, activation='none')

        self.sigmoid_atten = nn.Sigmoid()
        self.init_weight()

    def forward(self, x):
        feat = self.conv(x)
        atten = F.avg_pool2d(feat, feat.size()[2:])
        atten = self.conv_atten(atten)
        atten = self.bn_atten(atten)
        atten = self.sigmoid_atten(atten)
        out = torch.mul(feat, atten)
        return out

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)


class ContextPath(nn.Module):
    def __init__(self, backbone='CatNetSmall', pretrain_model='', use_conv_last=False, *args, **kwargs):
        super(ContextPath, self).__init__()

        self.backbone_name = backbone
        if backbone == 'STDCNet1446':
            self.backbone = STDCNet1446(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
            self.arm16 = AttentionRefinementModule(512, 128)
            inplanes = 1024
            if use_conv_last:
                inplanes = 1024
            self.arm32 = AttentionRefinementModule(inplanes, 128)
            self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
            self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
            self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)

        elif backbone == 'STDCNet813':
            self.backbone = STDCNet813(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
            self.arm16 = AttentionRefinementModule(512, 128)
            inplanes = 1024
            if use_conv_last:
                inplanes = 1024
            self.arm32 = AttentionRefinementModule(inplanes, 128)
            self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
            self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
            self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)
        else:
            print("backbone is not in backbone lists")
            exit(0)

        self.init_weight()

    def forward(self, x):
        H0, W0 = x.size()[2:]

        feat2, feat4, feat8, feat16, feat32 = self.backbone(x)
        H8, W8 = feat8.size()[2:]
        H16, W16 = feat16.size()[2:]
        H32, W32 = feat32.size()[2:]

        avg = F.avg_pool2d(feat32, feat32.size()[2:])

        avg = self.conv_avg(avg)
        avg_up = F.interpolate(avg, (H32, W32), mode='nearest')

        feat32_arm = self.arm32(feat32)
        feat32_sum = feat32_arm + avg_up
        feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
        feat32_up = self.conv_head32(feat32_up)

        feat16_arm = self.arm16(feat16)
        feat16_sum = feat16_arm + feat32_up
        feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
        feat16_up = self.conv_head16(feat16_up)

        return feat2, feat4, feat8, feat16, feat16_up, feat32_up # x8, x16

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class FeatureFusionModule(nn.Module):
    def __init__(self, in_chan, out_chan, *args, **kwargs):
        super(FeatureFusionModule, self).__init__()
        self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
        self.conv1 = nn.Conv2d(out_chan,
                out_chan//4,
                kernel_size = 1,
                stride = 1,
                padding = 0,
                bias = False)
        self.conv2 = nn.Conv2d(out_chan//4,
                out_chan,
                kernel_size = 1,
                stride = 1,
                padding = 0,
                bias = False)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()
        self.init_weight()

    def forward(self, fsp, fcp):
        fcat = torch.cat([fsp, fcp], dim=1)
        feat = self.convblk(fcat)
        atten = F.avg_pool2d(feat, feat.size()[2:])
        atten = self.conv1(atten)
        atten = self.relu(atten)
        atten = self.conv2(atten)
        atten = self.sigmoid(atten)
        feat_atten = torch.mul(feat, atten)
        feat_out = feat_atten + feat
        return feat_out

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class BiSeNet(nn.Module):
    def __init__(self, backbone, n_classes, pretrain_model='', use_boundary_2=False, use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False, heat_map=False, *args, **kwargs):
        super(BiSeNet, self).__init__()

        self.use_boundary_2 = use_boundary_2
        self.use_boundary_4 = use_boundary_4
        self.use_boundary_8 = use_boundary_8
        self.use_boundary_16 = use_boundary_16
        # self.heat_map = heat_map
        self.cp = ContextPath(backbone, pretrain_model, use_conv_last=use_conv_last)



        if backbone == 'STDCNet1446':
            conv_out_inplanes = 128
            sp2_inplanes = 32
            sp4_inplanes = 64
            sp8_inplanes = 256
            sp16_inplanes = 512
            inplane = sp8_inplanes + conv_out_inplanes

        elif backbone == 'STDCNet813':
            conv_out_inplanes = 128
            sp2_inplanes = 32
            sp4_inplanes = 64
            sp8_inplanes = 256
            sp16_inplanes = 512
            inplane = sp8_inplanes + conv_out_inplanes

        else:
            print("backbone is not in backbone lists")
            exit(0)

        self.ffm = FeatureFusionModule(inplane, 256)
        self.conv_out = BiSeNetOutput(256, 256, n_classes)
        self.conv_out16 = BiSeNetOutput(conv_out_inplanes, 64, n_classes)
        self.conv_out32 = BiSeNetOutput(conv_out_inplanes, 64, n_classes)

        self.conv_out_sp16 = BiSeNetOutput(sp16_inplanes, 64, 1)

        self.conv_out_sp8 = BiSeNetOutput(sp8_inplanes, 64, 1)
        self.conv_out_sp4 = BiSeNetOutput(sp4_inplanes, 64, 1)
        self.conv_out_sp2 = BiSeNetOutput(sp2_inplanes, 64, 1)
        self.init_weight()

    def forward(self, x):
        H, W = x.size()[2:]

        feat_res2, feat_res4, feat_res8, feat_res16, feat_cp8, feat_cp16 = self.cp(x)

        feat_out_sp2 = self.conv_out_sp2(feat_res2)

        feat_out_sp4 = self.conv_out_sp4(feat_res4)

        feat_out_sp8 = self.conv_out_sp8(feat_res8)

        feat_out_sp16 = self.conv_out_sp16(feat_res16)

        feat_fuse = self.ffm(feat_res8, feat_cp8)

        feat_out = self.conv_out(feat_fuse)
        feat_out16 = self.conv_out16(feat_cp8)
        feat_out32 = self.conv_out32(feat_cp16)

        feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
        feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
        feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)


        if self.use_boundary_2 and self.use_boundary_4 and self.use_boundary_8:
            return feat_out, feat_out16, feat_out32, feat_out_sp2, feat_out_sp4, feat_out_sp8

        if (not self.use_boundary_2) and self.use_boundary_4 and self.use_boundary_8:
            return feat_out, feat_out16, feat_out32, feat_out_sp4, feat_out_sp8

        if (not self.use_boundary_2) and (not self.use_boundary_4) and self.use_boundary_8:
            return feat_out, feat_out16, feat_out32, feat_out_sp8

        if (not self.use_boundary_2) and (not self.use_boundary_4) and (not self.use_boundary_8):
            return feat_out, feat_out16, feat_out32

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
        for name, child in self.named_children():
            child_wd_params, child_nowd_params = child.get_params()
            if isinstance(child, (FeatureFusionModule, BiSeNetOutput)):
                lr_mul_wd_params += child_wd_params
                lr_mul_nowd_params += child_nowd_params
            else:
                wd_params += child_wd_params
                nowd_params += child_nowd_params
        return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params


## Run

### Parameters

In [None]:
num_epochs = 20
BATCH_SIZE = 6
learning_rate = 1e-3
step_size = 10
gamma = 0.1
resize = 512
w_decay = 1e-4

### Dataset preprocessing

#### Normalization metrics

In [None]:
num_workers = 2 if device.type == 'cuda' else 0

# Poiché il modello è pretrainato su ImageNet, si usano media e varianza di ImageNet
avg = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

#### Training set

In [None]:
train_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255),
    A.Resize(resize, resize, p=1, always_apply=True)
])

train_dataset = LoveDA(TRAIN_DIR, IMG_PATH, MASK_PATH, directories=URBAN_PATH, transforms=train_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

#### Validation set

In [None]:
val_transform = A.Compose([
    A.Normalize(mean=avg, std=std, p=1, always_apply=True, max_pixel_value=255),
])

#val_dataset = LoveDA(VAL_DIR, IMG_PATH, MASK_PATH, directories=URBAN_PATH, transforms=val_transform)
val_dataset = LoveDA(VAL_DIR, IMG_PATH, MASK_PATH, directories=RURAL_PATH, transforms=val_transform)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=num_workers, worker_init_fn=seed_worker, generator=g)

### Training process

#### Model engine

In [None]:
model = BiSeNet(n_classes=num_classes,backbone='STDCNet813', pretrain_model=stdc1_weights).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=w_decay)
scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)

#### Evaluate function

In [None]:
@torch.no_grad()
def evaluate(model, dataloader, device, ) -> tuple:

    model.eval()

    running_loss = 0.0
    data_len = 0
    iou_scores = 0.0
    ious_per_class = torch.zeros(num_classes)

    for i, (inputs, masks) in enumerate(dataloader):

        data_len += inputs.size(0)

        inputs = inputs.to(device)
        masks = masks.to(device)

        # Forward pass
        outputs, _, _ = model(inputs)
        loss = criterion(outputs, masks)

        running_loss += loss.item()*inputs.size(0)

        # Calculate mIoU
        iou, iou_per_class = calculate_iou(outputs, masks, num_classes)
        iou_scores += iou*inputs.size(0)
        ious_per_class+=iou_per_class.cpu()*inputs.size(0)

    mIoU = iou_scores/data_len
    loss = running_loss/data_len
    ious_per_class/=data_len


    return loss, mIoU, ious_per_class

#### Training loop

In [None]:
train_losses = []
eval_losses = []
mious = []

for epoch in range(num_epochs):
    print("### Training mode")
    model.train()
    running_loss = 0.0
    for i, (images, masks) in enumerate(train_loader):
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        outputs, outputs16, outputs32 = model(images)
        loss1 = criterion(outputs, masks)
        loss2 = criterion(outputs16, masks)
        loss3 = criterion(outputs32, masks)
        loss = loss1 + loss2 + loss3
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if (i + 1) % 25 == 0:
            print(f"Processed {i + 1} batches, loss: {running_loss / (i+1)}")

    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss:.4f}")

    print("### Evaluation mode")
    val_loss, miou, ious_per_class = evaluate(model, val_loader, device)

    print(f"Validation mIoU: {miou*100:.3f}%, Validation loss: {val_loss:.5f}")
    for i, cat in enumerate(categories.keys()):
        print(f"{cat} mIoU: {ious_per_class[i]*100:.3f}")
    print()

    eval_losses.append(val_loss)
    mious.append(miou)

    print(f"Epoch: [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}, mIoU: {(miou * 100):.2f}%")

    scheduler.step()

### Metric calculation

In [None]:
model.eval()
mean_latency, _, _, _ = calculate_latency_fps(model, device, 1024, 1024, num_epochs, ModelType.STDC)
print(f"Mean Latency: {mean_latency:.2f} ms")

calculate_flops_params(model, device, 1024, 1024, ModelType.STDC)