# Laufzeit vorbereiten

In [None]:
import random

import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics import classification as metrics
from torchvision.transforms import functional as augment_lib
from torch.optim import lr_scheduler as scheduler

import os
from termcolor import colored
from datetime import datetime
from copy import deepcopy
import shutil
import json
import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

import GPUtil

In [None]:
if not os.path.exists('Data/Underwater'):
    ! pip install kaggle

    ! mkdir Data / Underwater
    ! kaggle datasets download ashish2001 / semantic-segmentation-of-underwater-imagery-suim
    ! unzip semantic-segmentation-of-underwater-imagery-suim.zip -d Data / Underwater /
    ! rm semantic-segmentation-of-underwater-imagery-suim.zip

In [None]:
if not os.path.exists('Modelle'):
    ! mkdir Modelle

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    print("using cuda: ", torch.cuda.get_device_name())

else:
    device = torch.device("cpu")
    torch.set_default_tensor_type(torch.FloatTensor)
    print("using cpu")

# Params

In [None]:
input_shape = [1024, 1024, 3]

rgb2classes = {
    (0, 0, 0): 0,  # Background (Schwarz)
    (0, 0, 255): 1,  # Human diver (Blau)
    (0, 255, 0): 2,  # Plant (Grün)
    (0, 255, 255): 3,  # Wreck or ruin (Sky)
    (255, 0, 0): 4,  # Robot (Rot)
    (255, 0, 255): 5,  # Reef or invertebrate (Pink)
    (255, 255, 0): 6,  # Fish or vertebrate (Gelb)
    (255, 255, 255): 7  # Sea-floor or rock (Weiß)
}

classColorMap = ListedColormap([(r/255, g/255, b/255) for (r, g, b) in rgb2classes.keys()])

# Util

In [None]:
def free_gpu_cache(tensors, print_out=False):
    gpu = GPUtil.getGPUs()[0]

    if print_out:
        print("\n", "=" * 100, "\nBefore Clearing")
        GPUtil.showUtilization()

    for tensor in tensors:
        del tensor

    torch.cuda.empty_cache()

    if print_out:
        print("\nAfter Clearing")
        GPUtil.showUtilization()


def plot_image(image, mask=None, image_color=True, image_cmap=None, mask_color=True, mask_cmap=None):
    fig = plt.figure

    if isinstance(image, torch.Tensor):
        image = image.cpu().numpy()

    if image_color:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    if image_cmap is not None:
        plt.imshow(image, interpolation="nearest", cmap=image_cmap)
    else:
        plt.imshow(image, interpolation="nearest")

    if mask is not None:
        if isinstance(mask, torch.Tensor):
            mask = mask.cpu().numpy()

        if mask_color:
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)

        if mask_cmap is not None:
            plt.imshow(mask, alpha=0.5, interpolation="nearest", cmap=mask_cmap)
        else:
            plt.imshow(mask, alpha=0.5, interpolation="nearest")

    plt.show()


def plot_images(images: [], masks: [] = None, color=True, image_width=input_shape[0], image_height=input_shape[1], images_per_column=1, images_per_row=6):
    fig, ax = plt.subplots(nrows=images_per_column, ncols=images_per_row, figsize=([128, 128]))

    for index, axi in enumerate(ax.flat):
        image = images[index]

        if isinstance(image, torch.Tensor):
            image = image.cpu().numpy()

        if color:
            image.reshape([image_width, image_height, 3])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        #else:
            #image = image.reshape([image_width, image_height, 1])

        axi.imshow(image, interpolation="nearest")

        if masks is not None:
            mask_tensor = masks[index]

            mask_np = mask_tensor.reshape([image_width, image_width, 3])
            mask_np = cv2.cvtColor(mask_np, cv2.COLOR_BGR2RGB)

            axi.imshow(mask_np, interpolation="nearest")

    plt.tight_layout()
    plt.axis("off")
    plt.show()


def convert_rgb_to_class_feature_map(rgb_feature_map, num_classes=len(rgb2classes)):
    class_feature_map = np.zeros([dim for dim in rgb_feature_map.shape[:2]] + [num_classes])

    for row_index, rgb_row in enumerate(rgb_feature_map):
        for field_index, rgb_field in enumerate(rgb_row):
            rgb_field = tuple([
                255 if channel > 126 else 0
                for channel in rgb_field
            ])


            class_index = rgb2classes[rgb_field]
            class_feature_map[row_index][field_index][class_index] = 1

    return class_feature_map


def convert_class_to_rgb_feature_map(class_feature_map):
    class_feature_map = class_feature_map.reshape([class_feature_map.shape[0]] + [dim for dim in class_feature_map.shape[2:4]] + [class_feature_map.shape[1]]).squeeze()
    rgb_feature_map = np.ndarray([dim for dim in class_feature_map.shape[0:2]] + [3])

    for row_index, class_row in enumerate(class_feature_map):
        for field_index, class_field in enumerate(class_row):
            feature_class = np.where(class_field == 1)[0][0]

            rgb_feature_map[row_index][field_index] = [rgb_color
                                                       for rgb_color, mapped_feature_class
                                                       in rgb2classes.items()
                                                       if mapped_feature_class == feature_class][0]

    return rgb_feature_map

In [None]:
def is_more_rare(factor,current):
    if factor > current:
        return factor
    else:
        return current    


def check_how_rare(rgb_label):
    factor = 0
    for x in range(150,900,7):
        for y in range(150,900,7):
            #print(rgb_label[x][y])
            #print(" x = " + str(x) + " y = " + str(y) + "=" )
            
            if (rgb_label[x][y] & (255, 0, 0)).all(): # robot
                print(" The Factor is " + 5)
                return 5
            elif (rgb_label[x][y] ==  (0, 255, 0)).all(): # plant
                factor = 4      
            elif (rgb_label[x][y] == (0, 255, 255)).all(): #ruin:
                factor = is_more_rare(factor,3)       
            elif (rgb_label[x][y] == (0, 0, 255)).all(): # Human
                factor = is_more_rare(factor,2)            
            elif (rgb_label[x][y] == (255, 255, 255)).all(): # rock   
                 factor = is_more_rare(factor,1)       
            else:
                continue        
    print("The Multiplier in the Preagumention is " + str(factor))
    return factor



class Augmentation():
    
    def __init__(self,all_image):
        self.all_image = all_image 

        factor = check_how_rare(all_image[1])
        self.do_noise = False
        self.do_sat = False
        self.do_invert = False
        self.saturation_options = []
        self.invert_options = []

        if (factor >= 1):
            self.do_noise = True
        if (factor >= 2):
            self.do_sat = True
            self.saturation_options.append(22)
        if (factor >= 3):
            self.saturation_options.append(44)
        if (factor >= 4):
            self.do_invert = True
            self.invert_options.append(33)          
        if (factor >= 5):
            self.invert_options.append(66)

# https://github.com/AISangam/Image-Augmentation-Using-OpenCV-and-Python/blob/master/Image%20Augmentaion%20Part1.py
    def saturation_image(self,img):
        result = []
        for x in self.saturation_options:

            image = cv2.cvtColor(img[0].copy(), cv2.COLOR_BGR2HSV)

            v = image[:, :, 2]
            v = np.where(v <= 255 - x, v + x, 255)
            image[:, :, 2] = v

            image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
            result.append((image,img[1],img[2],img[3]))
        return result


    def invert(self,image,channel):
        return (channel-image)


    def invert_image(self,image):
        result = []
        for x in self.invert_options:
            tmp = image[0].copy()
            result.append((self.invert(tmp,x),image[1],image[2],image[3]))
        return result

    # https://github.com/AISangam/Image-Augmentation-Using-OpenCV-and-Python/blob/master/Image%20Augmentaion%20Part1.py
    def addeptive_gaussian_noise(self,image):
        
        h,s,v=cv2.split(image[0])
        s = cv2.adaptiveThreshold(s, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
        h = cv2.adaptiveThreshold(h, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
        v = cv2.adaptiveThreshold(v, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
        new=cv2.merge([h,s,v])
        return [[new,image[1],image[2],image[3]]]


    
    def start(self):
        
        image = self.all_image
        result  = [image]
        if self.do_noise:
            result += self.addeptive_gaussian_noise(image)
        if self.do_invert:
            result += self.invert_image(image)
        if self.do_sat:
            result += self.saturation_image(image)
        print(" Pre Aug did create " + str(len(result)) + "new Images")
        return result    


# Datensatz

In [None]:
class RttsDataset(Dataset):
    def __init__(self, name, image_shape=input_shape, classes=rgb2classes):
        self.name = name
        self.classes = classes
        self.image_shape = image_shape
        self.entries = []

    def load_entries(self, directory: str, starting_point=0, image_count=-1, augment=False, print_out=True):
        start_image_count = len(self)
        image_count_threshold = image_count * 10 if augment else image_count
        list_of_image_files = os.listdir(directory + "/images")

        print("\n", "=" * 10, "Start loading entries", "=" * 4, "\n")
        for index, filename in enumerate(list_of_image_files):
            if starting_point <= index:
                if image_count != -1 and (len(self) - start_image_count) == image_count_threshold:
                    break

                self.load_entry(directory, filename[:-4], augment=augment)

                if print_out and index % 50 == 0 and index != 0:
                    print("\n", "=" * 10, "Loaded", (index - starting_point), "entries into", self.name, "=" * 10, "\n")

        print("\n", "=" * 100, "\n", self.name, "contains", len(self), "entries.\n")

    def load_entry(self, directory: str, file_name: str, augment=False):
        image_file = os.path.join(directory + "/images", file_name + ".jpg")
        image = cv2.imread(image_file)
        if image is None:
            return

        image: numpy.ndarray = cv2.resize(image, self.image_shape[:2])

        rgb_label_file = os.path.join(directory + "/masks", file_name + ".bmp")
        rgb_label = cv2.imread(rgb_label_file)
        if rgb_label is None:
            return

        rgb_label = cv2.resize(rgb_label, self.image_shape[:2])

        class_label_path = directory + "/class_labels/"
        class_label_filename = class_label_path + file_name + ".txt"

        if os.path.exists(class_label_filename):
            class_label = np.loadtxt(class_label_filename, dtype=int)
            class_label = class_label.reshape([dim for dim in rgb_label.shape[:2]] + [len(rgb2classes)])

        else:
            if not os.path.exists(class_label_path):
                os.mkdir(class_label_path)

            class_label = convert_rgb_to_class_feature_map(cv2.cvtColor(rgb_label, cv2.COLOR_BGR2RGB))
            class_label = class_label.reshape([rgb_label.shape[0], rgb_label.shape[1] * len(rgb2classes)])
            np.savetxt(class_label_filename, class_label, fmt="%d")

        edge = cv2.Canny(rgb_label, 0.1, 0.2)
        kernel = np.ones((4, 4), np.uint8)
        edge = (cv2.dilate(edge, kernel, iterations=1) > 50) * 1.0

        image = np.array(image)
        rgb_label = np.array(rgb_label)
        class_label = np.array(class_label)
        edge = np.array(edge)

        if augment:
      #     print(rgb_label)
      #     print(class_label)
           aug = Augmentation((image, rgb_label, class_label, edge))
           #plot_image(rgb_label)
           self.entries += aug.start()
           #print(image_count)

        else:
            self.entries.append((image, rgb_label, class_label, edge))

    def augment(self, entry):
        image_np, rgb_label_np, class_label_np, edge_np = entry

        image = torch.from_numpy(image_np)
        rgb_label = torch.from_numpy(rgb_label_np)
        class_label = torch.from_numpy(class_label_np)
        edge = torch.from_numpy(edge_np).unsqueeze(dim=0)

        self.entries += augment_entry(image, rgb_label, class_label, edge)

    def __getitem__(self, index):
        image, rgb_label, class_label, edge = self.entries[index]

        image_tensor = torch.Tensor(image) / 255
        rgb_label_tensor = torch.from_numpy(rgb_label) / 255
        class_label_tensor = torch.Tensor(class_label)
        edge_tensor = torch.Tensor(edge)

        return image_tensor, rgb_label_tensor, class_label_tensor, edge_tensor

    def __len__(self):
        return len(self.entries)

    def plot_image(self, index):
        image_tensor, _, _, _ = self.__getitem__(index)
        plot_image(image_tensor)

    def plot_rgb_label(self, index):
        _, label, _, _ = self.__getitem__(index)
        plot_image(label)

    def plot_class_label(self, index):
        class_label = self.entries[index][2]
        images = [
            class_label[:, :, i]
            for i in range(class_label.shape[2])
        ]
        plot_images(images=images, color=False, images_per_row=8)

    def plot_edge(self, index):
        _, _, _, edge_tensor = self.__getitem__(index)
        plot_image(edge_tensor, image_color=False)

    def plot_image_with_rgb_label(self, index):
        image_tensor, label, _, _ = self.__getitem__(index)
        plot_image(image_tensor, label)

    def plot_images(self):
        images = [image for image, _, _, _ in self.entries]
        plot_images(images)

    def plot_labels(self):
        labels = [label for _, label, _, _ in self.entries]
        plot_images(labels)

    def plot_edges(self):
        edges = [edge for _, _, _, edge in self.entries]
        plot_images(edges, color=False)

    def plot_images_with_rgb_label(self):
        images = [image for image, _, _, _ in self.entries]
        labels = [label for _, label, _, _ in self.entries]
        plot_images(images, labels)

In [None]:
train_dataset_path = "Data/Underwater/train_val"
test_dataset_path = "Data/Underwater/TEST"

train_set = RttsDataset(name="Train_Set")
train_set.load_entries(directory=train_dataset_path)

test_set = RttsDataset(name="Test_Set")
test_set.load_entries(directory=test_dataset_path)

In [None]:
image_index = 0
train_set.plot_image(image_index)
train_set.plot_rgb_label(image_index)
train_set.plot_class_label(image_index)
train_set.plot_edge(image_index)
train_set.plot_image_with_rgb_label(image_index)

# Augmentation

In [None]:
def augment_entry(image_tensor, rgb_label_tensor, class_label_tensor, edge_tensor):
    result = []

    if len(image_tensor.shape) == 3:
        image_tensor = image_tensor.unsqueeze(dim=0)
        rgb_label_tensor = rgb_label_tensor.unsqueeze(dim=0)
        class_label_tensor = class_label_tensor.unsqueeze(dim=0)
        edge_tensor = edge_tensor.unsqueeze(dim=0)

    image_tensor = image_tensor.permute(0, 3, 1, 2)
    rgb_label_tensor = rgb_label_tensor.permute(0, 3, 1, 2)
    class_label_tensor = class_label_tensor.permute(0, 3, 1, 2)
    edge_tensor = edge_tensor.unsqueeze(dim=3).permute(0, 3, 1, 2)

    factor = random.random()

    brightness_image = augment_lib.adjust_brightness(image_tensor, factor)
    contrast_image = augment_lib.adjust_contrast(image_tensor, factor)
    blured_image = augment_lib.gaussian_blur(image_tensor, kernel_size=random.choice([1,3,5,7,9]), sigma=factor)
    color_invert_image = augment_lib.invert(image_tensor)

    h_flipped_image = augment_lib.hflip(image_tensor)
    h_flipped_rgb_label = augment_lib.hflip(rgb_label_tensor)
    h_flipped_class_label = augment_lib.hflip(class_label_tensor)
    h_flipped_edge = augment_lib.hflip(edge_tensor)

    v_flipped_image = augment_lib.vflip(image_tensor)
    v_flipped_rgb_label = augment_lib.vflip(rgb_label_tensor)
    v_flipped_class_label = augment_lib.vflip(class_label_tensor)
    v_flipped_edge = augment_lib.vflip(edge_tensor)

    b_flipped_image = augment_lib.hflip(augment_lib.vflip(image_tensor))
    b_flipped_rgb_label = augment_lib.hflip(augment_lib.vflip(rgb_label_tensor))
    b_flipped_class_label = augment_lib.hflip(augment_lib.vflip(class_label_tensor))
    b_flipped_edge = augment_lib.hflip(augment_lib.vflip(edge_tensor))

    repetitions = 3 if 1 in class_label_tensor or 2 in class_label_tensor or 3 in class_label_tensor or 4 in class_label_tensor else 1

    height = image_tensor.shape[2]
    width = image_tensor.shape[3]

    for _ in range(repetitions):
        size = random.randint(300, 700)
        top = random.randint(0, image_tensor.shape[3] - size)
        left = random.randint(0, image_tensor.shape[2] - size)

        cropped_image = F.interpolate(
            input=augment_lib.crop(image_tensor, top=top, left=left, width=size, height=size),
            size=[height, width],
            mode='bilinear',
            align_corners=False
        )

        cropped_rgb_label = F.interpolate(
            input=augment_lib.crop(rgb_label_tensor, top=top, left=left, width=size, height=size),
            size=[height, width],
            mode='bilinear',
            align_corners=False
        )

        cropped_class_label = F.interpolate(
            input=augment_lib.crop(class_label_tensor, top=top, left=left, width=size, height=size),
            size=[height, width],
            mode='bilinear',
            align_corners=False
        )

        cropped_edge = F.interpolate(
            input=augment_lib.crop(edge_tensor, top=top, left=left, width=size, height=size),
            size=[height, width],
            mode='bilinear',
            align_corners=False
        )

        result.append((cropped_image.permute(0, 2, 3, 1), cropped_rgb_label.permute(0, 2, 3, 1), cropped_class_label.permute(0, 2, 3, 1), cropped_edge.permute(0, 2, 3, 1)))

    result.append((brightness_image.permute(0, 2, 3, 1), rgb_label_tensor.permute(0, 2, 3, 1), class_label_tensor.permute(0, 2, 3, 1), edge_tensor.permute(0, 2, 3, 1)))
    result.append((contrast_image.permute(0, 2, 3, 1), rgb_label_tensor.permute(0, 2, 3, 1), class_label_tensor.permute(0, 2, 3, 1), edge_tensor.permute(0, 2, 3, 1)))
    result.append((blured_image.permute(0, 2, 3, 1), rgb_label_tensor.permute(0, 2, 3, 1), class_label_tensor.permute(0, 2, 3, 1), edge_tensor.permute(0, 2, 3, 1)))
    result.append((color_invert_image.permute(0, 2, 3, 1), rgb_label_tensor.permute(0, 2, 3, 1), class_label_tensor.permute(0, 2, 3, 1), edge_tensor.permute(0, 2, 3, 1)))
    result.append((h_flipped_image.permute(0, 2, 3, 1), h_flipped_rgb_label.permute(0, 2, 3, 1), h_flipped_class_label.permute(0, 2, 3, 1), h_flipped_edge.permute(0, 2, 3, 1)))
    result.append((v_flipped_image.permute(0, 2, 3, 1), v_flipped_rgb_label.permute(0, 2, 3, 1), v_flipped_class_label.permute(0, 2, 3, 1), v_flipped_edge.permute(0, 2, 3, 1)))
    result.append((b_flipped_image.permute(0, 2, 3, 1), b_flipped_rgb_label.permute(0, 2, 3, 1), b_flipped_class_label.permute(0, 2, 3, 1), b_flipped_edge.permute(0, 2, 3, 1)))

    return result

# KNN

### Params

In [None]:
batchnorm_momentum = 0.1
align_corners = False
loss_weights = torch.tensor([
    1/1288,
    1/405,
    1/239,
    1/275,
    1/101,
    1/1028,
    1/1030,
    1/635
])

### Loss Functions

In [None]:
class SemanticCrossEntropyLoss(nn.Module):
    def __init__(self, ignore_label=-1, thres=0.7, min_kept=100_000, weight=None):
        super(SemanticCrossEntropyLoss, self).__init__()

        self.thresh = thres
        self.min_kept = max(1, min_kept)
        self.ignore_label = ignore_label
        self.criterion = nn.CrossEntropyLoss(
            weight=weight,
            ignore_index=ignore_label,
        )

    def _ce_forward(self, prediction, target):
        prediction = F.softmax(prediction, dim=1)
        return self.criterion(prediction, target)

    def _ohem_forward(self, prediction, target):
        prediction = self._ce_forward(prediction, target).contiguous().view(-1)

        return prediction.mean()

    def forward(self, prediction, target):
        if not (isinstance(prediction, list) or isinstance(prediction, tuple)):
            prediction = [prediction]

        balance_weights = [0.5, 0.5]
        if len(balance_weights) == len(prediction):
            functions = [self._ce_forward] * (len(balance_weights) - 1) + [self._ohem_forward]
            return sum([
                weight * func(x, target)
                for (weight, x, func) in zip(balance_weights, prediction, functions)
            ])

        elif len(prediction) == 1:
            return 0.5 * self._ohem_forward(prediction[0], target)

        else:
            raise ValueError("lengths of prediction and target are not identical!")

In [None]:
class BoundaryLoss(nn.Module):
    def __init__(self, coeff_bce=20.0):
        super(BoundaryLoss, self).__init__()

        self.coeff_bce = coeff_bce

    def forward(self, prediction, target):
        return self.coeff_bce * self.weighted_bce(prediction, target)

    @staticmethod
    def weighted_bce(prediction, target):
        prediction = prediction.permute(0, 2, 3, 1).contiguous().view(1, -1)
        target = target.view(1, -1)

        pos_index = (target == 1)
        neg_index = (target == 0)

        weights = torch.zeros_like(prediction)
        pos_num = pos_index.sum()
        neg_num = neg_index.sum()
        sum_num = pos_num + neg_num
        weights[pos_index] = neg_num * 1.0 / sum_num
        weights[neg_index] = pos_num * 1.0 / sum_num

        loss = F.binary_cross_entropy_with_logits(prediction, target, weights, reduction='mean')

        return loss

### Blocks

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=None, no_relu=False):
        super(BasicBlock, self).__init__()

        self.downsample = downsample
        self.stride = stride
        self.no_relu = no_relu

        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False
        )

        self.bn1 = nn.BatchNorm2d(
            num_features=out_channels,
            momentum=batchnorm_momentum
        )

        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=3,
            padding=1,
            bias=False
        )

        self.bn2 = nn.BatchNorm2d(
            num_features=out_channels,
            momentum=batchnorm_momentum
        )

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        if self.no_relu:
            return out

        return self.relu(out)

In [None]:
 class BottleneckBlock(nn.Module):
    expansion = 2

    def __init__(self, in_channels, out_channels, stride=1, downsample=None, no_relu=True):
        super(BottleneckBlock, self).__init__()

        self.downsample = downsample
        self.stride = stride
        self.no_relu = no_relu


        self.relu = nn.ReLU(inplace=True)

        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=1,
            bias=False
        )

        self.bn1 = nn.BatchNorm2d(
            num_features=out_channels,
            momentum=batchnorm_momentum
        )

        self.conv2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False
        )

        self.bn2 = nn.BatchNorm2d(
            num_features=out_channels,
            momentum=batchnorm_momentum
        )

        self.conv3 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels * self.expansion,
            kernel_size=1,
            bias=False
        )

        self.bn3 = nn.BatchNorm2d(
            out_channels * self.expansion,
            momentum=batchnorm_momentum
        )

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        if self.no_relu:
            return out

        return self.relu(out)

In [None]:
class SegmentHead(nn.Module):

    def __init__(self, in_channels, inter_channels, out_channels, scale_factor=None):
        super(SegmentHead, self).__init__()

        self.scale_factor = scale_factor

        self.relu = nn.ReLU(inplace=True)

        self.bn1 = nn.BatchNorm2d(
            num_features=in_channels,
            momentum=batchnorm_momentum
        )

        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=inter_channels,
            kernel_size=3,
            padding=1,
            bias=False
        )

        self.bn2 = nn.BatchNorm2d(
            num_features=inter_channels,
            momentum=batchnorm_momentum
        )

        self.conv2 = nn.Conv2d(
            in_channels=inter_channels,
            out_channels=out_channels,
            kernel_size=1,
            padding=0,
            bias=True
        )

    def forward(self, x):
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv1(x)

        out = self.bn2(x)
        out = self.relu(out)
        out = self.conv2(out)

        if self.scale_factor is not None:
            height = x.shape[-2] * self.scale_factor
            width = x.shape[-1] * self.scale_factor

            out = F.interpolate(
                input=out,
                size=[height, width],
                mode='bilinear',
                align_corners=align_corners
            )

        return out

In [None]:
class PAPPM(nn.Module):
    def __init__(self, in_channels, branch_channels, out_channels):
        super(PAPPM, self).__init__()

        self.scale0 = nn.Sequential(
            nn.BatchNorm2d(
                num_features=in_channels,
                momentum=batchnorm_momentum
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=branch_channels,
                kernel_size=1,
                bias=False
            ),
        )

        self.scale1 = nn.Sequential(
            nn.AvgPool2d(
                kernel_size=5,
                stride=2,
                padding=2
            ),
            nn.BatchNorm2d(
                num_features=in_channels,
                momentum=batchnorm_momentum
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=branch_channels,
                kernel_size=1,
                bias=False
            ),
        )

        self.scale2 = nn.Sequential(
            nn.AvgPool2d(
                kernel_size=9,
                stride=4,
                padding=4
            ),
            nn.BatchNorm2d(
                num_features=in_channels,
                momentum=batchnorm_momentum
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=branch_channels,
                kernel_size=1,
                bias=False
            ),
        )

        self.scale3 = nn.Sequential(
            nn.AvgPool2d(
                kernel_size=17,
                stride=8,
                padding=8
            ),
            nn.BatchNorm2d(
                num_features=in_channels,
                momentum=batchnorm_momentum
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=branch_channels,
                kernel_size=1,
                bias=False
            ),
        )

        self.scale4 = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.BatchNorm2d(
                num_features=in_channels,
                momentum=batchnorm_momentum
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=branch_channels,
                kernel_size=1,
                bias=False
            ),
        )

        self.scale_process = nn.Sequential(
            nn.BatchNorm2d(
                num_features=branch_channels * 5,
                momentum=batchnorm_momentum
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=branch_channels * 5,
                out_channels=branch_channels * 4,
                kernel_size=3,
                padding=1,
                groups=4,
                bias=False
            ),
        )

        self.compression = nn.Sequential(
            nn.BatchNorm2d(
                num_features=branch_channels * 5,
                momentum=batchnorm_momentum
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=branch_channels * 5,
                out_channels=out_channels,
                kernel_size=1,
                bias=False
            ),
        )

        self.shortcut = nn.Sequential(
            nn.BatchNorm2d(
                num_features=in_channels,
                momentum=batchnorm_momentum
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                bias=False
            ),
        )

    def forward(self, x):
        width = x.shape[-1]
        height = x.shape[-2]

        x_scale0 = self.scale0(x)

        scale_list = [
            x_scale0
        ]

        x_scale1 = F.interpolate(
            input=self.scale1(x),
            size=[height, width],
            mode='bilinear',
            align_corners=align_corners
        ) + x_scale0
        scale_list.append(x_scale1)

        x_scale2 = F.interpolate(
            input=self.scale2(x),
            size=[height, width],
            mode='bilinear',
            align_corners=align_corners
        ) + x_scale0
        scale_list.append(x_scale2)

        x_scale3 = F.interpolate(
            input=self.scale3(x),
            size=[height, width],
            mode='bilinear',
            align_corners=align_corners
        ) + x_scale0
        scale_list.append(x_scale3)

        x_scale4 = F.interpolate(
            input=self.scale4(x),
            size=[height, width],
            mode='bilinear',
            align_corners=align_corners
        ) + x_scale0
        scale_list.append(x_scale4)

        scale_out = self.scale_process(torch.cat(scale_list, 1))
        out = self.compression(torch.cat([x_scale0, scale_out], 1)) + self.shortcut(x)

        return out

In [None]:
class PagFM(nn.Module):
    def __init__(self, in_channels, inter_channels, after_relu=False, with_channel=False):
        super(PagFM, self).__init__()

        self.with_channel = with_channel
        self.after_relu = after_relu

        self.f_x = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=inter_channels,
                kernel_size=1,
                bias=False
            ),
            nn.BatchNorm2d(inter_channels)
        )

        self.f_y = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=inter_channels,
                kernel_size=1,
                bias=False
            ),
            nn.BatchNorm2d(inter_channels)
        )

        if with_channel:
            self.up = nn.Sequential(
                nn.Conv2d(
                    in_channels=inter_channels,
                    out_channels=in_channels,
                    kernel_size=1,
                    bias=False
                ),
                nn.BatchNorm2d(in_channels)
            )

        if after_relu:
            self.relu = nn.ReLU(inplace=True)

    def forward(self, x, y):
        input_size = x.size()

        if self.after_relu:
            y = self.relu(y)
            x = self.relu(x)

        y_q = self.f_y(y)
        y_q = F.interpolate(
            input=y_q,
            size=[input_size[2], input_size[3]],
            mode='bilinear',
            align_corners=False
        )

        x_k = self.f_x(x)

        if self.with_channel:
            sim_map = torch.sigmoid(self.up(x_k * y_q))
        else:
            sim_map = torch.sigmoid(torch.sum(x_k * y_q, dim=1).unsqueeze(1))

        y = F.interpolate(
            input=y,
            size=[input_size[2], input_size[3]],
            mode='bilinear',
            align_corners=False
        )
        x = (1 - sim_map) * x + sim_map * y

        return x

In [None]:
class LightBag(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(LightBag, self).__init__()

        self.conv_p = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                bias=False
            ),
            nn.BatchNorm2d(out_channels)
        )

        self.conv_i = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                bias=False
            ),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)

        p_add = self.conv_p((1 - edge_att) * i + p)
        i_add = self.conv_i(i + edge_att * p)

        return p_add + i_add

In [None]:
class LightBagV2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(LightBagV2, self).__init__()

        self.conv_p = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                bias=False
            ),
            nn.BatchNorm2d(out_channels)
        )

        self.conv_i = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                bias=False
            ),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)

        p_add = self.conv_p((1 - edge_att) * i + p)
        i_add = self.conv_i(i + edge_att * p)

        return p_add + i_add

### Model

In [None]:
class PIDNet(nn.Module):
    def __init__(self, name: str, learning_rate: float, in_channels=3, out_channels=32, ppm_channels=96, head_channels=128, num_classes=len(rgb2classes), loss_weights=None):
        super(PIDNet, self).__init__()

        self.name = name
        self.epochsTrained = 0

        self.num_classes = num_classes

        self.relu = nn.ReLU(inplace=True)

        # I Branch
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                bias=False
            ),
            nn.BatchNorm2d(
                num_features=out_channels,
                momentum=batchnorm_momentum
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                bias=False
            ),
            nn.BatchNorm2d(
                num_features=out_channels,
                momentum=batchnorm_momentum
            ),
            nn.ReLU(inplace=True),
        )

        self.layer1 = self.__make_layer(
            block=BasicBlock,
            in_channels=out_channels,
            out_channels=out_channels,
            blocks=2
        )

        self.layer2 = self.__make_layer(
            block=BasicBlock,
            in_channels=out_channels,
            out_channels=out_channels * 2,
            blocks=2,
            stride=2
        )

        self.layer3 = self.__make_layer(
            block=BasicBlock,
            in_channels=out_channels * 2,
            out_channels=out_channels * 4,
            blocks=3,
            stride=2
        )

        self.layer4 = self.__make_layer(
            block=BasicBlock,
            in_channels=out_channels * 4,
            out_channels=out_channels * 8,
            blocks=3,
            stride=2
        )

        self.layer5 = self.__make_layer(
            block=BottleneckBlock,
            in_channels=out_channels * 8,
            out_channels=out_channels * 8,
            blocks=2,
            stride=2
        )

        # P Branch
        self.compression3 = nn.Sequential(
            nn.Conv2d(
                in_channels=out_channels * 4,
                out_channels=out_channels * 2,
                kernel_size=1,
                bias=False
            ),
            nn.BatchNorm2d(
                num_features=out_channels * 2,
                momentum=batchnorm_momentum
            ),
        )

        self.compression4 = nn.Sequential(
            nn.Conv2d(
                in_channels=out_channels * 8,
                out_channels=out_channels * 2,
                kernel_size=1,
                bias=False
            ),
            nn.BatchNorm2d(
                num_features=out_channels * 2,
                momentum=batchnorm_momentum
            ),
        )

        self.pag3 = PagFM(
            in_channels=out_channels * 2,
            inter_channels=out_channels
        )

        self.pag4 = PagFM(
            in_channels=out_channels * 2,
            inter_channels=out_channels
        )

        self.layer3_ = self.__make_layer(
            block=BasicBlock,
            in_channels=out_channels * 2,
            out_channels=out_channels * 2,
            blocks=2

        )
        self.layer4_ = self.__make_layer(
            block=BasicBlock,
            in_channels=out_channels * 2,
            out_channels=out_channels * 2,
            blocks=2
        )

        self.layer5_ = self.__make_layer(
            block=BottleneckBlock,
            in_channels=out_channels * 2,
            out_channels=out_channels * 2,
            blocks=1
        )

        # D Branch
        self.layer3_d = self.__make_single_layer(
            block=BasicBlock,
            in_channels=out_channels * 2,
            out_channels=out_channels
        )

        self.layer4_d = self.__make_layer(
            block=BottleneckBlock,
            in_channels=out_channels,
            out_channels=out_channels,
            blocks=1
        )

        self.diff3 = nn.Sequential(
            nn.Conv2d(
                in_channels=out_channels * 4,
                out_channels=out_channels,
                kernel_size=3,
                padding=1,
                bias=False
            ),
            nn.BatchNorm2d(
                num_features=out_channels,
                momentum=batchnorm_momentum
            ),
        )

        self.diff4 = nn.Sequential(
            nn.Conv2d(
                in_channels=out_channels * 8,
                out_channels=out_channels * 2,
                kernel_size=3,
                padding=1,
                bias=False
            ),
            nn.BatchNorm2d(
                num_features=out_channels * 2,
                momentum=batchnorm_momentum
            ),
        )

        self.spp = PAPPM(
            in_channels=out_channels * 16,
            branch_channels=ppm_channels,
            out_channels=out_channels * 4
        )

        self.dfm = LightBagV2(
            in_channels=out_channels * 4,
            out_channels=out_channels * 4
        )

        self.layer5_d = self.__make_layer(
            block=BottleneckBlock,
            in_channels=out_channels * 2,
            out_channels=out_channels * 2,
            blocks=1
        )

        # Prediction Head
        self.seghead_p = SegmentHead(
            in_channels=out_channels * 2,
            inter_channels=head_channels,
            out_channels=num_classes
        )

        self.seghead_d = SegmentHead(
            in_channels=out_channels * 2,
            inter_channels=out_channels,
            out_channels=1
        )


        self.final_layer = SegmentHead(
            in_channels=out_channels * 4,
            inter_channels=head_channels,
            out_channels=num_classes
        )


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    tensor=m.weight,
                    mode='fan_out',
                    nonlinearity='relu'
                )

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        self.semantic_loss_function = SemanticCrossEntropyLoss(
            ignore_label=-1,
            thres=0.9,
            min_kept=100_000,
            weight=loss_weights
        )

        self.boundary_loss_function = BoundaryLoss()
        self.learning_rate = learning_rate
        self.optimiser = optim.SGD(
            self.parameters(),
            lr=learning_rate,
            momentum=0.9,
            weight_decay=0.0001,
            nesterov=False
        )

    @staticmethod
    def __make_layer(block, in_channels, out_channels, blocks, stride=1):
        downsample = None

        if stride != 1 or in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels * block.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False
                ),
                nn.BatchNorm2d(
                    num_features=out_channels * block.expansion,
                    momentum=batchnorm_momentum
                ),
            )

        layers = []
        layers.append(
            block(in_channels, out_channels, stride, downsample)
        )

        in_channels = out_channels * block.expansion

        for i in range(1, blocks):
            if i == (blocks - 1):
                layers.append(
                    block(in_channels, out_channels, stride=1, no_relu=True)
                )

            else:
                layers.append(
                    block(in_channels, out_channels, stride=1, no_relu=False)
                )

        return nn.Sequential(*layers)

    @staticmethod
    def __make_single_layer(block, in_channels, out_channels, stride=1):
        downsample = None

        if stride != 1 or in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels * block.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False
                ),
                nn.BatchNorm2d(
                    num_features=out_channels * block.expansion,
                    momentum=batchnorm_momentum
                ),
            )

        layer = block(in_channels, out_channels, stride, downsample, no_relu=True)

        return layer

    def forward(self, x):
        width_output = x.shape[-1] // 8
        height_output = x.shape[-2] // 8

        x = self.conv1(x)
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        x = self.relu(x)

        x_ = self.layer3_(x)
        x_d = self.layer3_d(x)

        x = self.layer3(x)
        x = self.relu(x)

        x_ = self.pag3(x_, self.compression3(x))
        x_d = x_d + F.interpolate(
            input=self.diff3(x),
            size=[height_output, width_output],
            mode='bilinear',
            align_corners=align_corners
        )

        temp_p = x_

        x = self.layer4(x)
        x = self.relu(x)

        x_ = self.relu(x_)
        x_ = self.layer4_(x_)

        x_d = self.relu(x_d)
        x_d = self.layer4_d(x_d)

        x_ = self.pag4(x_, self.compression4(x))
        x_d = x_d + F.interpolate(
            input=self.diff4(x),
            size=[height_output, width_output],
            mode='bilinear',
            align_corners=align_corners
        )

        temp_d = x_d

        x_ = self.relu(x_)
        x_ = self.layer5_(x_)

        x_d = self.relu(x_d)
        x_d = self.layer5_d(x_d)

        x = F.interpolate(
            self.spp(self.layer5(x)),
            size=[height_output, width_output],
            mode='bilinear',
            align_corners=align_corners
        )

        x_ = self.final_layer(self.dfm(x_, x, x_d))

        x_extra_p = self.seghead_p(temp_p)
        x_extra_d = self.seghead_d(temp_d)

        return x_extra_p, x_, x_extra_d

    train_mode = 0
    test_mode = 1

    def __predict(self, image_tensor, output_height, output_width):
        image_tensor = image_tensor.permute(0, 3, 1, 2)

        prediction_start_ts = datetime.now()

        batch_p_branch_prediction_tensor, batch_i_branch_prediction_tensor, batch_d_branch_prediction_tensor = self.forward(image_tensor)

        prediction_stop_ts = datetime.now()
        prediction_time = (prediction_stop_ts - prediction_start_ts).total_seconds() * 1_000_000

        p_prediction_height, p_prediction_width = batch_i_branch_prediction_tensor.size(2), batch_i_branch_prediction_tensor.size(3)
        i_prediction_height, i_prediction_width = batch_i_branch_prediction_tensor.size(2), batch_i_branch_prediction_tensor.size(3)
        d_prediction_height, d_prediction_width = batch_i_branch_prediction_tensor.size(2), batch_i_branch_prediction_tensor.size(3)

        if p_prediction_height != output_height or p_prediction_height != output_width:
            batch_p_branch_prediction_tensor = F.interpolate(
                batch_p_branch_prediction_tensor,
                size=(output_height, output_width),
                mode='bilinear',
                align_corners=True
            )

        if i_prediction_height != output_height or i_prediction_height != output_width:
            batch_i_branch_prediction_tensor = F.interpolate(
                batch_i_branch_prediction_tensor,
                size=(output_height, output_width),
                mode='bilinear',
                align_corners=True
            )

        if d_prediction_height != output_height or d_prediction_height != output_width:
            batch_d_branch_prediction_tensor = F.interpolate(
                batch_d_branch_prediction_tensor,
                size=(output_height, output_width),
                mode='bilinear',
                align_corners=True
            )

        return batch_p_branch_prediction_tensor, batch_i_branch_prediction_tensor, batch_d_branch_prediction_tensor, prediction_time

    def predict(self, image_tensor, output_height, output_width, plot=False):
        if len(image_tensor.shape) == 3:
            image_tensor = image_tensor.unsqueeze(dim=0)

        _, prediction, _ = self.forward(
            image_tensor.permute(0, 3, 1, 2)
        )

        prediction = F.softmax(
            prediction.detach(),
            dim=1
        )

        prediction_height, prediction_width = prediction.size(1), prediction.size(2)
        if prediction_height != output_height or prediction_height != output_width:
            prediction = F.interpolate(
                prediction,
                size=(output_height, output_width),
                mode='bilinear',
                align_corners=True
            )

        prediction = torch.argmax(
            prediction.squeeze().permute(1, 2, 0),
            dim=2
        )

        if plot:
            plot_image(image_tensor.squeeze(), image_color=True)
            plot_image(prediction, image_color=False, image_cmap=classColorMap)
            plot_image(image_tensor.squeeze(), mask=prediction, mask_color=False, mask_cmap=classColorMap)

        return prediction

    def process_data(self, mode: int, batch_image_tensor, batch_image_label_tensor, batch_boundary_label_tensor, print_out=False):
        batch_image_label_tensor = batch_image_label_tensor.permute(0, 3, 1, 2)
        label_height, label_width = batch_image_label_tensor.size(2), batch_image_label_tensor.size(3)

        batch_p_branch_prediction_tensor, batch_i_branch_prediction_tensor, batch_d_branch_prediction_tensor, prediction_time = self.__predict(batch_image_tensor, label_height, label_width)

        iou, pixel_accuracy = self.calc_metrics(batch_i_branch_prediction_tensor, batch_image_label_tensor)

        semantic_loss = self.semantic_loss_function([batch_p_branch_prediction_tensor, batch_i_branch_prediction_tensor], batch_image_label_tensor)
        boundary_loss = self.boundary_loss_function(batch_d_branch_prediction_tensor, batch_boundary_label_tensor)

        filler = torch.ones_like(batch_image_label_tensor) * -1
        boundary_label = torch.where(
            torch.sigmoid(batch_d_branch_prediction_tensor) > 0.8,
            batch_image_label_tensor,
            filler
        )
        combined_loss = self.semantic_loss_function(batch_i_branch_prediction_tensor, boundary_label)

        full_loss = semantic_loss + boundary_loss + combined_loss
        full_loss = torch.unsqueeze(full_loss, 0).mean()

        if print_out:
            print(
                "\n=" * 100,
                "\nPixel_Accuracy:", pixel_accuracy,
                "\nSemantic Loss:", semantic_loss,
                "\nBoundary Loss:", boundary_loss,
                "\nCombined Loss:", combined_loss,
                "\nFull Loss:", full_loss
            )

        if mode == PIDNet.train_mode:
            self.optimiser.zero_grad(set_to_none=True)
            full_loss.backward()
            self.optimiser.step()

        return semantic_loss.mean().item(), boundary_loss.item(), combined_loss.item(), full_loss.item(), pixel_accuracy, iou, prediction_time

    def calc_metrics(self, batch_prediction_tensor, batch_image_label_tensor):
        batch_image_label_tensor = batch_image_label_tensor.argmax(dim=1)
        batch_prediction_tensor = F.softmax(batch_prediction_tensor, dim=1).detach()

        iou_metric = metrics.MulticlassJaccardIndex(num_classes=self.num_classes, average="weighted")
        pixel_acc_metric = metrics.MulticlassAccuracy(num_classes=self.num_classes, average="weighted", multidim_average="samplewise")

        iou = iou_metric(batch_prediction_tensor, batch_image_label_tensor).item()
        pixel_acc = pixel_acc_metric(batch_prediction_tensor, batch_image_label_tensor).mean().item()

        return iou, pixel_acc

    def get_spec_string(self, override=""):
        if override != "":
            return override

        hyperParam_string = self.name + "-"
        hyperParam_string += str(self.semantic_loss_function).split("(")[0] + "-"
        hyperParam_string += str(self.optimiser).split(" ")[0] + "-"
        hyperParam_string += "lr" + str(self.learning_rate)

        return hyperParam_string

# Training und Testen

### Util

In [None]:
class Result(object):

    def __init__(self):
        self.len = 0

        self.semantic_loss = []
        self.boundary_loss = []
        self.combined_loss = []
        self.full_loss = []
        self.pixel_acc = []
        self.iou = []
        self.calc_time = []

    def append(self, semantic_loss, boundary_loss, combined_loss, full_loss, pixel_acc, iou, calc_time):
        if (semantic_loss is None or
                boundary_loss is None or
                combined_loss is None or
                full_loss is None or
                pixel_acc is None or
                iou is None or
                calc_time is None):
            raise ValueError("None of the given Parameters can be None!")

        self.semantic_loss.append(semantic_loss)
        self.boundary_loss.append(boundary_loss)
        self.combined_loss.append(combined_loss)
        self.full_loss.append(full_loss)
        self.pixel_acc.append(pixel_acc)
        self.iou.append(iou)
        self.calc_time.append(calc_time)

        if (len(self.full_loss) != len(self.semantic_loss) or
                len(self.full_loss) != len(self.boundary_loss) or
                len(self.full_loss) != len(self.combined_loss) or
                len(self.full_loss) != len(self.pixel_acc) or
                len(self.full_loss) != len(self.iou) or
                len(self.full_loss) != len(self.calc_time)):
            raise Exception("All properties must be of the same length!")

        self.len += 1

    def append_as_result(self, result: 'Result'):
        semantic_loss, boundary_loss, combined_loss, full_loss, pixel_acc, iou, calc_time = result.value()
        self.append(
            semantic_loss=semantic_loss,
            boundary_loss=boundary_loss,
            combined_loss=combined_loss,
            full_loss=full_loss,
            pixel_acc=pixel_acc,
            iou=iou,
            calc_time=calc_time
        )

    def append_avg(self, result: 'Result'):
        mean_semantic_loss, mean_boundary_loss, mean_combined_loss, mean_full_loss, mean_pixel_acc, mean_iou, mean_calc_time = result.average()
        self.append(
            semantic_loss=mean_semantic_loss,
            boundary_loss=mean_boundary_loss,
            combined_loss=mean_combined_loss,
            full_loss=mean_full_loss,
            pixel_acc=mean_pixel_acc,
            iou=mean_iou,
            calc_time=mean_calc_time
        )

    def __len__(self):
        return self.len

    def __str__(self):
        return "\n\n=" * 200, \
            "\nSemantic Loss:", self.semantic_loss, \
            "\nBoundary Loss:", self.boundary_loss, \
            "\nCombined Loss:", self.combined_loss, \
            "\nFull Loss:", self.full_loss, \
            "\nPixel Accuracy:", self.pixel_acc, \
            "\nIoU:", self.iou, \
            "\nCalc Time:", self.calc_time

    def get_loss(self):
        return self.semantic_loss[-1], self.boundary_loss[-1], self.combined_loss[-1], self.full_loss[-1]

    def get_accuracy(self):
        return self.iou[-1], self.pixel_acc[-1], self.calc_time[-1]

    def value(self):
        return (self.semantic_loss,
                self.boundary_loss,
                self.combined_loss,
                self.full_loss,
                self.pixel_acc,
                self.iou,
                self.calc_time)

    def value_at_index(self, index: int):
        return (self.semantic_loss[index],
                self.boundary_loss[index],
                self.combined_loss[index],
                self.full_loss[index],
                self.pixel_acc[index],
                self.iou[index],
                self.calc_time[index])

    def printable_value_at_index(self, index: int):
        return "Semantic Loss: " + str(self.semantic_loss[index]) + \
            " | Boundary Loss: " + str(self.boundary_loss[index]) + \
            " | Combined Loss: " + str(self.combined_loss[index]) + \
            " | Full Loss: " + str(self.full_loss[index]) + \
            " | Pixel Accuracy: " + str(self.pixel_acc[index]) + \
            " | IoU: " + str(self.iou[index]) + \
            " | Prediction Time: " + str(self.calc_time[index])

    def average(self):
        return (np.array(self.semantic_loss).mean(),
                np.array(self.boundary_loss).mean(),
                np.array(self.combined_loss).mean(),
                np.array(self.full_loss).mean(),
                np.array(self.pixel_acc).mean(),
                np.array(self.iou).mean(),
                np.array(self.calc_time).mean())

    def max(self):
        return (np.array(self.semantic_loss).max(),
                np.array(self.boundary_loss).max(),
                np.array(self.combined_loss).max(),
                np.array(self.full_loss).max(),
                np.array(self.pixel_acc).max(),
                np.array(self.iou).max(),
                np.array(self.calc_time).max())

    def min(self):
        return (np.array(self.semantic_loss).min(),
                np.array(self.boundary_loss).min(),
                np.array(self.combined_loss).min(),
                np.array(self.full_loss).min(),
                np.array(self.pixel_acc).min(),
                np.array(self.iou).min(),
                np.array(self.calc_time).min())

### ModelHandler

In [None]:
class ModelHandler:
    def __init__(self, model):
        self.model = model

        self.scheduler = scheduler.LambdaLR(
            optimizer=model.optimiser,
            lr_lambda= [
                lambda epoch: epoch / 10
            ]
        )

        self.train_results = Result()
        self.test_results_with_train_data = Result()
        self.test_results_with_test_data = Result()

        current_ts = datetime.now()
        self.directory = "Modelle/" + model.get_spec_string() + "_" + current_ts.strftime("%d%b%y-%H-%M")
        try:
            if not os.path.isdir(self.directory):
                os.mkdir(self.directory)
                os.mkdir(self.directory + "/Model")

        except OSError:
            raise OSError

        else:
            model_init_state = deepcopy(model)
            torch.save(model_init_state.state_dict(), self.directory + "/Model/model_init_state.pt")

    def training(self, train_epochs=10, train_batch_size=8, augment=True, print_out=True):
        train_loader = DataLoader(dataset=train_set, batch_size=train_batch_size, shuffle=True, num_workers=0, generator=torch.Generator(device=device))
        bs1_train_loader = DataLoader(dataset=train_set, batch_size=1, shuffle=True, num_workers=0, generator=torch.Generator(device=device))
        test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=True, num_workers=0, generator=torch.Generator(device=device))

        epoch_border = self.model.epochsTrained + train_epochs + 1

        while self.model.epochsTrained < epoch_border:
            self.model.train()
            start_time = datetime.now()

            torch.backends.cudnn.benchmark = True

            for step, data in enumerate(train_loader, 0):
                print_out_ = step % (train_batch_size * 5) == 0 and step != 0

                entries = [data]
                if augment:
                    entries += augment_entry(data[0], data[1], data[2], data[3])

                for image, _, class_label, edge in entries:
                    semantic_loss, boundary_loss, combined_loss, full_loss, pixel_acc, iou, calc_time = self.model.process_data(PIDNet.train_mode, image, class_label, edge)
                    self.train_results.append(
                        semantic_loss=semantic_loss,
                        boundary_loss=boundary_loss,
                        combined_loss=combined_loss,
                        full_loss=full_loss,
                        pixel_acc=pixel_acc,
                        iou=iou,
                        calc_time=calc_time
                    )

                    free_gpu_cache([image, class_label, edge])

                if print_out and print_out_:
                    print("Done", step, "Batch-Trainingsteps.")

            train_duration = datetime.now() - start_time

            random_index = random.randint(0, len(test_set) - 1)
            self.model.eval()
            test_image = test_set.__getitem__(random_index)[0]
            with torch.no_grad():
                classifier.predict(test_image, 1024, 1024, plot=True)

            self.testing(bs1_train_loader, self.test_results_with_train_data, print_out=False)
            self.testing(test_loader, self.test_results_with_test_data, print_out=False)

            self.print_epoch(self.model.epochsTrained, train_duration)
            self.save_model(self.model.epochsTrained)

            if self.model.epochsTrained == 10 or self.model.epochsTrained == 20:
                self.scheduler.step()

            self.model.epochsTrained += 1

        self.save_results(batch_size=train_batch_size)

    def testing(self, dataloader, result: Result, print_out=True):
        self.model.eval()
        test_results = Result()

        with torch.no_grad():
            for index, data in enumerate(dataloader, 0):
                image, _, label, edge = data

                semantic_loss, boundary_loss, combined_loss, full_loss, pixel_acc, iou, calc_time = self.model.process_data(PIDNet.test_mode, image, label, edge)
                test_results.append(
                    semantic_loss=semantic_loss,
                    boundary_loss=boundary_loss,
                    combined_loss=combined_loss,
                    full_loss=full_loss,
                    pixel_acc=pixel_acc,
                    iou=iou,
                    calc_time=calc_time
                )

                free_gpu_cache([image, label, edge])

                if print_out and index % 10 == 0:
                    color = 'green' if result.iou > 50 else 'red'
                    output = test_results
                    output = colored(output, color, attrs=['reverse', 'blink'])
                    print(output)

        result.append_avg(test_results)

    def print_epoch(self, epoch_nr: int, train_duration: str):
        print("\n----------------- Epoch", epoch_nr, "-----------------")
        print("-->  Train-Duration:", str(train_duration))

        print("\n-->  Train-Results:", self.train_results.printable_value_at_index(-1))
        print("\n-->  Test-Results with Train Data:", self.test_results_with_train_data.printable_value_at_index(-1))
        print("\n-->  Test-Results with Test Data:", self.test_results_with_test_data.printable_value_at_index(-1))
        print("\n", 43 * "-")

    def save_model(self, epoch_nr: int):
        iou, pixel_acc, calc_time = self.test_results_with_test_data.get_accuracy()

        filename = "model_e" + str(epoch_nr).zfill(2) + "_pixel-ac" + str(pixel_acc) + "_iou" + str(iou) + "_time" + str(calc_time)
        torch.save(self.model.state_dict(), self.directory + "/Model/" + filename + ".pt")
        print("Safed model as " + filename + ".pt to " + self.directory + "/Model/")

    def save_results(self, batch_size):
        np.save(self.directory + "/loss.npy", self.train_results)
        np.save(self.directory + "/avg_train_loss.npy", self.test_results_with_train_data)
        np.save(self.directory + "/avg_test_loss.npy", self.test_results_with_test_data)

        model_specs = {
            "model": {
                "spec_string": self.model.get_spec_string(),
                "name": self.model.name,
                "semantic loss-function": str(self.model.semantic_loss_function).split("(")[0],
                "boundary loss-function": str(self.model.boundary_loss_function).split("(")[0],
                "optimiser": str(self.model.optimiser).split(" ")[0],
                "learning_rate": self.model.learning_rate
            },
            "training": {
                "batch_size": batch_size
            }
        }
        json_string = json.dumps(model_specs)
        with open(self.directory + "/spec.json", "w") as file:
            file.write(json_string)

        #shutil.make_archive(self.directory, 'zip', self.directory)
        #shutil.rmtree(self.directory)

# Start

In [None]:
classifier_model_path = ""
classifier = PIDNet(
    name="PIDNet-S",
    learning_rate=1e-2,
    out_channels=32,
    ppm_channels=96,
    head_channels=128,
    loss_weights=loss_weights
)

if classifier_model_path != "":
    classifier.load_state_dict(torch.load(classifier_model_path))

classifier.to(device)

print("\n\n" + 100 * "=")
print(100 * "=")
print("Created " + classifier.get_spec_string())

c_handler = ModelHandler(classifier)
c_handler.training(train_epochs=50, train_batch_size=16)