In [None]:
import shutil
import os

os.makedirs('/root/.config/kaggle', exist_ok=True)
shutil.copy('/kaggle/input/kaggle-api/kaggle.json', '/root/.config/kaggle/kaggle.json')

In [None]:
!pip install -q albumentations==1.4.18

In [None]:
!pip install yacs
!pip install tensorboardX

In [None]:
!pip install -q kaggle

import os
import json
import shutil
import subprocess
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
dataset_id = user_secrets.get_secret("DATASET_ID")


# Step 1: Check if dataset already exists
try:
    # Run the Kaggle API command to check dataset
    result = subprocess.run(
        ["kaggle", "datasets", "metadata", "-d", dataset_id],
        capture_output=True,
        text=True
    )

    if result.returncode == 0:
        print(f"✅ Dataset already exists: {dataset_id} — Skipping creation.")
    else:
        raise FileNotFoundError

except FileNotFoundError:
    # Download the zip files
    !wget https://zenodo.org/records/5706578/files/Train.zip?download=1 -O Train.zip
    !wget https://zenodo.org/records/5706578/files/Val.zip?download=1 -O Val.zip
    
    # Unzip them
    !unzip -q Train.zip -d splits/
    !unzip -q Val.zip -d splits/
    
    # Step 2: Organize files
    os.makedirs('zips/', exist_ok=True)
    shutil.copy('Train.zip', 'zips')
    shutil.copy('Val.zip', 'zips')

    # Step 3: Create metadata
    dataset_metadata = {
        "title": "LoveDA",
        "id": dataset_id,
        "licenses": [{"name": "CC0-1.0"}]
    }
    with open('zips/dataset-metadata.json', 'w') as f:
        json.dump(dataset_metadata, f)

    # Step 4: Create the dataset
    !kaggle datasets create -p zips
    

In [None]:
import os

base_kaggle_path = "/kaggle/working/"
folder_names = [
    'data/loveda/train/',
    'data/loveda/val/',
    'data/list/loveda/rural',
    'data/list/loveda/urban_rural',
    'data/list/loveda/urban_urban',
    'pretrained_models/imagenet'
]

for relative_path in folder_names:
    full_path = os.path.join(base_kaggle_path, relative_path)
    if not os.path.exists(full_path):
        os.makedirs(full_path)
        print(f"Created: {full_path}")
    else:
        print(f"Already exists: {full_path}")



In [None]:
import shutil
import os

list_src = [
    '/kaggle/input/loveda-splits/Train/Train/Rural',
    '/kaggle/input/loveda-splits/Train/Train/Urban',
    '/kaggle/input/loveda-splits/Val/Val/Rural',
    '/kaggle/input/loveda-splits/Val/Val/Urban'
]

list_dst = [
    '/kaggle/working/data/loveda/train/Rural',
    '/kaggle/working/data/loveda/train/Urban',
    '/kaggle/working/data/loveda/val/Rural',
    '/kaggle/working/data/loveda/val/Urban'
]

for src, dst in zip(list_src, list_dst):
    if not os.path.exists(dst):
        os.makedirs(os.path.dirname(dst), exist_ok=True)
        shutil.copytree(src, dst)
        print(f"Copied: {src} → {dst}")
    else:
        print(f"Skipped (already exists): {dst}")


In [None]:
shutil.copytree('/kaggle/input/configs/configs', '/kaggle/working/configs/')

In [None]:
shutil.copy('/kaggle/input/pidnet-pretrained/PIDNet_S_ImageNet.pth.tar', '/kaggle/working/pretrained_models/imagenet')

In [None]:
%cd '/kaggle/working'

In [None]:
import os

def create_lst_file(image_dir, label_dir, output_lst):
    # List and sort files numerically
    images = sorted(os.listdir(image_dir), key=lambda x: int(os.path.splitext(x)[0]))
    labels = sorted(os.listdir(label_dir), key=lambda x: int(os.path.splitext(x)[0]))

    os.makedirs(os.path.dirname(output_lst), exist_ok=True)  # Ensure the directory exists

    with open(output_lst, 'w') as f:
        for img, lbl in zip(images, labels):
            # Generate full paths and normalize to use forward slashes
            img_path = os.path.join(image_dir, img).replace("\\", "/")
            lbl_path = os.path.join(label_dir, lbl).replace("\\", "/")
            # Write formatted line with consistent spacing
            f.write(f"{img_path} {lbl_path}\n")

# Paths to the LoveDA dataset directories
urban_train_image_dir = "data/loveda/train/Urban/images_png"
urban_train_label_dir = "data/loveda/train/Urban/masks_png"

urban_test_image_dir = "data/loveda/val/Urban/images_png"
urban_test_label_dir = "data/loveda/val/Urban/masks_png"

rural_train_image_dir = "data/loveda/train/Rural/images_png"
rural_train_label_dir = "data/loveda/train/Rural/masks_png"

rural_test_image_dir = "data/loveda/val/Rural/images_png"
rural_test_label_dir = "data/loveda/val/Rural/masks_png"



# train on urban - test on urban
train_lst_path = "data/list/loveda/urban_urban/train.lst"
test_lst_path = "data/list/loveda/urban_urban/val.lst"

# Create .lst files
create_lst_file(urban_train_image_dir, urban_train_label_dir, train_lst_path)
create_lst_file(urban_test_image_dir, urban_test_label_dir, test_lst_path)

train_lst_path = "data/list/loveda/urban_rural/train.lst"
target_lst_path = "data/list/loveda/rural/train.lst"
test_lst_path = "data/list/loveda/urban_rural/val.lst"

# Create .lst files
create_lst_file(urban_train_image_dir, urban_train_label_dir, train_lst_path)
create_lst_file(rural_train_image_dir, rural_train_label_dir, target_lst_path)
create_lst_file(rural_test_image_dir, rural_test_label_dir, test_lst_path)

# DeepLab implementation

In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet101
affine_par = True


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
        self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
        for i in self.bn1.parameters():
            i.requires_grad = False
        padding = dilation
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                               padding=padding, bias=False, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
        for i in self.bn2.parameters():
            i.requires_grad = False
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)
        for i in self.bn3.parameters():
            i.requires_grad = False
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)

        return out


class ClassifierModule(nn.Module):
    """ 
    This module implements the Atrous Spatial Pyramid Pooling (ASPP).
    Parameters:
        - inplanes: Number of input channels.
        - dilation_series: List of dilation rates for the convolutions.
        - padding_series: List of padding values corresponding to the dilation rates.
        - num_classes: Number of output classes for semantic segmentation.
    """ 
    def __init__(self, inplanes, dilation_series, padding_series, num_classes):
        super(ClassifierModule, self).__init__()
        self.conv2d_list = nn.ModuleList()
        for dilation, padding in zip(dilation_series, padding_series):
            self.conv2d_list.append(
                nn.Conv2d(inplanes, num_classes, kernel_size=3, stride=1, padding=padding,
                          dilation=dilation, bias=True))

        for m in self.conv2d_list:
            m.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.conv2d_list[0](x)
        for i in range(len(self.conv2d_list) - 1):
            out += self.conv2d_list[i + 1](x)
        return out


class ResNetMulti(nn.Module):
    """
    This is the main DeepLabv2 model. It uses a ResNet backbone (with dilation) for feature extraction, 
    followed by a classification module (ASPP).

    Layers: 
        - Convolutional Stem:
            - conv1: Initial 7x7 convolution with stride 2, reducing spatial resolution.
            - bn1: Batch normalization for the first layer.
            - relu: ReLU activation.
            - maxpool: Reduces resolution further with a 3x3 max-pooling layer.
        - ResNet Layers:
            - layer1 to layer4: Residual layers (ResNet blocks) with bottleneck structures.
            - layer3 and layer4 use dilated convolutions for larger receptive fields, instead of downsampling.
            - dilation=2 for layer3 and dilation=4 for layer4.
        - ASPP:
            - Implemented as layer6 (using ClassifierModule).
            - Applies parallel atrous convolutions with dilation rates [6, 12, 18, 24] to the high-level feature map from layer4.

    Forward Pass: the forward pass processes an input tensor through the following steps:
        - Extract low-level features using conv1, bn1, relu, and maxpool.
        - Pass through ResNet blocks (layer1 to layer4) to extract high-level features.
        - Use ASPP (layer6) to capture multi-scale context.
        - Upsample the final predictions back to the input resolution using bilinear interpolation.
    Output
        - During training: Returns segmentation predictions.
        - During inference: Returns predictions resized to match the input size
    """
    def __init__(self, block, layers, num_classes, multi_level=False):

        """
        Layers: 
            - Convolutional Stem:
                - conv1: Initial 7x7 convolution with stride 2, reducing spatial resolution.
                - bn1: Batch normalization for the first layer.
                - relu: ReLU activation.
                - maxpool: Reduces resolution further with a 3x3 max-pooling layer.
            - ResNet Layers:
                - layer1 to layer4: Residual layers (ResNet blocks) with bottleneck structures.
                - layer3 and layer4 use dilated convolutions for larger receptive fields, instead of downsampling.
                - dilation=2 for layer3 and dilation=4 for layer4.
            - ASPP:
                - Implemented as layer6 (using ClassifierModule).
                - Applies parallel atrous convolutions with dilation rates [6, 12, 18, 24] to the high-level feature map from layer4.
        """
        self.multi_level = multi_level
        self.inplanes = 64
        super(ResNetMulti, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64, affine=affine_par)
        for i in self.bn1.parameters():
            i.requires_grad = False
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)  # change
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
        self.layer6 = ClassifierModule(2048, [6, 12, 18, 24], [6, 12, 18, 24], num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.01)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):

        """ 
        This method is responsible for creating a ResNet block.
        Arguments:
            - block:
                - The class of the residual block to use (e.g., Bottleneck).
                - A residual block consists of multiple convolutional layers with a skip connection.
            - planes:
                - The number of output channels for the block. The block will typically expand this to planes * expansion.
                - For example, in Bottleneck class we have expansion = 4. Hence the number of output channels is planes * 4. 
            - blocks:
                - The number of residual layers (or blocks) to include in this layer group.
                - For example, in ResNet-101, layer3 has 23 blocks.
            - stride:
                - Controls the downsampling factor for the first block in the layer. A stride > 1 reduces the spatial resolution.
            - dilation:
                - Specifies the dilation rate for convolutions in the block.
                - Used in layer3 and layer4 to replace downsampling with atrous convolution, maintaining spatial resolution while expanding the receptive field. 
        """
        downsample = None
        if (stride != 1
                or self.inplanes != planes * block.expansion
                or dilation == 2
                or dilation == 4):
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, affine=affine_par))
        for i in downsample._modules['1'].parameters():
            i.requires_grad = False
        layers = []
        layers.append(
            block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        _, _, H, W = x.size()

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer6(x)

        x = torch.nn.functional.interpolate(x, size=(H, W), mode='bilinear')

        if self.training == True:
            return x

        return x

    def get_1x_lr_params_no_scale(self):
        """
        This generator returns all the parameters of the net except for
        the last classification layer. Note that for each batchnorm layer,
        requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
        any batchnorm parameter
        """
        b = []

        b.append(self.conv1)
        b.append(self.bn1)
        b.append(self.layer1)
        b.append(self.layer2)
        b.append(self.layer3)
        b.append(self.layer4)

        for i in range(len(b)):
            for j in b[i].modules():
                jj = 0
                for k in j.parameters():
                    jj += 1
                    if k.requires_grad:
                        yield k

    def get_10x_lr_params(self):
        """
        This generator returns all the parameters for the last layer of the net,
        which does the classification of pixel into classes
        """
        b = []
        if self.multi_level:
            b.append(self.layer5.parameters())
        b.append(self.layer6.parameters())

        for j in range(len(b)):
            for i in b[j]:
                yield i

    def optim_parameters(self, lr):
        return [{'params': self.get_1x_lr_params_no_scale(), 'lr': lr},
                {'params': self.get_10x_lr_params(), 'lr': 10 * lr}]


def get_deeplab_v2(num_classes=19, pretrain=True, pretrain_model_path='DeepLab_resnet_pretrained_imagenet.pth'):
    """ model = ResNetMulti(Bottleneck, [3, 4, 23, 3], num_classes)

    # Pretraining loading
    if pretrain:
        resnet_model = resnet101(pretrained=True)
        saved_state_dict = resnet_model.state_dict()

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            i_parts = i.split('.')
            new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        model.load_state_dict(new_params, strict=False)

    return model """

    model = ResNetMulti(Bottleneck, [3, 4, 23, 3], num_classes)

    # Pretraining loading from torchvision's ResNet-101
    if pretrain:
        print('Loading ResNet-101 pretrained weights...')
        resnet_model = resnet101(pretrained=True)
        saved_state_dict = resnet_model.state_dict()

        # Load pretrained weights into DeepLab model
        new_params = model.state_dict().copy()
        for key in saved_state_dict:
            # Remove 'layer' prefixes or other modifications if necessary
            modified_key = key.split('.', 1)[-1] if key.startswith('layer') else key
            if modified_key in new_params:
                new_params[modified_key] = saved_state_dict[key]

        model.load_state_dict(new_params, strict=False)
        print('Pretrained weights loaded successfully.')

    return model

In [None]:
deeplab_model = get_deeplab_v2(num_classes=7)

_init_paths.py

In [None]:
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Ke Sun (sunk@mail.ustc.edu.cn)
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path as osp
import sys


def add_path(path):
    if path not in sys.path:
        sys.path.insert(0, path)


model_utils.py

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

BatchNorm2d = nn.BatchNorm2d
bn_mom = 0.1
algc = False

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=False):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               padding=1, bias=False)
        self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
        self.downsample = downsample
        self.stride = stride
        self.no_relu = no_relu

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        if self.no_relu:
            return out
        else:
            return self.relu(out)

class Bottleneck(nn.Module):
    expansion = 2

    def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=True):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = BatchNorm2d(planes * self.expansion, momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.no_relu = no_relu

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        if self.no_relu:
            return out
        else:
            return self.relu(out)

class segmenthead(nn.Module):

    def __init__(self, inplanes, interplanes, outplanes, scale_factor=None):
        super(segmenthead, self).__init__()
        self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom)
        self.conv1 = nn.Conv2d(inplanes, interplanes, kernel_size=3, padding=1, bias=False)
        self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(interplanes, outplanes, kernel_size=1, padding=0, bias=True)
        self.scale_factor = scale_factor

    def forward(self, x):

        x = self.conv1(self.relu(self.bn1(x)))
        out = self.conv2(self.relu(self.bn2(x)))

        if self.scale_factor is not None:
            height = x.shape[-2] * self.scale_factor
            width = x.shape[-1] * self.scale_factor
            out = F.interpolate(out,
                        size=[height, width],
                        mode='bilinear', align_corners=algc)

        return out

class DAPPM(nn.Module):
    def __init__(self, inplanes, branch_planes, outplanes, BatchNorm=nn.BatchNorm2d):
        super(DAPPM, self).__init__()
        bn_mom = 0.1
        self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale0 = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.process1 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.process2 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.process3 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.process4 = nn.Sequential(
                                    BatchNorm(branch_planes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False),
                                    )
        self.compression = nn.Sequential(
                                    BatchNorm(branch_planes * 5, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False),
                                    )
        self.shortcut = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False),
                                    )

    def forward(self, x):
        width = x.shape[-1]
        height = x.shape[-2]
        x_list = []

        x_list.append(self.scale0(x))
        x_list.append(self.process1((F.interpolate(self.scale1(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[0])))
        x_list.append((self.process2((F.interpolate(self.scale2(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[1]))))
        x_list.append(self.process3((F.interpolate(self.scale3(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[2])))
        x_list.append(self.process4((F.interpolate(self.scale4(x),
                        size=[height, width],
                        mode='bilinear', align_corners=algc)+x_list[3])))

        out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x)
        return out

class PAPPM(nn.Module):
    def __init__(self, inplanes, branch_planes, outplanes, BatchNorm=nn.BatchNorm2d):
        super(PAPPM, self).__init__()
        bn_mom = 0.1
        self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )
        self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )

        self.scale0 = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
                                    )

        self.scale_process = nn.Sequential(
                                    BatchNorm(branch_planes*4, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes*4, branch_planes*4, kernel_size=3, padding=1, groups=4, bias=False),
                                    )


        self.compression = nn.Sequential(
                                    BatchNorm(branch_planes * 5, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False),
                                    )

        self.shortcut = nn.Sequential(
                                    BatchNorm(inplanes, momentum=bn_mom),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False),
                                    )


    def forward(self, x):
        width = x.shape[-1]
        height = x.shape[-2]
        scale_list = []

        x_ = self.scale0(x)
        scale_list.append(F.interpolate(self.scale1(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)
        scale_list.append(F.interpolate(self.scale2(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)
        scale_list.append(F.interpolate(self.scale3(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)
        scale_list.append(F.interpolate(self.scale4(x), size=[height, width],
                        mode='bilinear', align_corners=algc)+x_)

        scale_out = self.scale_process(torch.cat(scale_list, 1))

        out = self.compression(torch.cat([x_,scale_out], 1)) + self.shortcut(x)
        return out


class PagFM(nn.Module):
    def __init__(self, in_channels, mid_channels, after_relu=False, with_channel=False, BatchNorm=nn.BatchNorm2d):
        super(PagFM, self).__init__()
        self.with_channel = with_channel
        self.after_relu = after_relu
        self.f_x = nn.Sequential(
                                nn.Conv2d(in_channels, mid_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(mid_channels)
                                )
        self.f_y = nn.Sequential(
                                nn.Conv2d(in_channels, mid_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(mid_channels)
                                )
        if with_channel:
            self.up = nn.Sequential(
                                    nn.Conv2d(mid_channels, in_channels,
                                              kernel_size=1, bias=False),
                                    BatchNorm(in_channels)
                                   )
        if after_relu:
            self.relu = nn.ReLU(inplace=True)

    def forward(self, x, y):
        input_size = x.size()
        if self.after_relu:
            y = self.relu(y)
            x = self.relu(x)

        y_q = self.f_y(y)
        y_q = F.interpolate(y_q, size=[input_size[2], input_size[3]],
                            mode='bilinear', align_corners=False)
        x_k = self.f_x(x)

        if self.with_channel:
            sim_map = torch.sigmoid(self.up(x_k * y_q))
        else:
            sim_map = torch.sigmoid(torch.sum(x_k * y_q, dim=1).unsqueeze(1))

        y = F.interpolate(y, size=[input_size[2], input_size[3]],
                            mode='bilinear', align_corners=False)
        x = (1-sim_map)*x + sim_map*y

        return x

class Light_Bag(nn.Module):
    def __init__(self, in_channels, out_channels, BatchNorm=nn.BatchNorm2d):
        super(Light_Bag, self).__init__()
        self.conv_p = nn.Sequential(
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )
        self.conv_i = nn.Sequential(
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )

    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)

        p_add = self.conv_p((1-edge_att)*i + p)
        i_add = self.conv_i(i + edge_att*p)

        return p_add + i_add


class DDFMv2(nn.Module):
    def __init__(self, in_channels, out_channels, BatchNorm=nn.BatchNorm2d):
        super(DDFMv2, self).__init__()
        self.conv_p = nn.Sequential(
                                BatchNorm(in_channels),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )
        self.conv_i = nn.Sequential(
                                BatchNorm(in_channels),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, bias=False),
                                BatchNorm(out_channels)
                                )

    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)

        p_add = self.conv_p((1-edge_att)*i + p)
        i_add = self.conv_i(i + edge_att*p)

        return p_add + i_add

class Bag(nn.Module):
    def __init__(self, in_channels, out_channels, BatchNorm=nn.BatchNorm2d):
        super(Bag, self).__init__()

        self.conv = nn.Sequential(
                                BatchNorm(in_channels),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(in_channels, out_channels,
                                          kernel_size=3, padding=1, bias=False)
                                )


    def forward(self, p, i, d):
        edge_att = torch.sigmoid(d)
        return self.conv(edge_att*p + (1-edge_att)*i)


pidnet.py

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import logging

BatchNorm2d = nn.BatchNorm2d
bn_mom = 0.1
algc = False

class PIDNet(nn.Module):

    def __init__(self, m=2, n=3, num_classes=19, planes=64, ppm_planes=96, head_planes=128, augment=True):
        super(PIDNet, self).__init__()
        self.augment = augment

        # I Branch
        self.conv1 =  nn.Sequential(
                          nn.Conv2d(3,planes,kernel_size=3, stride=2, padding=1),
                          BatchNorm2d(planes, momentum=bn_mom),
                          nn.ReLU(inplace=True),
                          nn.Conv2d(planes,planes,kernel_size=3, stride=2, padding=1),
                          BatchNorm2d(planes, momentum=bn_mom),
                          nn.ReLU(inplace=True),
                      )

        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(BasicBlock, planes, planes, m)
        self.layer2 = self._make_layer(BasicBlock, planes, planes * 2, m, stride=2)
        self.layer3 = self._make_layer(BasicBlock, planes * 2, planes * 4, n, stride=2)
        self.layer4 = self._make_layer(BasicBlock, planes * 4, planes * 8, n, stride=2)
        self.layer5 =  self._make_layer(Bottleneck, planes * 8, planes * 8, 2, stride=2)

        # P Branch
        self.compression3 = nn.Sequential(
                                          nn.Conv2d(planes * 4, planes * 2, kernel_size=1, bias=False),
                                          BatchNorm2d(planes * 2, momentum=bn_mom),
                                          )

        self.compression4 = nn.Sequential(
                                          nn.Conv2d(planes * 8, planes * 2, kernel_size=1, bias=False),
                                          BatchNorm2d(planes * 2, momentum=bn_mom),
                                          )
        self.pag3 = PagFM(planes * 2, planes)
        self.pag4 = PagFM(planes * 2, planes)

        self.layer3_ = self._make_layer(BasicBlock, planes * 2, planes * 2, m)
        self.layer4_ = self._make_layer(BasicBlock, planes * 2, planes * 2, m)
        self.layer5_ = self._make_layer(Bottleneck, planes * 2, planes * 2, 1)

        # D Branch
        if m == 2:
            self.layer3_d = self._make_single_layer(BasicBlock, planes * 2, planes)
            self.layer4_d = self._make_layer(Bottleneck, planes, planes, 1)
            self.diff3 = nn.Sequential(
                                        nn.Conv2d(planes * 4, planes, kernel_size=3, padding=1, bias=False),
                                        BatchNorm2d(planes, momentum=bn_mom),
                                        )
            self.diff4 = nn.Sequential(
                                     nn.Conv2d(planes * 8, planes * 2, kernel_size=3, padding=1, bias=False),
                                     BatchNorm2d(planes * 2, momentum=bn_mom),
                                     )
            self.spp = PAPPM(planes * 16, ppm_planes, planes * 4)
            self.dfm = Light_Bag(planes * 4, planes * 4)
        else:
            self.layer3_d = self._make_single_layer(BasicBlock, planes * 2, planes * 2)
            self.layer4_d = self._make_single_layer(BasicBlock, planes * 2, planes * 2)
            self.diff3 = nn.Sequential(
                                        nn.Conv2d(planes * 4, planes * 2, kernel_size=3, padding=1, bias=False),
                                        BatchNorm2d(planes * 2, momentum=bn_mom),
                                        )
            self.diff4 = nn.Sequential(
                                     nn.Conv2d(planes * 8, planes * 2, kernel_size=3, padding=1, bias=False),
                                     BatchNorm2d(planes * 2, momentum=bn_mom),
                                     )
            self.spp = DAPPM(planes * 16, ppm_planes, planes * 4)
            self.dfm = Bag(planes * 4, planes * 4)

        self.layer5_d = self._make_layer(Bottleneck, planes * 2, planes * 2, 1)

        # Prediction Head
        if self.augment:
            self.seghead_p = segmenthead(planes * 2, head_planes, num_classes)
            self.seghead_d = segmenthead(planes * 2, planes, 1)

        self.final_layer = segmenthead(planes * 4, head_planes, num_classes)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


    def _make_layer(self, block, inplanes, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
            )

        layers = []
        layers.append(block(inplanes, planes, stride, downsample))
        inplanes = planes * block.expansion
        for i in range(1, blocks):
            if i == (blocks-1):
                layers.append(block(inplanes, planes, stride=1, no_relu=True))
            else:
                layers.append(block(inplanes, planes, stride=1, no_relu=False))

        return nn.Sequential(*layers)

    def _make_single_layer(self, block, inplanes, planes, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
            )

        layer = block(inplanes, planes, stride, downsample, no_relu=True)

        return layer

    def forward(self, x):

        width_output = x.shape[-1] // 8
        height_output = x.shape[-2] // 8

        x = self.conv1(x)
        x = self.layer1(x)
        x = self.relu(self.layer2(self.relu(x)))
        x_ = self.layer3_(x)
        x_d = self.layer3_d(x)

        x = self.relu(self.layer3(x))
        x_ = self.pag3(x_, self.compression3(x))
        x_d = x_d + F.interpolate(
                        self.diff3(x),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=algc)
        if self.augment:
            temp_p = x_

        x = self.relu(self.layer4(x))
        x_ = self.layer4_(self.relu(x_))
        x_d = self.layer4_d(self.relu(x_d))

        x_ = self.pag4(x_, self.compression4(x))
        x_d = x_d + F.interpolate(
                        self.diff4(x),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=algc)
        if self.augment:
            temp_d = x_d

        x_ = self.layer5_(self.relu(x_))
        x_d = self.layer5_d(self.relu(x_d))
        x = F.interpolate(
                        self.spp(self.layer5(x)),
                        size=[height_output, width_output],
                        mode='bilinear', align_corners=algc)

        x_ = self.final_layer(self.dfm(x_, x, x_d))

        if self.augment:
            x_extra_p = self.seghead_p(temp_p)
            x_extra_d = self.seghead_d(temp_d)
            return [x_extra_p, x_, x_extra_d]
        else:
            return x_

def get_seg_model(cfg, imgnet_pretrained):

    if 's' in cfg.MODEL.NAME:
        model = PIDNet(m=2, n=3, num_classes=cfg.DATASET.NUM_CLASSES, planes=32, ppm_planes=96, head_planes=128, augment=True)
    elif 'm' in cfg.MODEL.NAME:
        model = PIDNet(m=2, n=3, num_classes=cfg.DATASET.NUM_CLASSES, planes=64, ppm_planes=96, head_planes=128, augment=True)
    else:
        model = PIDNet(m=3, n=4, num_classes=cfg.DATASET.NUM_CLASSES, planes=64, ppm_planes=112, head_planes=256, augment=True)

    if imgnet_pretrained:
        pretrained_state = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu')['state_dict']
        model_dict = model.state_dict()
        pretrained_state = {k: v for k, v in pretrained_state.items() if (k in model_dict and v.shape == model_dict[k].shape)}
        model_dict.update(pretrained_state)
        msg = 'Loaded {} parameters!'.format(len(pretrained_state))
        logging.info('Attention!!!')
        logging.info(msg)
        logging.info('Over!!!')
        model.load_state_dict(model_dict, strict = False)
    else:
        pretrained_dict = torch.load(cfg.MODEL.PRETRAINED, map_location='cpu')
        if 'state_dict' in pretrained_dict:
            pretrained_dict = pretrained_dict['state_dict']
        model_dict = model.state_dict()
        pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if (k[6:] in model_dict and v.shape == model_dict[k[6:]].shape)}
        msg = 'Loaded {} parameters!'.format(len(pretrained_dict))
        logging.info('Attention!!!')
        logging.info(msg)
        logging.info('Over!!!')
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict, strict = False)

    return model

def get_pred_model(name, num_classes):

    if 's' in name:
        model = PIDNet(m=2, n=3, num_classes=num_classes, planes=32, ppm_planes=96, head_planes=128, augment=False)
    elif 'm' in name:
        model = PIDNet(m=2, n=3, num_classes=num_classes, planes=64, ppm_planes=96, head_planes=128, augment=False)
    else:
        model = PIDNet(m=3, n=4, num_classes=num_classes, planes=64, ppm_planes=112, head_planes=256, augment=False)

    return model

base_dataset.py

In [None]:
import cv2
import numpy as np
import random
import albumentations as A
from torch.nn import functional as F
from torch.utils import data
import matplotlib.pyplot as plt
y_k_size = 6
x_k_size = 6

def show_images(x_original, x_augmented, unnormalize = False):

    if unnormalize:
        # ImageNet mean and std
        imagenet_mean = np.array([0.485, 0.456, 0.406])[:, None, None]
        imagenet_std = np.array([0.229, 0.224, 0.225])[:, None, None]

        # Denormalize using NumPy broadcasting
        x_original = x_original * imagenet_std + imagenet_mean
        x_augmented = x_augmented * imagenet_std + imagenet_mean

        # Clip to [0, 1] in case of overflows
        x_original = np.clip(x_original, 0, 1)
        x_augmented = np.clip(x_augmented, 0, 1)

        # Transpose to HWC for matplotlib
        x_original = np.transpose(x_original, (1, 2, 0))
        x_augmented = np.transpose(x_augmented, (1, 2, 0))

    # Plot
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].imshow(x_original)
    axs[0].set_title("Original Image")
    axs[0].axis("off")

    axs[1].imshow(x_augmented)
    axs[1].set_title("Augmented Image")
    axs[1].axis("off")

    plt.tight_layout()
    plt.show()

class BaseDataset(data.Dataset):
    def __init__(self,
                 ignore_label=255,
                 base_size=2048,
                 crop_size=(512, 512),
                 scale_factor=16,
                 mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]):

        self.base_size = base_size
        self.crop_size = crop_size
        self.ignore_label = ignore_label

        self.mean = mean
        self.std = std
        self.scale_factor = scale_factor

        self.files = []

    def __len__(self):
        return len(self.files)

    def input_transform(self, image, city=False):
        if city:
            image = image.astype(np.float32)[:, :, ::-1]
        else:
            image = image.astype(np.float32)
        image = image / 255.0
        image -= self.mean
        image /= self.std
        return image

    def label_transform(self, label):
        return np.array(label).astype(np.uint8)

    def pad_image(self, image, h, w, size, padvalue):
        pad_h = max(size[0] - h, 0)
        pad_w = max(size[1] - w, 0)

        # Se non è necessario il padding, restituisci l'immagine originale
        if pad_h == 0 and pad_w == 0:
            return image

        # Verifica il formato dell'immagine (deve essere H, W, C)
        if len(image.shape) == 3 and image.shape[0] <= 3:  # Se è in formato (C, H, W)
            image = np.transpose(image, (1, 2, 0))  # Converti in (H, W, C)

        # Aggiungi il padding
        pad_image = cv2.copyMakeBorder(image, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=padvalue)

        # Ripristina il formato originale (C, H, W) se necessario
        if len(image.shape) == 3 and image.shape[2] <= 3:  # Se era in formato (C, H, W)
            pad_image = np.transpose(pad_image, (2, 0, 1))  # Converti di nuovo in (C, H, W)

        return pad_image

    def rand_crop(self, image, label, edge):
        # Verifica il formato dell'immagine
        if len(image.shape) == 3 and image.shape[0] <= 3:  # Se è in formato (C, H, W)
            image = np.transpose(image, (1, 2, 0))  # Converti in (H, W, C)

        h, w = image.shape[:2]

        # Aggiungi padding se necessario
        if h < self.crop_size[0] or w < self.crop_size[1]:
            image = self.pad_image(image, h, w, self.crop_size, (0.0, 0.0, 0.0))
            label = self.pad_image(label, h, w, self.crop_size, (self.ignore_label,))
            edge = self.pad_image(edge, h, w, self.crop_size, (0.0,))

        # Aggiorna le dimensioni dopo il padding
        new_h, new_w = label.shape
        if new_h < self.crop_size[0] or new_w < self.crop_size[1]:
            raise ValueError(f"Dimensioni insufficienti per il ritaglio: label={label.shape}, crop_size={self.crop_size}")

        # Calcola le coordinate per il ritaglio casuale
        x = random.randint(0, new_w - self.crop_size[1])
        y = random.randint(0, new_h - self.crop_size[0])

        # Esegui il ritaglio
        image = image[y:y+self.crop_size[0], x:x+self.crop_size[1]]
        label = label[y:y+self.crop_size[0], x:x+self.crop_size[1]]
        edge = edge[y:y+self.crop_size[0], x:x+self.crop_size[1]]

        #in questo modo l'iimagine è 512x512x3
        #se volessi croppare quella regione
        '''
        # Estrai la regione da sfocare
        cropped_region = image[y:y+crop_size[0], x:x+crop_size[1]]

        # Applica il Gaussian Blur alla regione
        blurred_region = cv2.GaussianBlur(cropped_region, (15, 15), 0)

        # Sostituisci la regione originale con quella sfocata
        augmented_image = image.copy()
        augmented_image[y:y+crop_size[0], x:x+crop_size[1]] = blurred_region
        '''

        return image, label, edge

    def multi_scale_aug(self, image, label=None, edge=None,
                        rand_scale=1, rand_crop=True):
        long_size = int(self.base_size * rand_scale + 0.5)
        h, w = image.shape[:2]
        if h > w:
            new_h = long_size
            new_w = int(w * long_size / h + 0.5)
        else:
            new_w = long_size
            new_h = int(h * long_size / w + 0.5)

        image = cv2.resize(image, (new_w, new_h),
                           interpolation=cv2.INTER_LINEAR)
        if label is not None:
            label = cv2.resize(label, (new_w, new_h),
                               interpolation=cv2.INTER_NEAREST)
            if edge is not None:
                edge = cv2.resize(edge, (new_w, new_h),
                                   interpolation=cv2.INTER_NEAREST)
        else:
            return image

        if rand_crop:
            image, label, edge = self.rand_crop(image, label, edge)

        return image, label, edge


    def gen_sample(self, image, label, edge_pad=True, edge_size=4, city=False, transform=None, show=False):


        if transform is not None:
            # Pass both image and mask
            augmented = transform(image=image, mask=label)

            if show:
                show_images(image, augmented["image"])

            # Extract results
            image = augmented['image']
            label = augmented['mask']



        #It' important keeping the edge generation after the data augmentation
        edge = cv2.Canny(label, 0.1, 0.2)
        kernel = np.ones((edge_size, edge_size), np.uint8)
        if edge_pad:
            edge = edge[y_k_size:-y_k_size, x_k_size:-x_k_size]
            edge = np.pad(edge, ((y_k_size,y_k_size),(x_k_size,x_k_size)), mode='constant')
        edge = (cv2.dilate(edge, kernel, iterations=1)>50)*1.0


        #trasformazioni di input
        image = self.input_transform(image, city=city) #Se city=True, converte l'immagine da RGB in BGR per opencv
        label = self.label_transform(label) #converte la label in un array di interi
        image = image.transpose((2, 0, 1)) #H,W,C -> C,H,W

        return image, label, edge


    def inference(self, config, model, image):
        size = image.size()
        pred = model(image)

        if config.MODEL.NUM_OUTPUTS > 1:
            pred = pred[config.TEST.OUTPUT_INDEX]


        pred = F.interpolate(
            input=pred, size=size[-2:],
            mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
        )


        return pred.exp()

discriminator.py

In [None]:
#------------------------------------------------------------------------
    #Discriminator based on DCGAN

import torch.nn as nn
import torch.nn.functional as F


class FCDiscriminator(nn.Module):

    def __init__(self, num_classes, ndf = 64):
        super(FCDiscriminator, self).__init__()

        self.conv1 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
        self.classifier = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=2, padding=1)

        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        #self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear')
        #self.sigmoid = nn.Sigmoid() #non uso sigmoid per la stabilità numerica a discapito di avere i logits e non le probabilità


    def forward(self, x):
        x = self.conv1(x)
        x = self.leaky_relu(x)
        x = self.conv2(x)
        x = self.leaky_relu(x)
        x = self.conv3(x)
        x = self.leaky_relu(x)
        x = self.conv4(x)
        x = self.leaky_relu(x)
        x = self.classifier(x)
        #x = self.up_sample(x)
        #x = self.sigmoid(x)

        return x


loveda.py

In [None]:
import cv2
import os
import numpy as np
import torch
import random
import logging
from PIL import Image
import torchvision.transforms as tf
import matplotlib.pyplot as plt

def compare_images(image, blurred_image):
    # Compute the absolute difference between the images
    # Ensure both images are in the same format (H, W, C)
    blurred_image = blurred_image.transpose(1, 2, 0)  # Change (C, H, W) to (H, W, C)

    # Compute the absolute difference between the images
    diff = np.abs(image - blurred_image)

    # Plot the images and their difference side by side
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    axs[0].imshow(image)  # Original image
    axs[0].set_title("Original Image")
    axs[0].axis("off")

    axs[1].imshow(blurred_image)  # Blurred image
    axs[1].set_title("Blurred Image")
    axs[1].axis("off")

    axs[2].imshow(diff)  # Difference image
    axs[2].set_title("Difference Image")
    axs[2].axis("off")

    plt.show()

    return diff


def show_images(x_original, x_augmented, unnormalize = False):

    if unnormalize:
        # ImageNet mean and std
        imagenet_mean = np.array([0.485, 0.456, 0.406])[:, None, None]
        imagenet_std = np.array([0.229, 0.224, 0.225])[:, None, None]

        # Denormalize using NumPy broadcasting
        x_original = x_original * imagenet_std + imagenet_mean
        x_augmented = x_augmented * imagenet_std + imagenet_mean

        # Clip to [0, 1] in case of overflows
        x_original = np.clip(x_original, 0, 1)
        x_augmented = np.clip(x_augmented, 0, 1)

        # Transpose to HWC for matplotlib
        x_original = np.transpose(x_original, (1, 2, 0))
        x_augmented = np.transpose(x_augmented, (1, 2, 0))

    # Plot
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].imshow(x_original)
    axs[0].set_title("Original Image")
    axs[0].axis("off")

    axs[1].imshow(x_augmented)
    axs[1].set_title("Augmented Image")
    axs[1].axis("off")

    plt.tight_layout()
    plt.show()


#classe per fare la data augmentation
class DataAugmentation:
    def __init__(self, config, dataset_instance):
        self.enable = config["ENABLE"]
        self.probability = config["PROBABILITY"]
        self.techniques = config["TECHNIQUES"]
        self.dataset = dataset_instance  # Riferimento all'istanza del dataset

    def apply(self, image, label, edge):

        if not self.enable or random.random() > self.probability: #50% di probabilità di applicare la data augmentation
            return image,label,edge #non faccio augmentation

        if self.techniques.get("HORIZONTAL_FLIP", False):
            image,label,edge = self.horizontal_flip(image, label, edge)

        if self.techniques.get("GAUSSIAN_BLUR", False):
            image, label, edge = self.gaussian_blur(image, label, edge)

        if self.techniques.get("MULTIPLY", False):
            image, label, edge = self.multiply(image, label, edge)

        if self.techniques.get("RANDOM_BRIGHTNESS", False):
            image, label, edge = self.random_brightness(image, label, edge)

        if self.techniques.get("RANDOM_CROP", False):
            image, label, edge = self.random_crop(image, label, edge)


        return image, label, edge



    def random_crop(self, image, label, edge):
        return self.dataset.rand_crop(image, label, edge)  # Usa l'istanza del dataset



    def horizontal_flip(self, image, label, edge):
        # Inverti orizzontalmente immagine, label ed edge
        flipped_image = image[:, :, ::-1]
        flipped_label = label[:, ::-1]
        flipped_edge = edge[:, ::-1]
        return flipped_image, flipped_label, flipped_edge



    def gaussian_blur(self, image, label, edge, kernel_size=5, show = False):
        # Applica il Gaussian Blur solo all'immagine
        transposed_image = image.transpose(1, 2, 0)  # From (C, H, W) to (H, W, C)

        # Apply Gaussian blur
        blurred_image = cv2.GaussianBlur(transposed_image, (kernel_size, kernel_size), 0)

        # If you want to return it to the PyTorch format (C, H, W)
        blurred_image = blurred_image.transpose(2, 0, 1)  # From (H, W, C) to (C, H, W)

        if show:
            show_images(image, blurred_image)

        return blurred_image, label, edge



    def multiply(self, image, label, edge, factor_range=(0.8, 1.2), show = False):
        # Convert image to float32 to avoid overflow issues
        factor = random.uniform(*factor_range)
        image = image.astype(np.float32)  # Ensure safe multiplication

        # Check if image is normalized (0-1), rescale before multiplication
        if image.max() <= 1.0:
            image *= 255.0  # Scale to 0-255 range before multiplication

        multiplied_image = image * factor

        if show:
            show_images(image, multiplied_image)

        return multiplied_image, label, edge

    def random_brightness(self, image, label, edge, brightness_range=(-0.5, 0.5), show = False):
        # Modify image brightness
        brightness = np.float32(np.random.uniform(*brightness_range))
        brightened_image = image + brightness  # Keep within [0,1]

        if show:
            show_images(image, brightened_image)
        return brightened_image, label, edge



class LoveDA(BaseDataset):
    def __init__(self,
                 root,
                 list_path,
                 num_classes=7,
                 multi_scale=False,
                 flip=False,
                 ignore_label=0,
                 base_size=1024,
                 crop_size=(512, 512),
                 scale_factor=16, #multi scale usato come data augmentation alredy provided
                 enable_augmentation=False,
                 augmentation_probability=0.5,
                 horizontal_flip=False,
                 gaussian_blur=False,
                 multiply=False,
                 random_brightness=False,
                 random_crop=True,
                 mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225],
                 bd_dilate_size=4,
                 pseudo_label=False,
                 weighted=True,
                 transform=None):

        # estende il base_dataset
        super(LoveDA, self).__init__(ignore_label, base_size,
                                     crop_size, scale_factor, mean, std)

        self.root = root
        self.list_path = list_path
        self.num_classes = num_classes
        self.multi_scale = multi_scale
        self.flip = flip
        self.ignore_label = ignore_label
        self.base_size = base_size
        self.crop_size = crop_size
        self.scale_factor = scale_factor
        self.enable_augmentation = enable_augmentation
        self.augmentation_probability = augmentation_probability
        self.horizontal_flip = horizontal_flip
        self.gaussian_blur = gaussian_blur
        self.multiply = multiply
        self.random_brightness = random_brightness
        self.random_crop = random_crop
        self.bd_dilate_size = bd_dilate_size

        self.img_list = [line.strip().split() for line in open(root + list_path)]
        self.files = self.read_files()
        self.color_list = [[0, 0, 0], [1, 1, 1], [2, 2, 2],
                            [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6], [7, 7, 7]]
        self.class_weights = None
        if weighted:
            self.class_weights = torch.tensor([0.000000, 0.116411, 0.266041, 0.607794, 1.511413, 0.745507, 0.712438, 3.040396])
        
        self.pseudo_label = pseudo_label
        self.transform=transform

    def read_files(self):
        files = []

        for item in self.img_list:
            image_path, label_path = item
            name = os.path.splitext(os.path.basename(label_path))[0]
            files.append({
                "img": image_path,
                "label": label_path,
                "name": name
            })

        return files

    # da immagine a label
    def color2label(self, color_map):
        label = np.ones(color_map.shape[:2]) * self.ignore_label
        for i, v in enumerate(self.color_list):
            label[(color_map == v).sum(2) == 3] = i

        return label.astype(np.uint8)

    def convert_label(self, label, inverse=False):
        temp = label.copy()
        if inverse:
            for v, k in self.label_mapping.items():
                label[temp == k] = v
        else:
            for k, v in self.label_mapping.items():
                label[temp == k] = v
        return label

    # da label a immagine
    def label2color(self, label):
        color_map = np.zeros(label.shape + (3,))
        for i, v in enumerate(self.color_list):
            color_map[label == i] = self.color_list[i]

        return color_map.astype(np.uint8)

    def __getitem__(self, index):
        item = self.files[index]
        name = item["name"]
        image = cv2.imread(item["img"], cv2.IMREAD_COLOR)

        size = image.shape

        label = cv2.imread(item["label"], cv2.IMREAD_GRAYSCALE)



        #edge (H,W)
        image, label, edge = self.gen_sample(image, label, edge_pad=False,
                                             edge_size=self.bd_dilate_size, city=False, transform=self.transform, show=False) #image diventa (C,H,W)

        return image.copy(), label.copy(), edge.copy(), np.array(size), name

    def single_scale_inference(self, config, model, image):
        pred = self.inference(config, model, image)
        return pred

    def save_pred(self, preds, sv_path, name):
        preds = np.asarray(np.argmax(preds.cpu(), axis=1), dtype=np.uint8)
        for i in range(preds.shape[0]):
            pred = self.label2color(preds[i])
            save_img = Image.fromarray(pred)
            save_img.save(os.path.join(sv_path, name[i] + '.png'))

criterion.py

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from configs import config


class CrossEntropy(nn.Module):
    def __init__(self, ignore_label=-1, weight=None):
        super(CrossEntropy, self).__init__()
        self.ignore_label = ignore_label
        self.criterion = nn.CrossEntropyLoss(
            weight=weight,
            ignore_index=ignore_label
        )

    def _forward(self, score, target):

        loss = self.criterion(score, target)

        return loss

    def forward(self, score, target):

        if config.MODEL.NUM_OUTPUTS == 1:
            score = [score]

        balance_weights = config.LOSS.BALANCE_WEIGHTS
        sb_weights = config.LOSS.SB_WEIGHTS
        if len(balance_weights) == len(score):
            return sum([w * self._forward(x, target) for (w, x) in zip(balance_weights, score)])
        elif len(score) == 1:
            return sb_weights * self._forward(score[0], target)

        else:
            raise ValueError("lengths of prediction and target are not identical!")




class OhemCrossEntropy(nn.Module):
    def __init__(self, ignore_label=-1, thres=0.7,
                 min_kept=100000, weight=None):
        super(OhemCrossEntropy, self).__init__()
        self.thresh = thres
        self.min_kept = max(1, min_kept)
        self.ignore_label = ignore_label
        self.criterion = nn.CrossEntropyLoss(
            weight=weight,
            ignore_index=ignore_label,
            reduction='none'
        )

    def _ce_forward(self, score, target):


        loss = self.criterion(score, target)

        return loss

    def _ohem_forward(self, score, target, **kwargs):

        pred = F.softmax(score, dim=1)
        pixel_losses = self.criterion(score, target).contiguous().view(-1)
        mask = target.contiguous().view(-1) != self.ignore_label

        tmp_target = target.clone()
        tmp_target[tmp_target == self.ignore_label] = 0
        pred = pred.gather(1, tmp_target.unsqueeze(1))
        pred, ind = pred.contiguous().view(-1,)[mask].contiguous().sort()
        min_value = pred[min(self.min_kept, pred.numel() - 1)]
        threshold = max(min_value, self.thresh)

        pixel_losses = pixel_losses[mask][ind]
        pixel_losses = pixel_losses[pred < threshold]
        return pixel_losses.mean()

    def forward(self, score, target):

        if not (isinstance(score, list) or isinstance(score, tuple)):
            score = [score]

        balance_weights = config.LOSS.BALANCE_WEIGHTS
        sb_weights = config.LOSS.SB_WEIGHTS
        if len(balance_weights) == len(score):
            functions = [self._ce_forward] * \
                (len(balance_weights) - 1) + [self._ohem_forward]
            return sum([
                w * func(x, target)
                for (w, x, func) in zip(balance_weights, score, functions)
            ])

        elif len(score) == 1:
            return sb_weights * self._ohem_forward(score[0], target)

        else:
            raise ValueError("lengths of prediction and target are not identical!")


def weighted_bce(bd_pre, target):
    n, c, h, w = bd_pre.size()
    log_p = bd_pre.permute(0,2,3,1).contiguous().view(1, -1)
    target_t = target.view(1, -1)

    pos_index = (target_t == 1)
    neg_index = (target_t == 0)

    weight = torch.zeros_like(log_p)
    pos_num = pos_index.sum()
    neg_num = neg_index.sum()
    sum_num = pos_num + neg_num
    weight[pos_index] = neg_num * 1.0 / sum_num
    weight[neg_index] = pos_num * 1.0 / sum_num

    loss = F.binary_cross_entropy_with_logits(log_p, target_t, weight, reduction='mean')

    return loss


class BondaryLoss(nn.Module):
    def __init__(self, coeff_bce = 20.0):
        super(BondaryLoss, self).__init__()
        self.coeff_bce = coeff_bce

    def forward(self, bd_pre, bd_gt):

        bce_loss = self.coeff_bce * weighted_bce(bd_pre, bd_gt)
        loss = bce_loss

        return loss

utils.py

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging
import time
from pathlib import Path

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from configs import default
config = default._C.clone()
update_config = default.update_config

class FullModel(nn.Module):

  def __init__(self, model, sem_loss, bd_loss):
    super(FullModel, self).__init__()
    self.model = model
    self.sem_loss = sem_loss
    self.bd_loss = bd_loss

  def pixel_acc(self, pred, label):
    _, preds = torch.max(pred, dim=1)
    valid = (label >= 0).long()
    acc_sum = torch.sum(valid * (preds == label).long())
    pixel_sum = torch.sum(valid)
    acc = acc_sum.float() / (pixel_sum.float() + 1e-10)
    return acc

  def forward(self, inputs, labels, bd_gt, *args, **kwargs):

    outputs = self.model(inputs, *args, **kwargs)

    h, w = labels.size(1), labels.size(2)
    ph, pw = outputs[0].size(2), outputs[0].size(3)
    if ph != h or pw != w:
        for i in range(len(outputs)):
            outputs[i] = F.interpolate(outputs[i], size=(
                h, w), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS)

    acc  = self.pixel_acc(outputs[-2], labels)
    loss_s = self.sem_loss(outputs[:-1], labels)
    loss_b = self.bd_loss(outputs[-1], bd_gt)

    filler = torch.ones_like(labels) * config.TRAIN.IGNORE_LABEL
    try:
        bd_label = torch.where(torch.sigmoid(outputs[-1][:, 0, :, :]) > 0.8, labels, filler) # 0.7
        loss_sb = self.sem_loss([outputs[-2]], bd_label)
    except:
        print("Error in loss computation")
        loss_sb = self.sem_loss([outputs[-2]], labels)
    loss = loss_s + loss_b + loss_sb

    return torch.unsqueeze(loss,0), outputs[:-1], acc, [loss_s, loss_b] #aoutputs[:-1] è una lista di tensori


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.initialized = False
        self.val = None
        self.avg = None
        self.sum = None
        self.count = None

    def initialize(self, val, weight):
        self.val = val
        self.avg = val
        self.sum = val * weight
        self.count = weight
        self.initialized = True

    def update(self, val, weight=1):
        if not self.initialized:
            self.initialize(val, weight)
        else:
            self.add(val, weight)

    def add(self, val, weight):
        self.val = val
        self.sum += val * weight
        self.count += weight
        self.avg = self.sum / self.count

    def value(self):
        return self.val

    def average(self):
        return self.avg

def create_logger(cfg, cfg_name, phase='train'):
    root_output_dir = Path(cfg.OUTPUT_DIR)

    if (
        not cfg.TRAIN.AUGMENTATION.TECHNIQUES.HORIZONTAL_FLIP and
        not cfg.TRAIN.AUGMENTATION.TECHNIQUES.GAUSSIAN_BLUR and
        not cfg.TRAIN.AUGMENTATION.TECHNIQUES.RANDOM_CROP and
        not cfg.TRAIN.AUGMENTATION.TECHNIQUES.GAUSSIAN_NOISE and
        not cfg.TRAIN.AUGMENTATION.TECHNIQUES.COLOR_JITTER
        ):
        folder_name = "no_aug"
    else:
        folder_name = "aug"

    if cfg.TRAIN.DACS.ENABLE:
        folder_name = "dacs"

    if cfg.TRAIN.GAN.ENABLE:
        folder_name = "gan"

    if cfg.TRAIN.AUGMENTATION.ENABLE:
        folder_name+= "_hf" if cfg.TRAIN.AUGMENTATION.TECHNIQUES.HORIZONTAL_FLIP else ""
        folder_name+= "_gb" if cfg.TRAIN.AUGMENTATION.TECHNIQUES.GAUSSIAN_BLUR else ""
        folder_name+= "_rc" if cfg.TRAIN.AUGMENTATION.TECHNIQUES.RANDOM_CROP else ""
        folder_name+= "_cj" if cfg.TRAIN.AUGMENTATION.TECHNIQUES.COLOR_JITTER else ""
        folder_name+= "_gn" if cfg.TRAIN.AUGMENTATION.TECHNIQUES.GAUSSIAN_NOISE else ""

    # set up logger
    if not root_output_dir.exists():
        print('=> creating {}'.format(root_output_dir))
        root_output_dir.mkdir()

    dataset = cfg.DATASET.DATASET
    model = cfg.MODEL.NAME
    cfg_name = os.path.basename(cfg_name).split('.')[0]

    final_output_dir = root_output_dir / dataset / cfg_name / folder_name

    print('=> creating {}'.format(final_output_dir))
    final_output_dir.mkdir(parents=True, exist_ok=True)

    time_str = time.strftime('%Y-%m-%d-%H-%M')
    log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase)
    final_log_file = final_output_dir / log_file
    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(filename=str(final_log_file),
                        format=head)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    console = logging.StreamHandler()
    logging.getLogger('').addHandler(console)

    tensorboard_log_dir = Path(cfg.LOG_DIR) / dataset / model / \
            (cfg_name + '_' + time_str)
    print('=> creating {}'.format(tensorboard_log_dir))
    tensorboard_log_dir.mkdir(parents=True, exist_ok=True)

    return logger, str(final_output_dir), str(tensorboard_log_dir)

def get_confusion_matrix(label, pred, size, num_class, ignore=-1):
    """
    Calcute the confusion matrix by given label and pred
    """
    output = pred.cpu().numpy().transpose(0, 2, 3, 1)
    seg_pred = np.asarray(np.argmax(output, axis=3), dtype=np.uint8)
    seg_gt = np.asarray(
    label.cpu().numpy()[:, :size[-2], :size[-1]], dtype=int)

    ignore_index = seg_gt != ignore
    seg_gt = seg_gt[ignore_index]
    seg_pred = seg_pred[ignore_index]

    index = (seg_gt * num_class + seg_pred).astype('int32')
    label_count = np.bincount(index)
    confusion_matrix = np.zeros((num_class, num_class))

    for i_label in range(num_class):
        for i_pred in range(num_class):
            cur_index = i_label * num_class + i_pred
            if cur_index < len(label_count):
                confusion_matrix[i_label,
                                 i_pred] = label_count[cur_index]
    return confusion_matrix

def adjust_learning_rate(optimizer, base_lr, max_iters,
        cur_iters, power=0.9, nbb_mult=10):
    lr = base_lr*((1-float(cur_iters)/max_iters)**(power))
    optimizer.param_groups[0]['lr'] = lr
    if len(optimizer.param_groups) == 2:
        optimizer.param_groups[1]['lr'] = lr * nbb_mult
    return lr

classmix.py

In [None]:
import argparse
import torch
import random
import torch.nn as nn
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
import numpy as np
import cv2
import matplotlib.patches as mpatches

def show_mixed_visualization(x_s, y_s, x_t, y_t, x_mixed, y_mixed, bd_mixed, ignore_label=0):
    """
    Displays source, target, mixed images and labels alongside the edge map, 
    with a legend whose colors exactly match the label images.
    
    Args:
        x_s, x_t, x_mixed: Tensors [1,C,H,W] pre-normalized.
        y_s, y_t, y_mixed: Tensors [1,H,W] integer labels.
        bd_mixed: Tensor [1,H,W] edge/boundary map.
        ignore_label: integer value to ignore in edge_map_vis.
        class_names: dict mapping label int -> class name.
    """
    class_names = {
        0: "no-data",
        1: "background",
        2: "building",
        3: "road",
        4: "water",
        5: "barren",
        6: "forest",
        7: "agriculture",
    }
    # Unpack first batch element and move to CPU / numpy
    x_s = x_s[0].cpu().clone()
    x_t = x_t[0].cpu().clone()
    x_mixed = x_mixed[0].cpu().clone()
    y_s = y_s[0].cpu().numpy()
    y_t = y_t[0].cpu().numpy()
    y_mixed = y_mixed[0].cpu().numpy()
    edge_map = bd_mixed[0].cpu().numpy()

    # Prepare edge‐map visualization
    edge_vis = (edge_map != ignore_label).astype(np.uint8) * edge_map
    edge_vis = (edge_vis > 0).astype(np.uint8) * 255

    # Un-normalize images (assumes ImageNet mean/std)
    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std  = [0.229, 0.224, 0.225]
    for t, m, s in zip(x_s, imagenet_mean, imagenet_std):
        t.mul_(s).add_(m)
    for t, m, s in zip(x_t, imagenet_mean, imagenet_std):
        t.mul_(s).add_(m)
    for t, m, s in zip(x_mixed, imagenet_mean, imagenet_std):
        t.mul_(s).add_(m)

    # H×W×C for imshow
    x_s = np.transpose(x_s.numpy(),     (1, 2, 0))
    x_t = np.transpose(x_t.numpy(),     (1, 2, 0))
    x_mixed = np.transpose(x_mixed.numpy(), (1, 2, 0))

    # Create a discrete ListedColormap with exactly len(class_names) colors
    num_classes = len(class_names)
    cmap = plt.cm.get_cmap("tab20", num_classes)

    fig, axs = plt.subplots(2, 4, figsize=(16, 6))

    # Row 1: images + edge map
    axs[0, 0].imshow(x_s);      axs[0, 0].set_title("Source Image");  axs[0, 0].axis("off")
    axs[0, 1].imshow(x_t);      axs[0, 1].set_title("Target Image");  axs[0, 1].axis("off")
    axs[0, 2].imshow(x_mixed);  axs[0, 2].set_title("Mixed Image");   axs[0, 2].axis("off")
    axs[0, 3].imshow(edge_vis, cmap='gray', vmin=0, vmax=255)
    axs[0, 3].set_title("Edge Map"); axs[0, 3].axis("off")

    # Row 2: labels using the same cmap with explicit vmin/vmax
    axs[1, 0].imshow(y_s,      cmap=cmap, vmin=0, vmax=num_classes-1)
    axs[1, 0].set_title("Source Label");  axs[1, 0].axis("off")
    axs[1, 1].imshow(y_t,      cmap=cmap, vmin=0, vmax=num_classes-1)
    axs[1, 1].set_title("Pseudo Label");  axs[1, 1].axis("off")
    axs[1, 2].imshow(y_mixed,  cmap=cmap, vmin=0, vmax=num_classes-1)
    axs[1, 2].set_title("Mixed Label");   axs[1, 2].axis("off")

    # Legend: one patch per class, using the exact same cmap
    handles = [
        mpatches.Patch(color=cmap(i), label=class_names[i])
        for i in sorted(class_names)
    ]
    axs[1, 3].legend(handles=handles, loc="center", ncol=2, frameon=False, fontsize="small")
    axs[1, 3].axis("off")

    plt.tight_layout()
    plt.show()
    
    

def classmix(x1, y1, x2, y2, verbose=False, deterministic=False):
    classes = y1.unique().tolist()
    # Remove class 1 if it exists, because background occupies too much of the image
    classes = [c for c in classes if c not in (0, 1)]
    selected = random.sample(classes, len(classes) // 2)
    if deterministic:
        selected = classes[:len(classes) // 2]
    mask = torch.zeros_like(y1, dtype=torch.bool)
    if verbose:
        print(f"Classes from the source domain: {selected}")
    for c in selected:
        mask |= (y1 == c) # c
    mask = mask.unsqueeze(1)
    x_mix = torch.where(mask, x1, x2)
    y_mix = torch.where(mask.squeeze(1), y1, y2)
    return x_mix, y_mix, mask.squeeze(1)



def generate_confident_edge_map(y_mixed, source_mask, confidence_mask,
                                edge_size=4, edge_pad=False, ignore_label=255):
    """
    Build boundary map from y_mixed but only at:
     - all source pixels (source_mask==1)
     - target pixels with confidence_mask==True
    Everything else is set to ignore_label.
    """
    B, H, W = y_mixed.shape
    edge_maps = []
    kernel = np.ones((edge_size, edge_size), np.uint8)
    pad = edge_size  # or whatever padding you want

    for i in range(B):
        lbl = y_mixed[i].cpu().numpy().astype(np.uint8)
        src = source_mask[i].cpu().numpy().astype(bool)
        conf = confidence_mask[i].cpu().numpy().astype(bool)

        # only keep labels where we trust them
        keep = src | conf
        lbl_masked = np.where(keep, lbl, 0).astype(np.uint8)

        # detect edges
        edge = cv2.Canny(lbl_masked, 0.1, 0.2)
        if edge_pad:
            edge = edge[pad:-pad, pad:-pad]
            edge = np.pad(edge, ((pad, pad), (pad, pad)), mode='constant')

        # dilate + threshold
        edge = (cv2.dilate(edge, kernel, iterations=1) > 50).astype(np.uint8)

        # ignore everything outside our keep region
        edge[~keep] = ignore_label

        edge_maps.append(torch.from_numpy(edge).long().unsqueeze(0))

    return torch.cat(edge_maps, dim=0).to(y_mixed.device)

def parse_args():
    parser = argparse.ArgumentParser(description='Train segmentation network')
    parser.add_argument('--cfg', default="configs/loveda/pidnet_small_loveda.yaml", type=str)
    parser.add_argument('--seed', type=int, default=304)
    parser.add_argument('opts', default=None, nargs=argparse.REMAINDER)
    args = parser.parse_args()
    update_config(config, args)
    return args

function.py

In [None]:
import logging
import os
import time
import numpy as np
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision.transforms.functional as TF
import numpy as np

import torch
from torch.nn import functional as F
import torch.nn as nn

# Define the transform once (outside the training loop ideally)
color_jitter = A.Compose([
    A.ColorJitter(
        p=0.5
    ),
])

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

def apply_color_jitter_tensor(img_tensor):
    # Move mean and std to same device as img_tensor
    mean = IMAGENET_MEAN.to(img_tensor.device)
    std = IMAGENET_STD.to(img_tensor.device)

    # Denormalize
    img_tensor = img_tensor.clone()
    img_tensor = img_tensor * std + mean

    img_np = img_tensor.cpu().numpy()  # (C, H, W)
    img_np = np.transpose(img_np, (1, 2, 0))  # (H, W, C)
    img_np = np.clip(img_np * 255.0, 0, 255).astype(np.uint8)

    augmented = color_jitter(image=img_np)
    img_aug = augmented['image'].astype(np.float32) / 255.0
    img_aug = np.transpose(img_aug, (2, 0, 1))  # back to (C, H, W)

    # Re-normalize (on CPU, then move back to GPU)
    img_aug = torch.tensor(img_aug, dtype=torch.float)
    img_aug = (img_aug - IMAGENET_MEAN) / IMAGENET_STD  # these are still on CPU
    return img_aug.to(img_tensor.device)

def update_ema_variables(ema_model, model, alpha_teacher, iteration):
    # Use the "true" average until the exponential average is more correct
    alpha_teacher = min(1 - 1 / (iteration + 1), alpha_teacher)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        #ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
        ema_param.data[:] = alpha_teacher * ema_param[:].data[:] + (1 - alpha_teacher) * param[:].data[:]
    return ema_model


def train(config, epoch, num_epoch, epoch_iters, base_lr,
          num_iters, trainloader, optimizer, model, writer_dict, deeplab_model = None, ema_model=None, targetloader=None):
    # Training
    
    model.train()
    show_dacs = False
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    ave_acc  = AverageMeter()
    avg_sem_loss = AverageMeter()
    avg_bce_loss = AverageMeter()
    tic = time.time()
    cur_iters = epoch*epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']
    
    if targetloader is not None:
        target_iter = iter(targetloader)
    lambda_weight = 0
    for i_iter, batch in enumerate(trainloader, 0):
        if config.TRAIN.DACS.ENABLE and epoch > -1:
            # DACS

            # === SOURCE BATCH ==
            x_s, y_s, bd_s, _, _ = batch
            
            x_s, y_s, bd_s = x_s.cuda(), y_s.long().cuda(), bd_s.float().cuda()

            # === TARGET BATCH ===
            try:
                x_t, y_t, _, _, _ = next(target_iter)
            except StopIteration:
                target_iter = iter(targetloader)
                x_t, y_t, _, _, _ = next(target_iter)
            x_t = x_t.cuda()

            with torch.no_grad():
                if deeplab_model is not None:
                    x_t_resized = torch.nn.functional.interpolate(x_t, size=(224, 224), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS)
                    logits_t = deeplab_model(x_t_resized)
                    
                elif ema_model is not None:
                    logits_t = ema_model.module.model(x_t)[-2]
    
                else:
                    logits_t = model.module.model(x_t)[-2]
                    
                logits_t = torch.nn.functional.interpolate(
                    logits_t, size=x_t.shape[2:], mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
                )
                
                pseudo_t = torch.argmax(logits_t, dim=1)
                if deeplab_model is not None:
                    pseudo_t += 1
                conf_t = torch.softmax(logits_t, dim=1).max(dim=1)[0]


            # === CONFIDENCE USAGE ==

            # Confidence_mask tells the position of the pixel whose pseudo-labels have
            # a confidence higher than  the threshold
            confidence_mask = conf_t > config.TRAIN.DACS.THRESHOLD 


            
            # Where confidence mask is True put the pixel from pseudo_t, otherwise insert the ignore label (0)
            pseudo_t_filtered = torch.where(
               confidence_mask,
               pseudo_t,
               torch.tensor(config.TRAIN.IGNORE_LABEL, device=pseudo_t.device)
            )
            
            # Apply classmix
            x_mixed, y_mixed, source_mask = classmix(x_s, y_s, x_t, pseudo_t_filtered)

            
            # Generate edges from the mixed image
            bd_mixed = generate_confident_edge_map(
                y_mixed,                       # mixed labels with ignore where conf low
                source_mask,                   # source==1, target==0
                confidence_mask,               # True where pseudo-label valid
                edge_size=3, edge_pad=True,
                ignore_label=config.TRAIN.IGNORE_LABEL
            )

            

            # Source images augmentation
            x_s.cpu()

            # Apply color jitter to batch
            x_s = torch.stack([
                apply_color_jitter_tensor(img) for img in x_s
            ])
            
            x_s.cuda()

            # Mixed images augmentation
            # Apply to batch
            x_mixed = torch.stack([
                apply_color_jitter_tensor(img) for img in x_mixed
            ])
            
            if show_dacs:
                show_mixed_visualization(x_s, y_s, x_t, pseudo_t_filtered, x_mixed, y_mixed, bd_mixed)
            
            x_mixed = x_mixed.cuda()
            y_mixed = y_mixed.long().cuda()
            bd_mixed = bd_mixed.float().cuda()


            
            # === FORWARD PASSES ===
            losses_src, _, acc_src, loss_list_src = model(x_s, y_s, bd_s)
            losses_mix, _, acc_mix, loss_list_mix = model(x_mixed, y_mixed, bd_mixed)
            
            loss_src = losses_src.mean()
            loss_mix = losses_mix.mean()
            
            # === COMBINE LOSSES ===
            # lambda_weight = confidence_mask.float().mean().item()
            lambda_weight = (epoch+1) * 0.01
            loss = loss_src + lambda_weight * loss_mix # .mean()
            acc = (acc_src + acc_mix) / 2
            
            sem_loss = loss_list_src[0] + lambda_weight * loss_list_mix[0]
            bce_loss = loss_list_src[1] + lambda_weight * loss_list_mix[1]
        else:
            images, labels, bd_gts, _, _ = batch
            images = images.cuda()
            labels = labels.long().cuda()
            bd_gts = bd_gts.float().cuda()


            losses, _, acc, loss_list = model(images, labels, bd_gts)
            loss = losses.mean()
            acc  = acc.mean()
            sem_loss = loss_list[0]
            bce_loss = loss_list[1]

        model.zero_grad()
        loss.backward()
        optimizer.step()

        if config.TRAIN.DACS.ENABLE and ema_model is not None and epoch > -1:
            alpha_teacher = 0.99
            ema_model = update_ema_variables(ema_model = ema_model, model = model, alpha_teacher=alpha_teacher, iteration=i_iter+cur_iters)
        
        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(loss.item())
        ave_acc.update(acc.item())
        avg_sem_loss.update(sem_loss.mean().item())
        avg_bce_loss.update(bce_loss.mean().item())

        lr = adjust_learning_rate(optimizer,
                                  base_lr,
                                  num_iters,
                                  i_iter+cur_iters)

        if i_iter % config.PRINT_FREQ == 0:
            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {}, Loss: {:.6f}, Lambda: {:.2f}, Acc:{:.6f}, Semantic loss: {:.6f}, BCE loss: {:.6f}, SB loss: {:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters,
                      batch_time.average(), [x['lr'] for x in optimizer.param_groups], ave_loss.average(), lambda_weight,
                      ave_acc.average(), avg_sem_loss.average(), avg_bce_loss.average(),ave_loss.average()-avg_sem_loss.average()-avg_bce_loss.average())
            logging.info(msg)

    writer.add_scalar('train_loss', ave_loss.average(), global_steps)
    writer_dict['train_global_steps'] = global_steps + 1

    # Ritorna la loss media per l'epoca
    return ave_loss.average()

def validate(config, testloader, model, writer_dict):
    model.eval()
    ave_loss = AverageMeter()
    nums = config.MODEL.NUM_OUTPUTS
    confusion_matrix = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES, nums))
    with torch.no_grad():
        for idx, batch in enumerate(testloader):
            image, label, bd_gts, _, _ = batch
            size = label.size()
            image = image.cuda()
            label = label.long().cuda()
            bd_gts = bd_gts.float().cuda()

            losses, pred, _, _ = model(image, label, bd_gts)
            if not isinstance(pred, (list, tuple)):
                pred = [pred]
            for i, x in enumerate(pred):
                x = F.interpolate(
                    input=x, size=size[-2:],
                    mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
                )

                confusion_matrix[..., i] += get_confusion_matrix(
                    label,
                    x,
                    size,
                    config.DATASET.NUM_CLASSES,
                    config.TRAIN.IGNORE_LABEL
                )

            if idx % 10 == 0:
                print(idx)

            loss = losses.mean()
            ave_loss.update(loss.item())

    for i in range(nums):
        pos = confusion_matrix[..., i].sum(1)
        res = confusion_matrix[..., i].sum(0)
        tp = np.diag(confusion_matrix[..., i])
        IoU_array = (tp / np.maximum(1.0, pos + res - tp))
        mean_IoU = IoU_array.mean()

        logging.info('{} {} {}'.format(i, IoU_array, mean_IoU))

    writer = writer_dict['writer']
    global_steps = writer_dict['valid_global_steps']
    writer.add_scalar('valid_loss', ave_loss.average(), global_steps)
    writer.add_scalar('valid_mIoU', mean_IoU, global_steps)
    writer_dict['valid_global_steps'] = global_steps + 1
    return ave_loss.average(), mean_IoU, IoU_array


def testval(config, test_dataset, testloader, model,
            sv_dir='./', sv_pred=False):
    model.eval()
    confusion_matrix = np.zeros((config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))
    with torch.no_grad():
        for index, batch in enumerate(tqdm(testloader)):
            image, label, _, _, name = batch
            size = label.size()
            pred = test_dataset.single_scale_inference(config, model, image.cuda())

            if pred.size()[-2] != size[-2] or pred.size()[-1] != size[-1]:
                pred = F.interpolate(
                    pred, size[-2:],
                    mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
                )

            confusion_matrix += get_confusion_matrix(
                label,
                pred,
                size,
                config.DATASET.NUM_CLASSES,
                config.TRAIN.IGNORE_LABEL)

            if sv_pred:
                sv_path = os.path.join(sv_dir, 'val_results')
                if not os.path.exists(sv_path):
                    os.mkdir(sv_path)
                test_dataset.save_pred(pred, sv_path, name)

            if index % 100 == 0:
                logging.info('processing: %d images' % index)
                pos = confusion_matrix.sum(1)
                res = confusion_matrix.sum(0)
                tp = np.diag(confusion_matrix)
                IoU_array = (tp / np.maximum(1.0, pos + res - tp))
                mean_IoU = IoU_array.mean()
                logging.info('mIoU: %.4f' % (mean_IoU))

    pos = confusion_matrix.sum(1)
    res = confusion_matrix.sum(0)
    tp = np.diag(confusion_matrix)
    pixel_acc = tp.sum()/pos.sum()
    mean_acc = (tp/np.maximum(1.0, pos)).mean()
    IoU_array = (tp / np.maximum(1.0, pos + res - tp))
    mean_IoU = IoU_array.mean()

    return mean_IoU, IoU_array, pixel_acc, mean_acc


def test(config, test_dataset, testloader, model,
         sv_dir='./', sv_pred=True):
    model.eval()
    with torch.no_grad():
        for _, batch in enumerate(tqdm(testloader)):
            image, size, name = batch
            size = size[0]
            pred = test_dataset.single_scale_inference(
                config,
                model,
                image.cuda())

            if pred.size()[-2] != size[0] or pred.size()[-1] != size[1]:
                pred = F.interpolate(
                    pred, size[-2:],
                    mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
                )

            if sv_pred:
                sv_path = os.path.join(sv_dir,'test_results')
                if not os.path.exists(sv_path):
                    os.mkdir(sv_path)
                test_dataset.save_pred(pred, sv_path, name)

def train_adv(config, epoch, num_epoch, epoch_iters, base_lr,
          num_iters, trainloader, targetloader, optimizer_G, optimizer_D,
          model, discriminator, writer_dict, lambda_adv=0.001, iter_size=4):

    # Training mode
    model.train()
    discriminator.train()

    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    ave_acc = AverageMeter()
    avg_sem_loss = AverageMeter()
    avg_bce_loss = AverageMeter()

    tic = time.time()
    cur_iters = epoch * epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']

    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    for i_iter, (batch_source, batch_target) in enumerate(zip(trainloader, targetloader)):

        optimizer_G.zero_grad()
        optimizer_D.zero_grad()

        cumulative_loss_G = 0.0
        cumulative_loss_D = 0.0

        for sub_iter in range(iter_size): #to make the batch size bigger

            #Train G
            # don't accumulate grads in D
            for param in discriminator.parameters():
                param.requires_grad = False



            images_source, labels, bd_gts, _, _ = batch_source
            images_target, _, _, _, _ = batch_target

            images_source = images_source.to(device)
            images_target = images_target.to(device)
            labels = labels.long().to(device)
            bd_gts = bd_gts.float().to(device)

            # Checks
            #print(f"Labels dtype: {labels.dtype}, shape: {labels.shape}, unique values: {torch.unique(labels)}")
            #assert labels.dtype == torch.long, "Labels devono essere di tipo torch.LongTensor"
            #assert labels.min() >= 0, f"Labels contengono valori negativi: {labels.min()}"
            #assert labels.max() < 8, f"Labels contengono valori >= n_classes: {labels.max()}"



            # 1. Forward seg net per dominio sorgente (supervisionato)
            loss_seg1, output_source, _, _ = model(images_source, labels, bd_gts)
            loss_seg1 = loss_seg1.squeeze().mean()
            loss_seg1 = loss_seg1 / iter_size
            cumulative_loss_G += loss_seg1.item()
            loss_seg1.backward()

            # Forward pass per il dominio target (adversarial)
            output_target = model(images_target, labels, bd_gts)[1]
            fake_preds = discriminator(F.softmax(output_target[-2], dim=1))
            loss_adv = nn.BCEWithLogitsLoss()(fake_preds, torch.ones_like(fake_preds))
            loss_adv = (loss_adv * lambda_adv) / iter_size  # Normalizza la loss
            cumulative_loss_G += loss_adv.item()
            loss_adv.backward()


            output_source = [t.detach() for t in output_source]
            output_target = [t.detach() for t in output_target]

            for param in discriminator.parameters():
                param.requires_grad = True

            # Aggiorna il discriminatore
            real_preds = discriminator(F.softmax(output_source[-2], dim=1)) #TODO: check -2
            fake_preds = discriminator(F.softmax(output_target[-2], dim=1))
            loss_D_real = nn.BCEWithLogitsLoss()(real_preds, torch.ones_like(real_preds))
            loss_D_fake = nn.BCEWithLogitsLoss()(fake_preds, torch.zeros_like(fake_preds))
            loss_D = (loss_D_real + loss_D_fake) / 2
            loss_D = loss_D / iter_size  # Normalizza la loss
            cumulative_loss_D += loss_D.item()
            loss_D.backward()

        # Aggiorna i pesi dopo tutte le sub-iterazioni
        optimizer_G.step()
        optimizer_D.step()

        # Log
        batch_time.update(time.time() - tic)
        tic = time.time()

        ave_loss.update(cumulative_loss_G)
        ave_acc.update(0)
        avg_sem_loss.update(0)
        avg_bce_loss.update(0)

        lr = adjust_learning_rate(optimizer_G, base_lr, num_iters, i_iter + cur_iters)

        if i_iter % config.PRINT_FREQ == 0:
            msg = ('Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, lr: {}, '
                   'Loss_G: {:.6f}, Loss_D: {:.6f}, Acc: {:.6f}, Semantic Loss: {:.6f}, BCE Loss: {:.6f}').format(
                      epoch, num_epoch, i_iter, epoch_iters, batch_time.average(),
                      [x['lr'] for x in optimizer_G.param_groups], ave_loss.average(), cumulative_loss_D,
                      ave_acc.average(), avg_sem_loss.average(), avg_bce_loss.average()
                  )
            logging.info(msg)

    writer.add_scalar('train_loss_G', ave_loss.average(), global_steps)
    writer.add_scalar('train_loss_D', cumulative_loss_D, global_steps)
    writer_dict['train_global_steps'] = global_steps + 1

def train_adv_multi(config, epoch, num_epoch, epoch_iters, base_lr,
                    num_iters, trainloader, targetloader, optimizer_G,
                    optimizer_D1, optimizer_D2, model, discriminator1, discriminator2,
                    writer_dict, lambda_adv1=0.001, lambda_adv2=0.0005, iter_size=4):

    # Training mode
    model.train()
    discriminator1.train()
    discriminator2.train()

    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    ave_acc = AverageMeter()
    avg_sem_loss = AverageMeter()
    avg_bce_loss = AverageMeter()

    tic = time.time()
    cur_iters = epoch * epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']

    # Iteratore per il targetloader
    target_iter = iter(targetloader)

    for i_iter, batch_source in enumerate(trainloader):
        optimizer_G.zero_grad()
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()

        cumulative_loss_G = 0.0
        cumulative_loss_D1 = 0.0
        cumulative_loss_D2 = 0.0

        for sub_iter in range(iter_size):  # Accumula i gradienti su più sotto-iterazioni
            # Ottieni il batch target
            try:
                batch_target = next(target_iter)
            except StopIteration:
                target_iter = iter(targetloader)
                batch_target = next(target_iter)

            # Estrai i dati dal batch
            images_source, labels, bd_gts, _, _ = batch_source
            images_target, _, _, _, _ = batch_target

            images_source = images_source.cuda()
            images_target = images_target.cuda()
            labels = labels.long().cuda()
            bd_gts = bd_gts.float().cuda()

            # ------------------ TRAINING DEL GENERATORE ------------------
            # Forward pass della rete di segmentazione (Supervisionato su sorgente)
            losses, output_source_final, output_source_intermediate, acc, loss_list = model(images_source, labels, bd_gts)
            loss_seg = losses.mean()

            # Forward pass della rete di segmentazione su target (Non supervisionato)
            output_target_final, output_target_intermediate = model(images_target)

            # Calcola la loss adversarial per il generatore
            fake_preds1 = discriminator1(output_target_final)
            fake_preds2 = discriminator2(output_target_intermediate)

            loss_adv1 = nn.BCEWithLogitsLoss()(fake_preds1, torch.ones_like(fake_preds1))
            loss_adv2 = nn.BCEWithLogitsLoss()(fake_preds2, torch.ones_like(fake_preds2))

            loss_G = (loss_seg + lambda_adv1 * loss_adv1 + lambda_adv2 * loss_adv2) / iter_size
            cumulative_loss_G += loss_G.item()
            loss_G.backward()

            # ------------------ TRAINING DEL DISCRIMINATORE 1 ------------------
            real_preds1 = discriminator1(output_source_final.detach())
            fake_preds1 = discriminator1(output_target_final.detach())

            loss_D1_real = nn.BCEWithLogitsLoss()(real_preds1, torch.ones_like(real_preds1))
            loss_D1_fake = nn.BCEWithLogitsLoss()(fake_preds1, torch.zeros_like(fake_preds1))
            loss_D1 = (loss_D1_real + loss_D1_fake) / 2 / iter_size
            cumulative_loss_D1 += loss_D1.item()
            loss_D1.backward()

            # ------------------ TRAINING DEL DISCRIMINATORE 2 ------------------
            real_preds2 = discriminator2(output_source_intermediate.detach())
            fake_preds2 = discriminator2(output_target_intermediate.detach())

            loss_D2_real = nn.BCEWithLogitsLoss()(real_preds2, torch.ones_like(real_preds2))
            loss_D2_fake = nn.BCEWithLogitsLoss()(fake_preds2, torch.zeros_like(fake_preds2))
            loss_D2 = (loss_D2_real + loss_D2_fake) / 2 / iter_size
            cumulative_loss_D2 += loss_D2.item()
            loss_D2.backward()

        # Aggiorna i pesi dopo tutte le sotto-iterazioni
        optimizer_G.step()
        optimizer_D1.step()
        optimizer_D2.step()

        # Log
        batch_time.update(time.time() - tic)
        tic = time.time()

        ave_loss.update(cumulative_loss_G)
        ave_acc.update(acc.mean().item())
        avg_sem_loss.update(loss_list[0].mean().item())
        avg_bce_loss.update(loss_list[1].mean().item())

        lr = adjust_learning_rate(optimizer_G, base_lr, num_iters, i_iter + cur_iters)

        if i_iter % config.PRINT_FREQ == 0:
            msg = ('Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, lr: {}, '
                   'Loss_G: {:.6f}, Loss_D1: {:.6f}, Loss_D2: {:.6f}, Acc: {:.6f}, Semantic Loss: {:.6f}, BCE Loss: {:.6f}').format(
                      epoch, num_epoch, i_iter, epoch_iters, batch_time.average(),
                      [x['lr'] for x in optimizer_G.param_groups], ave_loss.average(), cumulative_loss_D1,
                      cumulative_loss_D2, ave_acc.average(), avg_sem_loss.average(), avg_bce_loss.average()
                  )
            logging.info(msg)

    writer.add_scalar('train_loss_G', ave_loss.average(), global_steps)
    writer.add_scalar('train_loss_D1', cumulative_loss_D1, global_steps)
    writer.add_scalar('train_loss_D2', cumulative_loss_D2, global_steps)
    writer_dict['train_global_steps'] = global_steps + 1

### Main
Here the training loop is executed, then the validation phase is applied.

In [None]:
from re import L
# ------------------------------------------------------------------------------
# Modified based on https://github.com/HRNet/HRNet-Semantic-Segmentation
# ------------------------------------------------------------------------------

import argparse
import os
import pprint
import albumentations as A
import logging
import timeit

import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim
from tensorboardX import SummaryWriter
from torch import optim
import matplotlib.pyplot as plt
from IPython.display import clear_output, Image, display
import sys

os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"

def parse_args():
    parser = argparse.ArgumentParser(description='Train segmentation network')

    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        default="configs/loveda/pidnet_small_loveda.yaml", #file di configurazione da usare
                        type=str)
    parser.add_argument('--seed', type=int, default=304)
    parser.add_argument('--opts',
                        help="Modify config options using the command-line",
                        default=[],
                        nargs=argparse.REMAINDER)

    args, unknown = parser.parse_known_args()
    update_config(config, args) #aggiorna config con tutti i parametri trovati nel file di configurazione

    return args


def get_pidnet_model(config):
    class_weights_pidnet = torch.tensor([0.000000, 0.116411, 0.266041, 0.607794, 1.511413, 0.745507, 0.712438, 3.040396])
    # criterion
    if config.LOSS.USE_OHEM:
        sem_criterion = OhemCrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                        thres=config.LOSS.OHEMTHRES,
                                        min_kept=config.LOSS.OHEMKEEP,
                                        weight=class_weights_pidnet)
    else:
        sem_criterion = CrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                     weight=class_weights_pidnet)

    imgnet = 'imagenet' in config.MODEL.PRETRAINED
    model = get_seg_model(config, imgnet_pretrained=imgnet)
    bd_criterion = BondaryLoss()
    
    model = FullModel(model, sem_criterion, bd_criterion)

    return model

def create_ema_model(model, config, gpus):
    """
    The name ema_model literally stands for Exponential Moving Average model. 
    In the DACS code you’ll see that after each gradient update to your student network, they do:
    
    #for each pair of parameters (ema_param, param)
    ema_param.data = alpha * ema_param.data + (1 - alpha) * param.data
        
    that’s exactly the EMA update rule. By keeping a separate copy of the network whose weights 
    are the exponential moving average of the student’s weights, you get a “smoothed,” more stable
    model to generate pseudo-labels from. Hence the second network is called ema_model.
    """
    
    ema_model = get_pidnet_model(config)
    
    for param in ema_model.parameters():
        param.detach_()
    mp = list(model.parameters())
    mcp = list(ema_model.parameters())
    n = len(mp)
    for i in range(0, n):
        mcp[i].data[:] = mp[i].data[:].clone()
    ema_model = nn.DataParallel(ema_model, device_ids=gpus).cuda()
    return ema_model


def main():
    args = parse_args()

    if args.seed > 0:
        import random
        print('Seeding with', args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)

    logger, final_output_dir, tb_log_dir = create_logger(config, args.cfg, 'train')
    logger.info(pprint.pformat(args))
    logger.info(config)

    writer_dict = {
        'writer': SummaryWriter(tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    if torch.cuda.is_available():
        device = torch.device("cuda")
        print("Using CUDA")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        print("Using MPS")
    else:
        device = torch.device("cpu")
        print("Using CPU")

    if torch.cuda.is_available():
        # cudnn related setting
        cudnn.benchmark = config.CUDNN.BENCHMARK
        cudnn.deterministic = config.CUDNN.DETERMINISTIC
        cudnn.enabled = config.CUDNN.ENABLED
        gpus = list(config.GPUS)
        if torch.cuda.device_count() != len(gpus):
            print("The gpu numbers do not match!")
            return 0
    gpus = list(config.GPUS)

    imgnet = 'imagenet' in config.MODEL.PRETRAINED
    model = get_seg_model(config, imgnet_pretrained=imgnet)
    
    batch_size = config.TRAIN.BATCH_SIZE_PER_GPU * len(gpus)
    # prepare data
    #crop_size = (config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
    crop_size = (512, 512)

    train_trasform = None

    if config.TRAIN.AUGMENTATION.ENABLE:
        list_augmentations = []
        if config.TRAIN.AUGMENTATION.TECHNIQUES.RANDOM_CROP:
            list_augmentations.append(A.RandomResizedCrop(1024, 1024, p=0.5))
        if config.TRAIN.AUGMENTATION.TECHNIQUES.HORIZONTAL_FLIP:
            list_augmentations.append(A.HorizontalFlip(p=0.5))
        if config.TRAIN.AUGMENTATION.TECHNIQUES.COLOR_JITTER:
            list_augmentations.append(A.ColorJitter(p=0.5))
        if config.TRAIN.AUGMENTATION.TECHNIQUES.GAUSSIAN_BLUR:
            list_augmentations.append(A.GaussianBlur(p=0.5))
        if config.TRAIN.AUGMENTATION.TECHNIQUES.GAUSSIAN_NOISE:
            list_augmentations.append(A.GaussNoise(std_range=(0.2, 0.3), p=0.5))
        if len(list_augmentations) != 0:
            train_trasform = A.Compose(list_augmentations)

    #The eval() function evaluates the specified expression, if the expression is a legal Python statement, it will be executed.
    train_dataset = LoveDA(
                        root=config.DATASET.ROOT,
                        list_path=config.DATASET.TRAIN_SET,
                        num_classes=config.DATASET.NUM_CLASSES,
                        multi_scale=config.TRAIN.MULTI_SCALE,
                        flip=config.TRAIN.FLIP,
                        enable_augmentation=True,
                        ignore_label=config.TRAIN.IGNORE_LABEL,
                        base_size=config.TRAIN.BASE_SIZE,
                        crop_size=crop_size,
                        scale_factor=config.TRAIN.SCALE_FACTOR,
                        horizontal_flip=config.TRAIN.AUGMENTATION.TECHNIQUES.HORIZONTAL_FLIP,
                        gaussian_blur=config.TRAIN.AUGMENTATION.TECHNIQUES.GAUSSIAN_BLUR,
                        transform=None)

    trainloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=False,
        drop_last=True)


    targetloader = None
    if config.TRAIN.DACS.ENABLE or config.TRAIN.GAN.ENABLE:
        target_dataset = LoveDA(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TARGET_SET,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=config.TRAIN.MULTI_SCALE,
        flip=config.TRAIN.FLIP,
        enable_augmentation=True,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TRAIN.BASE_SIZE,
        crop_size=crop_size,
        scale_factor=config.TRAIN.SCALE_FACTOR,
        horizontal_flip=config.TRAIN.AUGMENTATION.TECHNIQUES.HORIZONTAL_FLIP,
        gaussian_blur=config.TRAIN.AUGMENTATION.TECHNIQUES.GAUSSIAN_BLUR,
        random_crop=config.TRAIN.AUGMENTATION.TECHNIQUES.RANDOM_CROP,
        transform=None)

        targetloader = torch.utils.data.DataLoader(
            target_dataset,
            batch_size=batch_size,
            shuffle=config.TRAIN.SHUFFLE,
            num_workers=config.WORKERS,
            pin_memory=False,
            drop_last=True)


    test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
    test_dataset = LoveDA(
                        root=config.DATASET.ROOT,
                        list_path=config.DATASET.TEST_SET,
                        num_classes=config.DATASET.NUM_CLASSES,
                        multi_scale=False,
                        flip=False,
                        ignore_label=config.TRAIN.IGNORE_LABEL,
                        base_size=config.TEST.BASE_SIZE,
                        crop_size=test_size,
                        weighted=False)

    testloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=False)

    # criterion
    if config.LOSS.USE_OHEM:
        sem_criterion = OhemCrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                        thres=config.LOSS.OHEMTHRES,
                                        min_kept=config.LOSS.OHEMKEEP,
                                        weight=train_dataset.class_weights)
    else:
        sem_criterion = CrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                    weight=train_dataset.class_weights)

    bd_criterion = BondaryLoss()

    
    model = FullModel(model, sem_criterion, bd_criterion)
    
    if torch.cuda.is_available():
        model = nn.DataParallel(model, device_ids=gpus).cuda() #per noi inutile
    else:
        model = model.to(device)

    # optimizer
    if config.TRAIN.OPTIMIZER == 'sgd':
        params_dict = dict(model.named_parameters())
        params = [{'params': list(params_dict.values()), 'lr': config.TRAIN.LR}]

        optimizer = torch.optim.SGD(params,
                                lr=config.TRAIN.LR,
                                momentum=config.TRAIN.MOMENTUM,
                                weight_decay=config.TRAIN.WD,
                                nesterov=config.TRAIN.NESTEROV,
                                )
    else:
        raise ValueError('Only Support SGD optimizer')

    epoch_iters = int(train_dataset.__len__() / config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))

    best_mIoU = 0
    last_epoch = 0
    flag_rm = config.TRAIN.RESUME
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            print('-'*60)
            checkpoint = torch.load(model_state_file, map_location={'cuda:0': 'cpu'})
            best_mIoU = checkpoint['best_mIoU']
            last_epoch = checkpoint['epoch']
            dct = checkpoint['state_dict']

            model.module.model.load_state_dict({k.replace('model.', ''): v for k, v in dct.items() if k.startswith('model.')})
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))

    start = timeit.default_timer()
    end_epoch = config.TRAIN.END_EPOCH
    num_iters = config.TRAIN.END_EPOCH * epoch_iters
    real_end = 120+1 if 'camvid' in config.DATASET.TRAIN_SET else end_epoch

    # grafici
    #plt.ion()  # Modalità interattiva
    #fig, ax = plt.subplots(2, 1, figsize=(10, 8))  # Due grafici: uno per le loss, uno per la mean IoU
    train_loss_history = []
    eval_loss_history = []
    mean_iou_history = []

    activate_ema_model = True
    ema_model = None
    deeplab_model = None
    # deeplab_model.load_state_dict(torch.load("/kaggle/input/deeplab-pretrained/deeplabv2_loveda_best.pth"))
    # deeplab_model.to(device)
    # deeplab_model.eval()
    for epoch in range(last_epoch, real_end):

        # --- EMA model creation---
        if config.TRAIN.DACS.ENABLE and activate_ema_model and epoch > -1:
            ema_model = create_ema_model(model, config, gpus)
            activate_ema_model = False
        # --- EMA model creation---
    

        current_trainloader = trainloader
        if current_trainloader.sampler is not None and hasattr(current_trainloader.sampler, 'set_epoch'):
            current_trainloader.sampler.set_epoch(epoch)

        if config.TRAIN.GAN.ENABLE:

            discriminator = FCDiscriminator(num_classes=8).to(device)
            #optimizer_G = optim.SGD(model.parameters(), lr=2.5e-4, momentum=0.9, weight_decay=1e-4) paper infos, but our net is different
            optimizer_G = optimizer
            optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.9, 0.99)) #given by the paper

            if config.TRAIN.GAN.MULTI_LEVEL:
                train_adv_multi(config, epoch, config.TRAIN.END_EPOCH, epoch_iters, config.TRAIN.LR, num_iters, trainloader, targetloader, optimizer_G, optimizer_D, model, discriminator,discriminator, writer_dict)
            else:
                train_adv(config, epoch, config.TRAIN.END_EPOCH, epoch_iters, config.TRAIN.LR, num_iters, trainloader, targetloader, optimizer_G, optimizer_D, model, discriminator, writer_dict)
        
        elif config.TRAIN.DACS.ENABLE:
            train_loss = train(config, epoch, config.TRAIN.END_EPOCH,
                  epoch_iters, config.TRAIN.LR, num_iters,
                  trainloader, optimizer, model, writer_dict, deeplab_model=deeplab_model, ema_model=ema_model, targetloader=targetloader)
        else:
            
            train_loss = train(config, epoch, config.TRAIN.END_EPOCH,
                  epoch_iters, config.TRAIN.LR, num_iters,
                  trainloader, optimizer, model, writer_dict, targetloader=targetloader)

        train_loss_history.append(train_loss)

        if flag_rm == 1 or (epoch % 5 == 0 and epoch < real_end - 100) or (epoch >= real_end - 100):
            valid_loss, mean_IoU, IoU_array = validate(config, testloader, model, writer_dict)
            eval_loss_history.append(valid_loss)
            mean_iou_history.append(mean_IoU)

        if flag_rm == 1:
            flag_rm = 0
        logger.info('=> saving checkpoint to {}'.format(
            final_output_dir + 'checkpoint.pth.tar'))
        torch.save({
            'epoch': epoch+1,
            'best_mIoU': best_mIoU,
            'state_dict': model.module.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, os.path.join(final_output_dir,'checkpoint.pth.tar'))
        if mean_IoU > best_mIoU:
            best_mIoU = mean_IoU
            torch.save(model.module.state_dict(),
                    os.path.join(final_output_dir, 'best.pt'))
        msg = 'Loss: {:.3f}, MeanIU: {: 4.4f}, Best_mIoU: {: 4.4f}'.format(
                    valid_loss, mean_IoU, best_mIoU)
        logging.info(msg)
        logging.info(IoU_array)



    torch.save(model.module.state_dict(),
            os.path.join(final_output_dir, 'final_state.pt'))

    writer_dict['writer'].close()
    end = timeit.default_timer()
    logger.info('Hours: %d' % int((end-start)/3600))
    logger.info('Done')
if __name__ == '__main__':
    main()

In [None]:
import shutil

# Crea un file ZIP della cartella /kaggle/working/
shutil.make_archive('/kaggle/working/log', 'zip', '/kaggle/working/log')
shutil.make_archive('/kaggle/working/output', 'zip', '/kaggle/working/output')

In [None]:
from IPython.display import FileLink

FileLink('/kaggle/working/output.zip')
FileLink('/kaggle/working/log.zip')