In [None]:
!pip install -q albumentations==1.4.18

In [None]:
!pip install yacs
!pip install tensorboardX

In [None]:
import shutil
shutil.copytree('/kaggle/input/pidnet-pretrained/', '/kaggle/working/pretrained_model')

In [None]:
import os

def create_lst_file(image_dir, label_dir, output_lst):
    # List and sort files numerically
    images = sorted(os.listdir(image_dir), key=lambda x: int(os.path.splitext(x)[0]))
    labels = sorted(os.listdir(label_dir), key=lambda x: int(os.path.splitext(x)[0]))

    os.makedirs(os.path.dirname(output_lst), exist_ok=True)  # Ensure the directory exists

    with open(output_lst, 'w') as f:
        for img, lbl in zip(images, labels):
            # Generate full paths and normalize to use forward slashes
            img_path = os.path.join(image_dir, img).replace("\\", "/")
            lbl_path = os.path.join(label_dir, lbl).replace("\\", "/")
            # Write formatted line with consistent spacing
            f.write(f"{img_path} {lbl_path}\n")

# Paths to the LoveDA dataset directories
train_image_dir = "/kaggle/input/loveda-splits/Train/Train/Urban/images_png"
train_label_dir = "/kaggle/input/loveda-splits/Train/Train/Urban/masks_png"


target_image_dir = "/kaggle/input/loveda-splits/Train/Train/Rural/images_png"
target_label_dir = "/kaggle/input/loveda-splits/Train/Train/Rural/masks_png"

val_image_dir = "/kaggle/input/loveda-splits/Val/Val/Rural/images_png"
val_label_dir = "/kaggle/input/loveda-splits/Val/Val/Rural/masks_png"


train_lst_path = "list/urban/train.lst"
target_lst_path = "list/rural/train.lst"
val_lst_path = "list/rural/val.lst"

# Create .lst files
create_lst_file(train_image_dir, train_label_dir, train_lst_path)
create_lst_file(target_image_dir, target_label_dir, target_lst_path)
create_lst_file(val_image_dir, val_label_dir, val_lst_path)

In [None]:
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Ke Sun (sunk@mail.ustc.edu.cn)
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path as osp
import sys


def add_path(path):
    if path not in sys.path:
        sys.path.insert(0, path)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

BatchNorm2d = nn.BatchNorm2d
bn_mom = 0.1
algc = False

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=False):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               padding=1, bias=False)
        self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
        self.downsample = downsample
        self.stride = stride
        self.no_relu = no_relu

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        if self.no_relu:
            return out
        else:
            return self.relu(out)

class Bottleneck(nn.Module):
    expansion = 2

    def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=True):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = BatchNorm2d(planes * self.expansion, momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.no_relu = no_relu

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        if self.no_relu:
            return out
        else:
            return self.relu(out)

class segmenthead(nn.Module):

    def __init__(self, inplanes, interplanes, outplanes, scale_factor=None):
        super(segmenthead, self).__init__()
        self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom)
        self.conv1 = nn.Conv2d(inplanes, interplanes, kernel_size=3, padding=1, bias=False)
        self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(interplanes, outplanes, kernel_size=1, padding=0, bias=True)
        self.scale_factor = scale_factor

    def forward(self, x):

        x = self.conv1(self.relu(self.bn1(x)))
        out = self.conv2(self.relu(self.bn2(x)))

        if self.scale_factor is not None:
            height = x.shape[-2] * self.scale_factor
            width = x.shape[-1] * self.scale_factor
            out = F.interpolate(out,
                        size=[height, width],
                        mode='bilinear', align_corners=algc)

        return out

class DAPPM(nn.Module):
    def __init__(self, inplanes, branch_planes, outplanes, BatchNorm=nn.BatchNorm2d):
        super(DAPPM, self).__init__()
        bn_mom = 0.1
        self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale0 = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.process1 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.process2 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.process3 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.process4 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.compression = nn.Sequential(
                                    BatchNorm(branch_planes * 5, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False),
                                    )
        self.shortcut = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False),
                                    )

    def forward(self, x):
        width = x.shape[-1]
        height = x.shape[-2]
        x_list = []

        x_list.append(self.scale0(x))
        x_list.append(self.process1((F.interpolate(self.scale1(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[0])))
        x_list.append((self.process2((F.interpolate(self.scale2(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[1]))))
        x_list.append(self.process3((F.interpolate(self.scale3(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[2])))
        x_list.append(self.process4((F.interpolate(self.scale4(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[3])))

        out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x)
        return out

class PAPPM(nn.Module):
    def __init__(self, inplanes, branch_planes, outplanes, BatchNorm=nn.BatchNorm2d):
        super(PAPPM, self).__init__()
        bn_mom = 0.1
        self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )

        self.scale0 = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )

        self.scale_process = nn.Sequential(
                                    BatchNorm(branch_planes*4, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes*4, branch_planes*4, kernel_size=3, padding=1, groups=4, bias=False),
                                    )


        self.compression = nn.Sequential(
                                    BatchNorm(branch_planes * 5, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False),
                                    )

        self.shortcut = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False),
                                    )


    def forward(self, x):
        width = x.shape[-1]
        height = x.shape[-2]
        scale_list = []

        x_ = self.scale0(x)
        scale_list.append(F.interpolate(self.scale1(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)
        scale_list.append(F.interpolate(self.scale2(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)
        scale_list.append(F.interpolate(self.scale3(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)
        scale_list.append(F.interpolate(self.scale4(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)

        scale_out = self.scale_process(torch.cat(scale_list, 1))

        out = self.compression(torch.cat([x_,scale_out], 1)) + self.shortcut(x)
        return out


class PagFM(nn.Module):
    def __init__(self, in_channels, mid_channels, after_relu=False, with_channel=False, BatchNorm=nn.BatchNorm2d):
        super(PagFM, self).__init__()
        self.with_channel = with_channel
        self.after_relu = after_relu
        self.f_x = nn.Sequential(
                                nn.Conv2d(in_channels, mid_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(mid_channels)
                                )
        self.f_y = nn.Sequential(
                                nn.Conv2d(in_channels, mid_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(mid_channels)
                                )
        if with_channel:
            self.up = nn.Sequential(
                                    nn.Conv2d(mid_channels, in_channels,
                                              kernel_size=1, bias=False),
                                    BatchNorm(in_channels)
                                   )
        if after_relu:
            self.relu = nn.ReLU(inplace=True)

    def forward(self, x, y):
        input_size = x.size()
        if self.after_relu:
            y = self.relu(y)
            x = self.relu(x)

        y_q = self.f_y(y)
        y_q = F.interpolate(y_q, size=[input_size[2], input_size[3]],
                            mode='bilinear', align_corners=False)
        x_k = self.f_x(x)

        if self.with_channel:
            sim_map = torch.sigmoid(self.up(x_k * y_q))
        else:
            sim_map = torch.sigmoid(torch.sum(x_k * y_q, dim=1).unsqueeze(1))

        y = F.interpolate(y, size=[input_size[2], input_size[3]],
                            mode='bilinear', align_corners=False)
        x = (1-sim_map)*x + sim_map*y

        return x

class Light_Bag(nn.Module):
    def __init__(self, in_channels, out_channels, BatchNorm=nn.BatchNorm2d):
        super(Light_Bag, self).__init__()
        self.conv_p = nn.Sequential(
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )
        self.conv_i = nn.Sequential(
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )

    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)

        p_add = self.conv_p((1-edge_att)*i + p)
        i_add = self.conv_i(i + edge_att*p)

        return p_add + i_add


class DDFMv2(nn.Module):
    def __init__(self, in_channels, out_channels, BatchNorm=nn.BatchNorm2d):
        super(DDFMv2, self).__init__()
        self.conv_p = nn.Sequential(
                                BatchNorm(in_channels),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )
        self.conv_i = nn.Sequential(
                                BatchNorm(in_channels),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )

    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)

        p_add = self.conv_p((1-edge_att)*i + p)
        i_add = self.conv_i(i + edge_att*p)

        return p_add + i_add

class Bag(nn.Module):
    def __init__(self, in_channels, out_channels, BatchNorm=nn.BatchNorm2d):
        super(Bag, self).__init__()

        self.conv = nn.Sequential(
                                BatchNorm(in_channels),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=3, padding=1, bias=False)
                                )


    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)
        return self.conv(edge_att*p + (1-edge_att)*i)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import logging

BatchNorm2d = nn.BatchNorm2d
bn_mom = 0.1
algc = False

class PIDNet(nn.Module):

    def __init__(self, m=2, n=3, num_classes=19, planes=64, ppm_planes=96, head_planes=128, augment=True):
        super(PIDNet, self).__init__()
        self.augment = augment

        # I Branch
        self.conv1 =  nn.Sequential(
                          nn.Conv2d(3,planes,kernel_size=3, stride=2, padding=1),
                          BatchNorm2d(planes, momentum=bn_mom),
                          nn.ReLU(inplace=True),
                          nn.Conv2d(planes,planes,kernel_size=3, stride=2, padding=1),
                          BatchNorm2d(planes, momentum=bn_mom),
                          nn.ReLU(inplace=True),
                      )

        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(BasicBlock, planes, planes, m)
        self.layer2 = self._make_layer(BasicBlock, planes, planes * 2, m, stride=2)
        self.layer3 = self._make_layer(BasicBlock, planes * 2, planes * 4, n, stride=2)
        self.layer4 = self._make_layer(BasicBlock, planes * 4, planes * 8, n, stride=2)
        self.layer5 =  self._make_layer(Bottleneck, planes * 8, planes * 8, 2, stride=2)

        # P Branch
        self.compression3 = nn.Sequential(
                                          nn.Conv2d(planes * 4, planes * 2, kernel_size=1, bias=False),
                                          BatchNorm2d(planes * 2, momentum=bn_mom),
                                          )

        self.compression4 = nn.Sequential(
                                          nn.Conv2d(planes * 8, planes * 2, kernel_size=1, bias=False),
                                          BatchNorm2d(planes * 2, momentum=bn_mom),
                                          )
        self.pag3 = PagFM(planes * 2, planes)
        self.pag4 = PagFM(planes * 2, planes)

        self.layer3_ = self._make_layer(BasicBlock, planes * 2, planes * 2, m)
        self.layer4_ = self._make_layer(BasicBlock, planes * 2, planes * 2, m)
        self.layer5_ = self._make_layer(Bottleneck, planes * 2, planes * 2, 1)

        # D Branch
        if m == 2:
            self.layer3_d = self._make_single_layer(BasicBlock, planes * 2, planes)
            self.layer4_d = self._make_layer(Bottleneck, planes, planes, 1)
            self.diff3 = nn.Sequential(
                                        nn.Conv2d(planes * 4, planes, kernel_size=3, padding=1, bias=False),
                                        BatchNorm2d(planes, momentum=bn_mom),
                                        )
            self.diff4 = nn.Sequential(
                                     nn.Conv2d(planes * 8, planes * 2, kernel_size=3, padding=1, bias=False),
                                     BatchNorm2d(planes * 2, momentum=bn_mom),
                                     )
            self.spp = PAPPM(planes * 16, ppm_planes, planes * 4)
            self.dfm = Light_Bag(planes * 4, planes * 4)
        else:
            self.layer3_d = self._make_single_layer(BasicBlock, planes * 2, planes * 2)
            self.layer4_d = self._make_single_layer(BasicBlock, planes * 2, planes * 2)
            self.diff3 = nn.Sequential(
                                        nn.Conv2d(planes * 4, planes * 2, kernel_size=3, padding=1, bias=False),
                                        BatchNorm2d(planes * 2, momentum=bn_mom),
                                        )
            self.diff4 = nn.Sequential(
                                     nn.Conv2d(planes * 8, planes * 2, kernel_size=3, padding=1, bias=False),
                                     BatchNorm2d(planes * 2, momentum=bn_mom),
                                     )
            self.spp = DAPPM(planes * 16, ppm_planes, planes * 4)
            self.dfm = Bag(planes * 4, planes * 4)

        self.layer5_d = self._make_layer(Bottleneck, planes * 2, planes * 2, 1)

        # Prediction Head
        if self.augment:
            self.seghead_p = segmenthead(planes * 2, head_planes, num_classes)
            self.seghead_d = segmenthead(planes * 2, planes, 1)

        self.final_layer = segmenthead(planes * 4, head_planes, num_classes)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


    def _make_layer(self, block, inplanes, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
            )

        layers = []
        layers.append(block(inplanes, planes, stride, downsample))
        inplanes = planes * block.expansion
        for i in range(1, blocks):
            if i == (blocks-1):
                layers.append(block(inplanes, planes, stride=1, no_relu=True))
            else:
                layers.append(block(inplanes, planes, stride=1, no_relu=False))

        return nn.Sequential(*layers)

    def _make_single_layer(self, block, inplanes, planes, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
            )

        layer = block(inplanes, planes, stride, downsample, no_relu=True)

        return layer

    def forward(self, x):

        width_output = x.shape[-1] // 8
        height_output = x.shape[-2] // 8

        x = self.conv1(x)
        x = self.layer1(x)
        x = self.relu(self.layer2(self.relu(x)))
        x_ = self.layer3_(x)
        x_d = self.layer3_d(x)

        x = self.relu(self.layer3(x))
        x_ = self.pag3(x_, self.compression3(x))
        x_d = x_d + F.interpolate(
                        self.diff3(x),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=algc)
        if self.augment:
            temp_p = x_

        x = self.relu(self.layer4(x))
        x_ = self.layer4_(self.relu(x_))
        x_d = self.layer4_d(self.relu(x_d))

        x_ = self.pag4(x_, self.compression4(x))
        x_d = x_d + F.interpolate(
                        self.diff4(x),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=algc)
        if self.augment:
            temp_d = x_d

        x_ = self.layer5_(self.relu(x_))
        x_d = self.layer5_d(self.relu(x_d))
        x = F.interpolate(
                        self.spp(self.layer5(x)),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=algc)

        x_ = self.final_layer(self.dfm(x_, x, x_d))

        if self.augment:
            x_extra_p = self.seghead_p(temp_p)
            x_extra_d = self.seghead_d(temp_d)
            return [x_extra_p, x_, x_extra_d]
        else:
            return x_

def get_seg_model():

    model = PIDNet(m=2, n=3, num_classes=8, planes=32, ppm_planes=96, head_planes=128, augment=True)
    
    
    pretrained_state = torch.load('/kaggle/working/pretrained_model/PIDNet_S_ImageNet.pth.tar', map_location='cpu')['state_dict']
    model_dict = model.state_dict()
    pretrained_state = {k: v for k, v in pretrained_state.items() if (k in model_dict and v.shape == model_dict[k].shape)}
    model_dict.update(pretrained_state)
    msg = 'Loaded {} parameters!'.format(len(pretrained_state))
    logging.info('Attention!!!')
    logging.info(msg)
    logging.info('Over!!!')
    model.load_state_dict(model_dict, strict = False)

    return model

def get_pred_model(name, num_classes):

    if 's' in name:
        model = PIDNet(m=2, n=3, num_classes=num_classes, planes=32, ppm_planes=96, head_planes=128, augment=False)
    elif 'm' in name:
        model = PIDNet(m=2, n=3, num_classes=num_classes, planes=64, ppm_planes=96, head_planes=128, augment=False)
    else:
        model = PIDNet(m=3, n=4, num_classes=num_classes, planes=64, ppm_planes=112, head_planes=256, augment=False)

    return model

In [None]:
import cv2
import numpy as np
import random
import albumentations as A
from torch.nn import functional as F
from torch.utils import data
import matplotlib.pyplot as plt
y_k_size = 6
x_k_size = 6

def show_images(x_original, x_augmented, unnormalize = False):

    if unnormalize:
        # ImageNet mean and std
        imagenet_mean = np.array([0.485, 0.456, 0.406])[:, None, None]
        imagenet_std = np.array([0.229, 0.224, 0.225])[:, None, None]

        # Denormalize using NumPy broadcasting
        x_original = x_original * imagenet_std + imagenet_mean
        x_augmented = x_augmented * imagenet_std + imagenet_mean

        # Clip to [0, 1] in case of overflows
        x_original = np.clip(x_original, 0, 1)
        x_augmented = np.clip(x_augmented, 0, 1)

        # Transpose to HWC for matplotlib
        x_original = np.transpose(x_original, (1, 2, 0))
        x_augmented = np.transpose(x_augmented, (1, 2, 0))

    # Plot
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].imshow(x_original)
    axs[0].set_title("Original Image")
    axs[0].axis("off")

    axs[1].imshow(x_augmented)
    axs[1].set_title("Augmented Image")
    axs[1].axis("off")

    plt.tight_layout()
    plt.show()

class BaseDataset(data.Dataset):
    def __init__(self,
                 ignore_label=255,
                 base_size=2048,
                 crop_size=(512, 512),
                 scale_factor=16,
                 mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]):

        self.base_size = base_size
        self.crop_size = crop_size
        self.ignore_label = ignore_label

        self.mean = mean
        self.std = std
        self.scale_factor = scale_factor

        self.files = []

    def __len__(self):
        return len(self.files)

    def input_transform(self, image, city=False):
        if city:
            image = image.astype(np.float32)[:, :, ::-1]
        else:
            image = image.astype(np.float32)
        image = image / 255.0
        image -= self.mean
        image /= self.std
        return image

    def label_transform(self, label):
        return np.array(label).astype(np.uint8)

    def pad_image(self, image, h, w, size, padvalue):
        pad_h = max(size[0] - h, 0)
        pad_w = max(size[1] - w, 0)

        # Se non è necessario il padding, restituisci l'immagine originale
        if pad_h == 0 and pad_w == 0:
            return image

        # Verifica il formato dell'immagine (deve essere H, W, C)
        if len(image.shape) == 3 and image.shape[0] <= 3:  # Se è in formato (C, H, W)
            image = np.transpose(image, (1, 2, 0))  # Converti in (H, W, C)

        # Aggiungi il padding
        pad_image = cv2.copyMakeBorder(image, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=padvalue)

        # Ripristina il formato originale (C, H, W) se necessario
        if len(image.shape) == 3 and image.shape[2] <= 3:  # Se era in formato (C, H, W)
            pad_image = np.transpose(pad_image, (2, 0, 1))  # Converti di nuovo in (C, H, W)

        return pad_image

    def rand_crop(self, image, label, edge):
        # Verifica il formato dell'immagine
        if len(image.shape) == 3 and image.shape[0] <= 3:  # Se è in formato (C, H, W)
            image = np.transpose(image, (1, 2, 0))  # Converti in (H, W, C)

        h, w = image.shape[:2]

        # Aggiungi padding se necessario
        if h < self.crop_size[0] or w < self.crop_size[1]:
            image = self.pad_image(image, h, w, self.crop_size, (0.0, 0.0, 0.0))
            label = self.pad_image(label, h, w, self.crop_size, (self.ignore_label,))
            edge = self.pad_image(edge, h, w, self.crop_size, (0.0,))

        # Aggiorna le dimensioni dopo il padding
        new_h, new_w = label.shape
        if new_h < self.crop_size[0] or new_w < self.crop_size[1]:
            raise ValueError(f"Dimensioni insufficienti per il ritaglio: label={label.shape}, crop_size={self.crop_size}")

        # Calcola le coordinate per il ritaglio casuale
        x = random.randint(0, new_w - self.crop_size[1])
        y = random.randint(0, new_h - self.crop_size[0])

        # Esegui il ritaglio
        image = image[y:y+self.crop_size[0], x:x+self.crop_size[1]]
        label = label[y:y+self.crop_size[0], x:x+self.crop_size[1]]
        edge = edge[y:y+self.crop_size[0], x:x+self.crop_size[1]]

        #in questo modo l'iimagine è 512x512x3
        #se volessi croppare quella regione
        '''
        # Estrai la regione da sfocare
        cropped_region = image[y:y+crop_size[0], x:x+crop_size[1]]

        # Applica il Gaussian Blur alla regione
        blurred_region = cv2.GaussianBlur(cropped_region, (15, 15), 0)

        # Sostituisci la regione originale con quella sfocata
        augmented_image = image.copy()
        augmented_image[y:y+crop_size[0], x:x+crop_size[1]] = blurred_region
        '''

        return image, label, edge

    def multi_scale_aug(self, image, label=None, edge=None,
                        rand_scale=1, rand_crop=True):
        long_size = int(self.base_size * rand_scale + 0.5)
        h, w = image.shape[:2]
        if h > w:
            new_h = long_size
            new_w = int(w * long_size / h + 0.5)
        else:
            new_w = long_size
            new_h = int(h * long_size / w + 0.5)

        image = cv2.resize(image, (new_w, new_h),
                           interpolation=cv2.INTER_LINEAR)
        if label is not None:
            label = cv2.resize(label, (new_w, new_h),
                               interpolation=cv2.INTER_NEAREST)
            if edge is not None:
                edge = cv2.resize(edge, (new_w, new_h),
                                   interpolation=cv2.INTER_NEAREST)
        else:
            return image

        if rand_crop:
            image, label, edge = self.rand_crop(image, label, edge)

        return image, label, edge


    def gen_sample(self, image, label, edge_pad=True, edge_size=4, city=False, transform=None, show=False):


        if transform is not None:
            # Pass both image and mask
            augmented = transform(image=image, mask=label)

            if show:
                show_images(image, augmented["image"])

            # Extract results
            image = augmented['image']
            label = augmented['mask']



        #It' important keeping the edge generation after the data augmentation
        edge = cv2.Canny(label, 0.1, 0.2)
        kernel = np.ones((edge_size, edge_size), np.uint8)
        if edge_pad:
            edge = edge[y_k_size:-y_k_size, x_k_size:-x_k_size]
            edge = np.pad(edge, ((y_k_size,y_k_size),(x_k_size,x_k_size)), mode='constant')
        edge = (cv2.dilate(edge, kernel, iterations=1)>50)*1.0


        #trasformazioni di input
        image = self.input_transform(image, city=city) #Se city=True, converte l'immagine da RGB in BGR per opencv
        label = self.label_transform(label) #converte la label in un array di interi
        image = image.transpose((2, 0, 1)) #H,W,C -> C,H,W

        return image, label, edge


    def inference(self, model, image):
        size = image.size()
        pred = model(image)

        
        pred = pred[1]


        pred = F.interpolate(
            input=pred, size=size[-2:],
            mode='bilinear', align_corners=True
        )


        return pred.exp()

In [None]:
import cv2
import os
import numpy as np
import torch
import random
import logging
from PIL import Image
import torchvision.transforms as tf
import matplotlib.pyplot as plt

def show_images(x_original, x_augmented, unnormalize = False):

    if unnormalize:
        # ImageNet mean and std
        imagenet_mean = np.array([0.485, 0.456, 0.406])[:, None, None]
        imagenet_std = np.array([0.229, 0.224, 0.225])[:, None, None]

        # Denormalize using NumPy broadcasting
        x_original = x_original * imagenet_std + imagenet_mean
        x_augmented = x_augmented * imagenet_std + imagenet_mean

        # Clip to [0, 1] in case of overflows
        x_original = np.clip(x_original, 0, 1)
        x_augmented = np.clip(x_augmented, 0, 1)

        # Transpose to HWC for matplotlib
        x_original = np.transpose(x_original, (1, 2, 0))
        x_augmented = np.transpose(x_augmented, (1, 2, 0))

    # Plot
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].imshow(x_original)
    axs[0].set_title("Original Image")
    axs[0].axis("off")

    axs[1].imshow(x_augmented)
    axs[1].set_title("Augmented Image")
    axs[1].axis("off")

    plt.tight_layout()
    plt.show()



class LoveDA(BaseDataset):
    def __init__(self,
                 root,
                 list_path,
                 num_classes=7,
                 flip=False,
                 ignore_label=0,
                 crop_size=(512, 512),
                 scale_factor=16, #multi scale usato come data augmentation alredy provided
                 mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225],
                 bd_dilate_size=4,
                 weighted=True,
                 transform=None):

        # estende il base_dataset
        super(LoveDA, self).__init__(ignore_label, crop_size, scale_factor, mean, std)

        self.root = root
        self.list_path = list_path
        self.num_classes = num_classes
        self.flip = flip
        self.ignore_label = ignore_label
        self.scale_factor = scale_factor
        self.bd_dilate_size = bd_dilate_size

        self.img_list = [line.strip().split() for line in open(root + list_path)]
        self.files = self.read_files()
        self.color_list = [[0, 0, 0], [1, 1, 1], [2, 2, 2],
                            [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6], [7, 7, 7]]
        self.class_weights = None
        if weighted:
            self.class_weights = torch.tensor([0.000000, 0.116411, 0.266041, 0.607794, 1.511413, 0.745507, 0.712438, 3.040396])
        
        self.transform=transform

    def read_files(self):
        files = []

        for item in self.img_list:
            image_path, label_path = item
            name = os.path.splitext(os.path.basename(label_path))[0]
            files.append({
                "img": image_path,
                "label": label_path,
                "name": name
            })

        return files

    # da immagine a label
    def color2label(self, color_map):
        label = np.ones(color_map.shape[:2]) * self.ignore_label
        for i, v in enumerate(self.color_list):
            label[(color_map == v).sum(2) == 3] = i

        return label.astype(np.uint8)

    def convert_label(self, label, inverse=False):
        temp = label.copy()
        if inverse:
            for v, k in self.label_mapping.items():
                label[temp == k] = v
        else:
            for k, v in self.label_mapping.items():
                label[temp == k] = v
        return label

    # da label a immagine
    def label2color(self, label):
        color_map = np.zeros(label.shape + (3,))
        for i, v in enumerate(self.color_list):
            color_map[label == i] = self.color_list[i]

        return color_map.astype(np.uint8)

    def __getitem__(self, index):
        item = self.files[index]
        name = item["name"]
        image = cv2.imread(item["img"], cv2.IMREAD_COLOR)

        size = image.shape

        label = cv2.imread(item["label"], cv2.IMREAD_GRAYSCALE)



        #edge (H,W)
        image, label, edge = self.gen_sample(image, label, edge_pad=False,
                                             edge_size=self.bd_dilate_size, city=False, transform=self.transform, show=False) #image diventa (C,H,W)

        return image.copy(), label.copy(), edge.copy(), np.array(size), name

    def single_scale_inference(self, model, image):
        pred = self.inference(model, image)
        return pred

    def save_pred(self, preds, sv_path, name):
        preds = np.asarray(np.argmax(preds.cpu(), axis=1), dtype=np.uint8)
        for i in range(preds.shape[0]):
            pred = self.label2color(preds[i])
            save_img = Image.fromarray(pred)
            save_img.save(os.path.join(sv_path, name[i] + '.png'))

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F



class CrossEntropy(nn.Module):
    def __init__(self, ignore_label=-1, weight=None):
        super(CrossEntropy, self).__init__()
        self.ignore_label = ignore_label
        self.criterion = nn.CrossEntropyLoss(
            weight=weight,
            ignore_index=ignore_label
        )

    def _forward(self, score, target):

        loss = self.criterion(score, target)

        return loss

    def forward(self, score, target):

        balance_weights = [0.4, 1.0]
        sb_weights = 1.0
        if len(balance_weights) == len(score):
            return sum([w * self._forward(x, target) for (w, x) in zip(balance_weights, score)])
        elif len(score) == 1:
            return sb_weights * self._forward(score[0], target)

        else:
            raise ValueError("lengths of prediction and target are not identical!")




class OhemCrossEntropy(nn.Module):
    def __init__(self, ignore_label=-1, thres=0.7,
                 min_kept=100000, weight=None):
        super(OhemCrossEntropy, self).__init__()
        self.thresh = thres
        self.min_kept = max(1, min_kept)
        self.ignore_label = ignore_label
        self.criterion = nn.CrossEntropyLoss(
            weight=weight,
            ignore_index=ignore_label,
            reduction='none'
        )

    def _ce_forward(self, score, target):


        loss = self.criterion(score, target)

        return loss

    def _ohem_forward(self, score, target, **kwargs):

        pred = F.softmax(score, dim=1)
        pixel_losses = self.criterion(score, target).contiguous().view(-1)
        mask = target.contiguous().view(-1) != self.ignore_label

        tmp_target = target.clone()
        tmp_target[tmp_target == self.ignore_label] = 0
        pred = pred.gather(1, tmp_target.unsqueeze(1))
        pred, ind = pred.contiguous().view(-1,)[mask].contiguous().sort()
        min_value = pred[min(self.min_kept, pred.numel() - 1)]
        threshold = max(min_value, self.thresh)

        pixel_losses = pixel_losses[mask][ind]
        pixel_losses = pixel_losses[pred < threshold]
        return pixel_losses.mean()

    def forward(self, score, target):

        if not (isinstance(score, list) or isinstance(score, tuple)):
            score = [score]

        balance_weights = [0.4, 1.0]
        sb_weights = 1.0
        if len(balance_weights) == len(score):
            functions = [self._ce_forward] * \
                (len(balance_weights) - 1) + [self._ohem_forward]
            return sum([
                w * func(x, target)
                for (w, x, func) in zip(balance_weights, score, functions)
            ])

        elif len(score) == 1:
            return sb_weights * self._ohem_forward(score[0], target)

        else:
            raise ValueError("lengths of prediction and target are not identical!")


def weighted_bce(bd_pre, target):
    n, c, h, w = bd_pre.size()
    log_p = bd_pre.permute(0,2,3,1).contiguous().view(1, -1)
    target_t = target.view(1, -1)

    pos_index = (target_t == 1)
    neg_index = (target_t == 0)

    weight = torch.zeros_like(log_p)
    pos_num = pos_index.sum()
    neg_num = neg_index.sum()
    sum_num = pos_num + neg_num
    weight[pos_index] = neg_num * 1.0 / sum_num
    weight[neg_index] = pos_num * 1.0 / sum_num

    loss = F.binary_cross_entropy_with_logits(log_p, target_t, weight, reduction='mean')

    return loss


class BondaryLoss(nn.Module):
    def __init__(self, coeff_bce = 20.0):
        super(BondaryLoss, self).__init__()
        self.coeff_bce = coeff_bce

    def forward(self, bd_pre, bd_gt):

        bce_loss = self.coeff_bce * weighted_bce(bd_pre, bd_gt)
        loss = bce_loss

        return loss

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging
import time
from pathlib import Path

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

class FullModel(nn.Module):

  def __init__(self, model, sem_loss, bd_loss):
    super(FullModel, self).__init__()
    self.model = model
    self.sem_loss = sem_loss
    self.bd_loss = bd_loss

  def pixel_acc(self, pred, label):
    _, preds = torch.max(pred, dim=1)
    valid = (label >= 0).long()
    acc_sum = torch.sum(valid * (preds == label).long())
    pixel_sum = torch.sum(valid)
    acc = acc_sum.float() / (pixel_sum.float() + 1e-10)
    return acc
      
  


  def forward(self, inputs, labels, bd_gt, *args, **kwargs):
    
    outputs = self.model(inputs, *args, **kwargs)

    h, w = labels.size(1), labels.size(2)
    ph, pw = outputs[0].size(2), outputs[0].size(3)
    if ph != h or pw != w:
        for i in range(len(outputs)):
            outputs[i] = F.interpolate(outputs[i], size=(
                h, w), mode='bilinear', align_corners=True)

    acc  = self.pixel_acc(outputs[-2], labels)
    loss_s = self.sem_loss(outputs[:-1], labels)
    loss_b = self.bd_loss(outputs[-1], bd_gt)

    filler = torch.ones_like(labels) * 0
    try:
        bd_label = torch.where(torch.sigmoid(outputs[-1][:, 0, :, :]) > 0.8, labels, filler) # 0.7
        loss_sb = self.sem_loss([outputs[-2]], bd_label)
    except:
        print("Error in loss computation")
        loss_sb = self.sem_loss([outputs[-2]], labels)
    loss = loss_s + loss_b + loss_sb

    return torch.unsqueeze(loss,0), outputs[:-1], acc, [loss_s, loss_b] #outputs[:-1] è una lista di tensori

'''def forward(self, inputs, labels, bd_gt, *args, **kwargs):
    # ——— cast targets to the right dtype
    labels = labels.long()
    # for BCEWithLogitsLoss we need float targets with a channel dim
    bd_gt = bd_gt.float().unsqueeze(1)  

    outputs = self.model(inputs, *args, **kwargs)

    # ——— resize outputs to match labels
    h, w = labels.size(1), labels.size(2)
    if outputs[0].size(2)!=h or outputs[0].size(3)!=w:
        for i in range(len(outputs)):
            outputs[i] = F.interpolate(outputs[i],
                                       size=(h, w),
                                       mode='bilinear',
                                       align_corners=True)

    # ——— pixel accuracy on the penultimate head
    acc = self.pixel_acc(outputs[-2], labels)

    # ——— multi-level segmentation loss
    loss_s = sum(self.sem_loss(out, labels) for out in outputs[:-1])

    # ——— boundary loss (now shapes match: [B,1,H,W] vs [B,1,H,W])
    loss_b = self.bd_loss(outputs[-1], bd_gt)

    # ——— selective semantic loss over high-confidence boundary regions
    filler = torch.zeros_like(labels)
    try:
        mask = torch.sigmoid(outputs[-1][:, 0, ...]) > 0.8
        bd_label = torch.where(mask, labels, filler)
        loss_sb = self.sem_loss(outputs[-2], bd_label)
    except Exception as e:
        loss_sb = self.sem_loss(outputs[-2], labels)

    total_loss = loss_s + loss_b + loss_sb

    return (
        total_loss.unsqueeze(0),
        outputs[:-1],
        acc,
        [loss_s, loss_b, loss_sb]
    )'''



class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.initialized = False
        self.val = None
        self.avg = None
        self.sum = None
        self.count = None

    def initialize(self, val, weight):
        self.val = val
        self.avg = val
        self.sum = val * weight
        self.count = weight
        self.initialized = True

    def update(self, val, weight=1):
        if not self.initialized:
            self.initialize(val, weight)
        else:
            self.add(val, weight)

    def add(self, val, weight):
        self.val = val
        self.sum += val * weight
        self.count += weight
        self.avg = self.sum / self.count

    def value(self):
        return self.val

    def average(self):
        return self.avg

def create_logger(cfg, cfg_name, phase='train'):
    root_output_dir = Path('/kaggle/working/output')

    
    folder_name = "gan"

    # set up logger
    if not root_output_dir.exists():
        print('=> creating {}'.format(root_output_dir))
        root_output_dir.mkdir()

    dataset = 'loveda'
    model = 'pidnet_small'
    cfg_name = os.path.basename('log').split('.')[0]

    final_output_dir = root_output_dir / dataset / cfg_name / folder_name

    print('=> creating {}'.format(final_output_dir))
    final_output_dir.mkdir(parents=True, exist_ok=True)

    time_str = time.strftime('%Y-%m-%d-%H-%M')
    log_file = '{}_{}_{}.log'.format('log', time_str, phase)
    final_log_file = final_output_dir / log_file
    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(filename=str(final_log_file),
                        format=head)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    console = logging.StreamHandler()
    logging.getLogger('').addHandler(console)

    tensorboard_log_dir = Path('log') / dataset / model / \
            ('log' + '_' + time_str)
    print('=> creating {}'.format(tensorboard_log_dir))
    tensorboard_log_dir.mkdir(parents=True, exist_ok=True)

    return logger, str(final_output_dir), str(tensorboard_log_dir)

def get_confusion_matrix(label, pred, size, num_class, ignore=-1):
    """
    Calcute the confusion matrix by given label and pred
    """
    output = pred.cpu().numpy().transpose(0, 2, 3, 1)
    seg_pred = np.asarray(np.argmax(output, axis=3), dtype=np.uint8)
    seg_gt = np.asarray(
    label.cpu().numpy()[:, :size[-2], :size[-1]], dtype=int)

    ignore_index = seg_gt != ignore
    seg_gt = seg_gt[ignore_index]
    seg_pred = seg_pred[ignore_index]

    index = (seg_gt * num_class + seg_pred).astype('int32')
    label_count = np.bincount(index)
    confusion_matrix = np.zeros((num_class, num_class))

    for i_label in range(num_class):
        for i_pred in range(num_class):
            cur_index = i_label * num_class + i_pred
            if cur_index < len(label_count):
                confusion_matrix[i_label,
                                 i_pred] = label_count[cur_index]
    return confusion_matrix

def adjust_learning_rate(optimizer, base_lr, max_iters,
        cur_iters, power=0.9, nbb_mult=10):
    lr = base_lr*((1-float(cur_iters)/max_iters)**(power))
    optimizer.param_groups[0]['lr'] = lr
    if len(optimizer.param_groups) == 2:
        optimizer.param_groups[1]['lr'] = lr * nbb_mult
    return lr

In [None]:
def validate(testloader, model, num_classes=8, ignore_label=0, align_corners=False):
    model.eval()
    # we're going to collect two sets of preds
    confusion_matrix = np.zeros((num_classes, num_classes, 2), dtype=np.float64)
    with torch.no_grad():
        for idx, batch in enumerate(testloader):
            image, label, bd_gts, _, _ = batch
            size = label.size()
            image  = image.cuda()
            label  = label.long().cuda()
            bd_gts = bd_gts.float().cuda()

            losses, preds, _, _ = model(image, label, bd_gts)
            if not isinstance(preds, (list, tuple)):
                preds = [preds]

            for i, x in enumerate(preds):
                x = F.interpolate(
                    x,
                    size=size[-2:],
                    mode='bilinear',
                    align_corners=align_corners
                )
                cm = get_confusion_matrix(
                    label,
                    x,
                    size,
                    num_classes,
                    ignore=ignore_label
                )
                confusion_matrix[..., i] += cm

    # compute IoUs for each head
    mean_ious = []
    for i in range(confusion_matrix.shape[-1]):
        cm = confusion_matrix[..., i]
        pos = cm.sum(1)
        res = cm.sum(0)
        tp  = np.diag(cm)
        ious = tp / np.maximum(1.0, pos + res - tp)
        # drop class 0
        mean_ious.append(ious[1:].mean())

    # return the *second* head's mIoU by default (you can adjust)
    return mean_ious[1], ious[1:]


In [None]:
import os
import itertools
import time
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
import gc
import albumentations as A
# evita fork: usa spawn (più sicuro per DataLoader)
# --- Discriminator definition for multi-level adaptation ---
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        if in_channels == 1:
            self.conv = nn.Sequential(
                nn.Conv2d(1, 1, kernel_size=4, stride=2, padding=1),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(1, 1, kernel_size=4, stride=2, padding=1),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(1, 1, kernel_size=4, stride=2, padding=1),
            )
        else: 
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels, in_channels // 2, kernel_size=4, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(in_channels // 2),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(in_channels // 2, in_channels // 4, kernel_size=4, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(in_channels // 4),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(in_channels // 4, 1, kernel_size=4, stride=2, padding=1)
            )

    def forward(self, x):
        return self.conv(x)

# --- Utility: compute IoU from confusion matrix ---
def compute_iou(conf_matrix, ignore_index=0):
    """
    Compute per-class IoU for *all* classes, but set the ignored one to NaN.
    Returns a list of length N; IoU[ignore_index] == np.nan.
    """
    num_classes = conf_matrix.shape[0]
    ious = [float('nan')] * num_classes
    for cls in range(num_classes):
        if cls == ignore_index:
            continue
        tp = conf_matrix[cls, cls]
        fp = conf_matrix[:, cls].sum() - tp
        fn = conf_matrix[cls, :].sum() - tp
        denom = tp + fp + fn
        ious[cls] = tp / denom if denom > 0 else float('nan')
    return ious




# --- Main training loop ---
def train_domain_adaptation(
    root, list_src, list_tgt, list_val,
    num_classes=8, ignore_label=0,
    batch_size=2, num_epochs=20,
    base_lr=1e-2, 
    device='cuda'
):
    train_trasform = A.Compose([A.ColorJitter(p=0.5)])
    # Datasets and loaders
    src_dataset = LoveDA(root, list_src, num_classes=num_classes, ignore_label=ignore_label, transform=train_trasform)
    tgt_dataset = LoveDA(root, list_tgt, num_classes=num_classes, ignore_label=ignore_label, transform=train_trasform)
    val_dataset = LoveDA(root, list_val, num_classes=num_classes, ignore_label=ignore_label)

    src_loader = DataLoader(src_dataset, batch_size=batch_size, pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
    tgt_loader = DataLoader(tgt_dataset, batch_size=batch_size, pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=False, shuffle=False, num_workers=4)


    target_iter = iter(tgt_loader)
    # Models
    seg_model = get_seg_model()
    seg_model = seg_model.cuda()

    ouput_branches = ['P', 'I']
    
    # FullModel wraps semantic and boundary losses
    # Assuming sem_loss and bd_loss are defined elsewhere 
    class_weights_loveda = torch.tensor([0.000000, 0.116411, 0.266041, 0.607794, 1.511413, 0.745507, 0.712438, 3.040396])
    
    # sem_loss = nn.CrossEntropyLoss(ignore_index=ignore_label, weight=class_weights_loveda)
    sem_loss = OhemCrossEntropy(ignore_label=ignore_label, thres=0.9, min_kept=131072, weight=class_weights_loveda)
    bd_loss = BondaryLoss()
    full_model = FullModel(seg_model, sem_loss, bd_loss).cuda()

    n_discriminators = 3
    
    # Discriminators for two adaptation levels
    # Level 1: intermediate head output channels (P branch)
    if 'P' in ouput_branches:
        d1 = Discriminator(in_channels=num_classes).cuda()

    if 'I' in ouput_branches: 
        # Level 2: main output (I branch combined with P and D)
        d2 = Discriminator(in_channels=num_classes).cuda()
    if 'D' in ouput_branches:
        # Level 3: intermediate head output channels (D branch)
        d3 = Discriminator(in_channels=1).cuda()
        

    # Optimizers
    params_dict = dict(full_model.named_parameters())
    params = [{'params': list(params_dict.values()), 'lr': base_lr}]
    optimizer_seg = torch.optim.SGD(params,
                                lr=base_lr,
                                momentum=0.7,
                                weight_decay=0.0005,
                                )
    if 'P' in ouput_branches:
        optimizer_d1 = optim.Adam(d1.parameters(), lr=1e-4, betas=(0.9, 0.99))
        base_lr_d1 = optimizer_d1.param_groups[0]['lr']
    if 'I' in ouput_branches:
        optimizer_d2 = optim.Adam(d2.parameters(), lr=1e-4, betas=(0.9, 0.99))
        base_lr_d2 = optimizer_d2.param_groups[0]['lr']
        
    if 'D' in ouput_branches:
        optimizer_d3 = optim.Adam(d3.parameters(), lr=1e-4, betas=(0.9, 0.99))
        base_lr_d3 = optimizer_d3.param_groups[0]['lr']
        

    # Adversarial labels
    src_label = 1.0
    tgt_label = 0.0

    
    lambda_adv1=0.0002
    lambda_adv2=0.001
    lambda_adv3=0.00002
    # Training
    max_iters = num_epochs * len(src_loader) # corresponds to num_epoch * num_batches
    cur_iter = 0 # incremented by 1 at each iteration, by the end of the training it will be equal to max_iters
    best_mIoU = 0.0
    for epoch in range(num_epochs):
        seg_model.train()
        if 'P' in ouput_branches:
            d1.train()
            meter_loss_d1 = AverageMeter()
            meter_adv1  = AverageMeter()
            
        if 'I' in ouput_branches:
            d2.train()
            meter_adv2  = AverageMeter()
            meter_loss_d2 = AverageMeter()
        
        if 'D' in ouput_branches:    
            d3.train()
            meter_adv3  = AverageMeter()
            meter_loss_d3 = AverageMeter()
            
        meter_seg   = AverageMeter()
        meter_bd    = AverageMeter()
        meter_acc   = AverageMeter()
        pbar = tqdm(src_loader, total=len(src_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
        for imgs_s, labels_s, edges_s, _, _ in pbar:
            imgs_s = imgs_s.cuda(); labels_s = labels_s.long().cuda(); edges_s = edges_s.float().cuda()
            try:
                imgs_t, _, _, _, _ = next(target_iter)
            except StopIteration:
                target_iter = iter(tgt_loader)
                imgs_t, _, _, _, _ = next(target_iter)
            imgs_t = imgs_t.cuda()

            # ----- Train Discriminators -----
            with torch.no_grad():
                out_s1, out_s2, out_s3 = seg_model(imgs_s)
                out_t1, out_t2, out_t3 = seg_model(imgs_t)
            if 'P' in ouput_branches:
                prob_s1 = F.softmax(out_s1.detach(), dim=1)
                prob_t1 = F.softmax(out_t1.detach(), dim=1)
            if 'I' in ouput_branches:
                prob_s2 = F.softmax(out_s2.detach(), dim=1)
                prob_t2 = F.softmax(out_t2.detach(), dim=1)
            if 'D' in ouput_branches:
                prob_s3 = torch.sigmoid(out_s3.detach())
                prob_t3 = torch.sigmoid(out_t3.detach())

            # D1
            if 'P' in ouput_branches:
                optimizer_d1.zero_grad()
                pred_d_s1 = d1(prob_s1)
                pred_d_t1 = d1(prob_t1)
                loss_d1 = 0.5 * (F.binary_cross_entropy_with_logits(pred_d_s1, torch.full_like(pred_d_s1, src_label)) +
                                  F.binary_cross_entropy_with_logits(pred_d_t1, torch.full_like(pred_d_t1, tgt_label)))
                loss_d1.backward()
                optimizer_d1.step()
                meter_loss_d1.update(loss_d1.item())

            # D2
            if 'I' in ouput_branches:
                optimizer_d2.zero_grad()
                pred_d_s2 = d2(prob_s2)
                pred_d_t2 = d2(prob_t2)
                loss_d2 = 0.5 * (F.binary_cross_entropy_with_logits(pred_d_s2, torch.full_like(pred_d_s2, src_label)) +
                                  F.binary_cross_entropy_with_logits(pred_d_t2, torch.full_like(pred_d_t2, tgt_label)))
                loss_d2.backward()
                optimizer_d2.step()
                meter_loss_d2.update(loss_d2.item())
                
               
            # D3
            if 'D' in ouput_branches:
                optimizer_d3.zero_grad()
                pred_d_s3 = d3(prob_s3)
                pred_d_t3 = d3(prob_t3)
                loss_d3 = 0.5 * (F.binary_cross_entropy_with_logits(pred_d_s3, torch.full_like(pred_d_s3, src_label)) +
                                  F.binary_cross_entropy_with_logits(pred_d_t3, torch.full_like(pred_d_t3, tgt_label)))
                loss_d3.backward()
                optimizer_d3.step()
                meter_loss_d3.update(loss_d3.item())
                
            # ----- Train Segmentation Network -----
            optimizer_seg.zero_grad()
            loss_seg, seg_outputs, acc, loss_components = full_model(imgs_s, labels_s, edges_s)
            
            # Adversarial on target
            ot1, ot2, ot3 = seg_model(imgs_t)

            if 'P' in ouput_branches:
                prob_t1 = F.softmax(ot1, dim=1)
                pred_d_t1_for_seg = d1(prob_t1)
                loss_adv1 = F.binary_cross_entropy_with_logits(pred_d_t1_for_seg, torch.full_like(pred_d_t1_for_seg, src_label)) # predictions from target, but source ground truth
                
            if 'I' in ouput_branches:
                prob_t2 = F.softmax(ot2, dim=1)
                pred_d_t2_for_seg = d2(prob_t2)
                loss_adv2 = F.binary_cross_entropy_with_logits(pred_d_t2_for_seg, torch.full_like(pred_d_t2_for_seg, src_label)) # predictions from target, but source ground truth
                
            if 'D' in ouput_branches:
                prob_t3 = torch.sigmoid(ot3)
                pred_d_t3_for_seg = d3(prob_t3)
                loss_adv3 = F.binary_cross_entropy_with_logits(pred_d_t3_for_seg, torch.full_like(pred_d_t3_for_seg, src_label)) # predictions from target, but source ground truth

            loss_total = loss_seg.mean()
            
            if 'P' in ouput_branches:
                loss_total += lambda_adv1 * loss_adv1 
            if 'I' in ouput_branches:
                loss_total += lambda_adv2 * loss_adv2
            if 'D' in ouput_branches:
                loss_total += lambda_adv3 * loss_adv3
            
            loss_total.backward()
            optimizer_seg.step()

            meter_seg.update( loss_components[0].mean().item() )
            meter_bd .update( loss_components[1].mean().item() )

            if 'P' in ouput_branches:
                meter_adv1.update( loss_adv1.item() )
            if 'I' in ouput_branches:
                meter_adv2.update( loss_adv2.item() )
            if 'D' in ouput_branches:
                meter_adv3.update( loss_adv3.item() )
            
            meter_acc.update( acc.item() )

            # Update learning rate
            cur_iter += 1
            adjust_learning_rate(optimizer_seg, base_lr, max_iters, cur_iter)
            if 'P' in ouput_branches:
                adjust_learning_rate(optimizer_d1, base_lr_d1, max_iters, cur_iter)
            if 'I' in ouput_branches:
                adjust_learning_rate(optimizer_d2, base_lr_d2, max_iters, cur_iter)
            if 'D' in ouput_branches:
                adjust_learning_rate(optimizer_d3, base_lr_d3, max_iters, cur_iter)

            display_dict = {
                    'seg_loss': loss_components[0].mean().item(),
                    'bd_loss': loss_components[1].mean().item(),
                    'acc': acc.item()
            }
            if 'P' in ouput_branches:
                display_dict['adv1'] = loss_adv1.item()
                display_dict['loss_d1'] = loss_d1.item()
                
            if 'I' in ouput_branches:
                display_dict['adv2'] = loss_adv2.item()
                display_dict['loss_d2'] = loss_d2.item()
                
            if 'D' in ouput_branches:
                display_dict['adv3'] = loss_adv3.item()
                display_dict['loss_d3'] = loss_d3.item()

            # Display losses
            pbar.set_postfix(display_dict)

        log_str = (
            f"Epoch {epoch+1} Train Avg — "
            f"seg_loss: {meter_seg.average():.4f}, "
            f"bd_loss: {meter_bd.average():.4f}, "
            f"acc: {meter_acc.average():.4f}, "
        )
        
        if 'P' in ouput_branches:
            log_str += f"adv1: {meter_adv1.average():.4f}, "
            log_str += f"loss_d1: {meter_loss_d1.average():.4f}, "    
        
        if 'I' in ouput_branches:
            log_str += f"adv2: {meter_adv2.average():.4f}, "
            log_str += f"loss_d2: {meter_loss_d2.average():.4f}, "      
        
        if 'D' in ouput_branches:
            log_str += f"adv3: {meter_adv3.average():.4f}, " 
            log_str += f"loss_d3: {meter_loss_d3.average():.4f}, " 
        
        print(log_str)
        
        torch.cuda.empty_cache()       # se usi GPU
        gc.collect()                   # raccogli il garbage Python


       # ——— (3) Validation on the FULL model ———
        print(f"\nEpoch {epoch+1} Validation (full_model):")
        mIoU, cls_ious = validate(
            val_loader,
            full_model,
            num_classes=num_classes,
            ignore_label=ignore_label,
            align_corners=True 
        )
        print(f"  ==> full_model mIoU: {mIoU:.4f}")
        for idx, iou in enumerate(cls_ious, start=1):
            print(f"    Class {idx:2d}: IoU = {iou:.4f}")

        final_output_dir = os.path.join('/kaggle/working', 'best')
        os.makedirs(final_output_dir, exist_ok=True)
        
        if mIoU > best_mIoU:
            best_mIoU = mIoU
            torch.save(full_model.state_dict(), os.path.join(final_output_dir, 'best_full_model.pt'))

        

if __name__ == '__main__':
    # Example usage
    train_domain_adaptation(
        root='/kaggle/working/list/',
        list_src='urban/train.lst',
        list_tgt='rural/train.lst',
        list_val='rural/val.lst',
        num_classes=8,
        ignore_label=0,
        batch_size=6,
        num_epochs=20,
        base_lr=1e-2,
        device='cuda'
    )

In [None]:
import torch
from torch.utils.data import DataLoader
import os

# ---------- 1.  Paths & basic config ----------
root      = '/kaggle/working/list/'
list_val  = 'rural/val.lst'
ckpt_path = '/kaggle/working/best/best_full_model.pt'
batch_sz  = 6
num_cls   = 8
ignore_lb = 0
device    = 'cuda'

# ---------- 2.  Rebuild the model wrapper ----------
seg_model = get_seg_model()                    
seg_model = seg_model.to(device)

class_weights_loveda = torch.tensor(
    [0.000000, 0.116411, 0.266041, 0.607794,
     1.511413, 0.745507, 0.712438, 3.040396],
    device=device
)
sem_loss = OhemCrossEntropy(
    ignore_label=ignore_lb, thres=0.9, min_kept=131072,
    weight=class_weights_loveda
)
bd_loss  = BondaryLoss()

full_model = FullModel(seg_model, sem_loss, bd_loss).to(device)

# ---------- 3.  Load the checkpoint ----------
state_dict = torch.load(ckpt_path, map_location='cpu')
full_model.load_state_dict(state_dict, strict=True)
full_model.eval()                               # important!

# ---------- 4.  Validation dataset & loader ----------
val_dataset = LoveDA(
    root, list_val,
    num_classes=num_cls,
    ignore_label=ignore_lb,
    transform=None
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_sz,
    shuffle=False,
    num_workers=4,
    pin_memory=False
)


In [None]:
# ---------- 5.  Run the evaluation of the best model----------
mIoU, cls_ious = validate(
    val_loader,
    full_model,
    num_classes=num_cls,
    ignore_label=ignore_lb,
    align_corners=True       
)

print(f'\nBest-checkpoint mIoU: {mIoU:0.4f}')
for idx, iou in enumerate(cls_ious, start=1):
    print(f'   class {idx:2d}: IoU = {iou:0.4f}'

# Adversarial Attack
According to the authors of "Learning to Adapt Structured Output Space for Semantic Segmentation", the model should have learned to structure the output, whether the model comes from the source or the target domain. We test whether the model actually learned to force some segmentation map structure by using the Fast Gradient Descent method and looking the drop in performance. If the model actually forces a segmentation map desoite the input, the drop in performance should be lower with respect to the other models trained using data augmentation alone and data augmentation + color jitter

In [None]:
# ------------------ 1. Setup ------------------
import torch, random, numpy as np, matplotlib.pyplot as plt
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

CKPT_PATH = '/kaggle/working/best/best_full_model.pt'   # <- your saved model
NUM_CLASSES = 8
IGNORE_LABEL = 0

# class weights used during training
CLASS_WEIGHTS = torch.tensor(
    [0.000000, 0.116411, 0.266041, 0.607794,
     1.511413, 0.745507, 0.712438, 3.040396],
    device=device
)

# ------------------ 2. Rebuild + load ------------------
seg_model   = get_seg_model().to(device)
sem_loss    = OhemCrossEntropy(ignore_label=IGNORE_LABEL,
                               thres=0.9, min_kept=131072,
                               weight=CLASS_WEIGHTS).to(device)
bd_loss     = BondaryLoss().to(device)

full_model  = FullModel(seg_model, sem_loss, bd_loss).to(device)
full_model.load_state_dict(torch.load(CKPT_PATH, map_location='cpu'), strict=True)
full_model.eval()

print('✓ checkpoint loaded')

# ------------------ 4. FGSM attack helpers ------------------
def fgsm_attack(model, dataloader, epsilon=0.03,
                ignore_index=0, max_samples=200):
    """Return a batch of adversarially-perturbed images + labels."""
    model.eval()
    adv_imgs, adv_lbls = [], []
    collected = 0

    for imgs, lbls, *_ in tqdm(dataloader, desc='FGSM'):
        imgs, lbls = imgs.to(device).float(), lbls.to(device).long()
        imgs.requires_grad = True

        logits = model.model(imgs)[-2]              # underlying PIDNet output
        logits = F.interpolate(logits, size=lbls.shape[1:],
                               mode='bilinear', align_corners=True)
        loss = F.cross_entropy(logits, lbls,
                               weight=CLASS_WEIGHTS,
                               ignore_index=ignore_index)
        model.zero_grad(); loss.backward()

        # un-normalise, add perturbation, re-normalise
        mean = torch.tensor([0.485, 0.456, 0.406],
                            device=device).view(1,3,1,1)
        std  = torch.tensor([0.229, 0.224, 0.225],
                            device=device).view(1,3,1,1)
        img_unnorm = imgs * std + mean
        adv_unnorm = (img_unnorm + epsilon * imgs.grad.sign()).clamp(0,1)
        adv_imgs_b = (adv_unnorm - mean) / std

        for ai, lb in zip(adv_imgs_b.cpu(), lbls.cpu()):
            adv_imgs.append(ai); adv_lbls.append(lb)
            collected += 1
            if collected >= max_samples:
                return torch.stack(adv_imgs), torch.stack(adv_lbls)

    return torch.stack(adv_imgs), torch.stack(adv_lbls)

def evaluate_on_adv(model, adv_imgs, lbls, num_classes):
    """Mean IoU on a batch of adversarial samples (ignoring class-0)."""
    model.eval()
    preds = []
    with torch.no_grad():
        for i in range(0, len(adv_imgs), 4):
            x = adv_imgs[i:i+4].to(device)
            logits = model.model(x)[-2]
            logits = F.interpolate(logits, size=lbls.shape[1:],
                                   mode='bilinear', align_corners=True)
            preds.append(logits.argmax(1).cpu())
    preds = torch.cat(preds)

    ious = []
    for cls in range(1, num_classes):
        pred_i = preds == cls; lbl_i = lbls[:len(preds)] == cls
        inter  = (pred_i & lbl_i).sum().item()
        union  = (pred_i | lbl_i).sum().item()
        if union > 0: ious.append(inter / union)
    print(f'\nFGSM mIoU (ε={epsilon}): {np.mean(ious):.4f}')

def decode_segmap(segmentation):
    cmap = np.array([
        [0,0,0], [128,64,128], [70,70,70], [128,0,0],
        [0,0,255], [153,153,153], [0,128,0], [255,255,0]
    ], dtype=np.uint8)
    rgb = cmap[segmentation]
    return rgb

def visualise_examples(model, dataloader, adv_imgs, adv_lbls, k=4):
    import matplotlib.pyplot as plt
    idxs = random.sample(range(len(adv_imgs)), k)
    for idx in idxs:
        # original sample from loader
        img_np, gt_lbl, *_ = dataloader.dataset[idx]
        img_t = torch.from_numpy(img_np).unsqueeze(0).float().to(device)
        gt_np = gt_lbl

        with torch.no_grad():
            logits_clean = model.model(img_t)[-2]
            logits_clean = F.interpolate(logits_clean, size=gt_np.shape,
                                         mode='bilinear', align_corners=True)
            pred_clean = logits_clean.argmax(1).squeeze(0).cpu().numpy()

            logits_adv = model.model(adv_imgs[idx:idx+1].to(device))[-2]
            logits_adv = F.interpolate(logits_adv, size=gt_np.shape,
                                       mode='bilinear', align_corners=True)
            pred_adv = logits_adv.argmax(1).squeeze(0).cpu().numpy()

        # un-normalise RGB
        mean = np.array([0.485, 0.456, 0.406])
        std  = np.array([0.229, 0.224, 0.225])
        rgb_img = (img_np.transpose(1,2,0) * std + mean).clip(0,1)

        fig, ax = plt.subplots(1,4, figsize=(18,6))
        for a,title,data in zip(ax,
            ['RGB', 'Prediction', 'FGSM Prediction', 'GT'],
            [rgb_img, decode_segmap(pred_clean),
             decode_segmap(pred_adv), decode_segmap(gt_np)]):
            a.imshow(data); a.set_title(title); a.axis('off')
        plt.show()

# ------------------ 5. FGSM run ------------------
epsilon = 0.04            # set your ε here
adv_imgs, adv_lbls = fgsm_attack(
    full_model, val_loader,
    epsilon=epsilon,
    ignore_index=IGNORE_LABEL,
    max_samples=200
)
evaluate_on_adv(full_model, adv_imgs, adv_lbls, NUM_CLASSES)
visualise_examples(full_model, val_loader, adv_imgs, adv_lbls, k=6)


In [None]:
shutil.make_archive('/kaggle/working/checkpoints', 'zip', '/kaggle/working/') 
shutil.make_archive('/kaggle/working/best', 'zip', '/kaggle/working/') 