# Utils

In [86]:
# Libraries
import torch
import torchvision.transforms as T
from torch.utils.data import Dataset
from pathlib import Path
from PIL import Image
import numpy as np
import glob
import torchvision.transforms.v2 as v2
from typing import Optional, Tuple, Dict, List, Any


In [87]:
# # HPC Terrabyte
# # adapt the user to your needs
# USER = "di97ren"
# # keep the following unchanged
# ROOT = Path("/dss/dsstbyfs02/pn49ci/pn49ci-dss-0022")
# USER_PATH = ROOT / f"users/{USER}"
# DATA_PATH = ROOT / "data"

# # when you are on a local dev client
# # uncomment these lines and make necessary ajdustments
# #ROOT = Path("C:/projects/hands-on-DL")
# #DATA_PATH = Path("../data")

# # Configure the path to the GWHD dataset for your environment
# DATASET_ROOT = DATA_PATH / "xview2-subset"

# IMG_PATH = DATASET_ROOT / "png_images"
# TARGET_PATH = DATASET_ROOT / 'targets'

     

# Dataset Class

In [88]:
def transform():
    """Transform für Bilder & Masken"""
    return v2.Compose([
        v2.RandomHorizontalFlip(),
        v2.RandomVerticalFlip(),
        v2.RandomRotation(degrees=15),
        v2.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05)),
        v2.ToDtype(torch.float32, scale=True)  # Automatische Skalierung auf [0,1]
    ])

def image_transform():
    """Nur für RGB-Bilder"""
    return v2.Compose([
        v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        v2.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

class xView2Dataset(Dataset):

    def __init__(self,
                 png_path: str,
                 target_path: str,
                 transform: callable = None,
                 image_transform: callable = None):
        
        self.png_path = png_path
        self.target_path = target_path
        self.transform = transform
        self.image_transform = image_transform
        

        # get all pre-disaster images:
        self.pre_images = sorted(self.png_path.glob("*_pre_disaster.png"))
        
        self.pairs = [] #

        for pre_img_path in self.pre_images:
            post_img_path = self.png_path / pre_img_path.name.replace("_pre_disaster", "_post_disaster")

            post_target_path = self.target_path / pre_img_path.name.replace("_pre_disaster", "_post_disaster")
            pre_target_path = pre_target_path = self.target_path / pre_img_path.name


            if post_img_path.exists() and post_target_path.exists() and pre_target_path.exists():
                self.pairs.append((pre_img_path, post_img_path, pre_target_path, post_target_path))

        print(f"Total pairs found: {len(self.pairs)}")
        assert len(self.pairs) > 0, "No matching image-pairs found!"



        # super().__init__()

    def __len__(self):
        return len(self.pre_images)

    def __getitem__(self, index):
        pre_img_path, post_img_path, pre_target_path, post_target_path = self.pairs[index]

        # load images and target masks with 
        
        pre_img = Image.open(pre_img_path).convert("RGB")
        post_img = Image.open(post_img_path).convert("RGB")
        pre_target_mask = Image.open(pre_target_path).convert('L')
        post_target_mask = Image.open(post_target_path).convert('L')

        # convert to numpy arrays
        pre_img = np.array(pre_img, dtype=np.float32) / 255.0
        post_img = np.array(post_img, dtype=np.float32) / 255.0
        pre_target_mask = np.array(pre_target_mask, dtype=np.float32)
        post_target_mask = np.array(post_target_mask, dtype=np.float32)

        # convert to Tensor
        pre_img = torch.tensor(pre_img).permute(2, 0, 1)  # (H, W, C) → (C, H, W)
        post_img = torch.tensor(post_img).permute(2, 0, 1)
        pre_target_mask = torch.tensor(pre_target_mask).unsqueeze(0)  # (H, W) → (1, H, W)
        post_target_mask = torch.tensor(post_target_mask).unsqueeze(0)

    # Transformation (optional)


        if self.transform:
            stack = torch.cat([pre_img, post_img, pre_target_mask, post_target_mask], dim=0)  # (8, H, W)
            stack = self.transform(stack)

            pre_img, post_img, pre_target_mask, post_target_mask = stack[:3], stack[3:6], stack[6:7], stack[7:8]
        
        if self.image_transform:
            
            # Nur auf Bilder Normalisierung anwenden
            pre_img = self.image_transform(pre_img)
            post_img = self.image_transform(post_img)

        return pre_img, post_img, pre_target_mask, post_target_mask 
    

    
        
        

In [89]:
# dataset = xView2Dataset(png_path= IMG_PATH,
#                         target_path = TARGET_PATH,
#                         transform = transform (),
#                         image_transform = image_transform())

In [90]:
# print(len(dataset))

# Data Loader

In [91]:
def collate_fn(batch):
    # Extrahiere die Daten aus dem Batch
    pre_imgs, post_imgs, pre_masks, post_masks = zip(*batch)

    # Staple die Tensoren entlang der Batch-Dimension
    pre_imgs = torch.stack(pre_imgs, dim=0)  # [Batch size, Channels, Height, Width]
    post_imgs = torch.stack(post_imgs, dim=0)  # [Batch size, Channels, Height, Width]
    pre_masks = torch.stack(pre_masks, dim=0)  # [Batch size, Height, Width]
    post_masks = torch.stack(post_masks, dim=0)  # [Batch size, Height, Width]

    # Wenn Masken ebenfalls als 4D-Tensor erwartet werden (Batch size, 1, Height, Width)
    #pre_masks = pre_masks.unsqueeze(1)  # [Batch size, 1, Height, Width]
    #post_masks = post_masks.unsqueeze(1)  # [Batch size, 1, Height, Width]

    # # Die Eingabe ist ein 4D-Tensor [Batch size, Channels, Height, Width]
    images = torch.cat([pre_imgs, post_imgs], dim=0)  # [2 * Batch size, Channels, Height, Width]
    
    # # Masken als 4D-Tensor [Batch size, 1, Height, Width]
    masks = torch.cat([pre_masks, post_masks], dim=0)  # [2 * Batch size, 1, Height, Width]

    return {
        "image": images,  # 4D Tensor
        "masks": masks,   # 4D Tensor
    }


In [92]:
# def collate_fn(batch):
#     # Extrahiere die Daten aus dem Batch
#     pre_imgs, post_imgs, pre_masks, post_masks = zip(*batch)

#     # Staple die Tensoren entlang der Batch-Dimension
#     pre_imgs = torch.stack(pre_imgs, dim=0)
#     post_imgs = torch.stack(post_imgs, dim=0)
#     pre_masks = torch.stack(pre_masks, dim=0)
#     post_masks = torch.stack(post_masks, dim=0)

#     Format für das Modell: {"image": [pre_imgs, post_imgs], "mask": [pre_masks, post_masks]}
#     Damit das Modell beide Masken gleichzeitig verwenden kann
#     return {
#         "image": [pre_imgs, post_imgs],  # Beide Bilder als Eingabe
#         "masks": [pre_masks, post_masks]  # Beide Masken als Ziel
#     }



# train_loader = DataLoader(
#     dataset,
#     batch_size=5,
#     collate_fn = collate_fn,
#     shuffle=True,
#     num_workers= 0,
#     drop_last=True
# )

# # Teste die Anzahl der Batches im DataLoader
# print(f"Number of batches in train_loader: {len(train_loader)}")


In [93]:
# import argparse

# args = argparse.Namespace(
#         train_images_path = IMG_PATH,
#         train_masks_path = TARGET_PATH,
#         batch_size = 5
# )  


# training_dataset = create_train_dataloader(args)

In [94]:
# print(len(train_dataset))

In [95]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
from torchmetrics.classification import F1Score

def convert_to_labels(loss_str, logits):
    if loss_str == "mse":
        preds = torch.round(F.relu(logits[:, 0], inplace=True)) + 1
        preds[preds > 4] = 4
    # elif loss_str == "coral":
    #     preds = torch.sum(torch.sigmoid(logits) > 0.5, dim=1) + 1
    else:
        preds = torch.argmax(logits, dim=1) + 1
    return preds

class F1(torchmetrics.Metric):
    def __init__(self, args):
        super().__init__()
        self.loss_str = args.loss_str
        self.n_class = 5
        self.softmax = nn.Softmax(dim=1)
        self.f1_metric = F1Score(task = "multiclass", num_classes=self.n_class) 
                                 #,average='macro', mdmc_average='global')

    def update(self, preds, targets):
        probs = self.softmax(preds) if self.loss_str not in "mse" else preds
        if self.n_class == 5:
            preds = convert_to_labels(self.loss_str, probs)
            mask = targets > 0
            targets = targets[mask]
            preds = preds[mask]
        else:
            preds = torch.argmax(probs, dim=1)

        self.f1_metric.update(preds, targets)

    def compute(self):
        f1_score = self.f1_metric.compute()
        return f1_score.cpu()

    def reset(self):
        self.f1_metric.reset()




#  Model
Ich möchte ein Siamese Neural Network trainiern. Es soll UNets nutzen und ResNet50 als Encoder


## Definition: 
### Dilation:
Die Dilation in einer Convolutional-Schicht erweitert das Empfangsfenster des Filters, indem Lücken zwischen den benachbarten Pixeln eingefügt werden. Dadurch wird der Kontext, den der Filter bei der Verarbeitung eines Bildes sieht, vergrößert, ohne die Filtergröße zu erhöhen.

Auswirkungen im Code:
In deinem U-Net-Code hat die Dilation Auswirkungen auf den Decoder:

dilation = 1: Standard-Konvolution (kein Abstand zwischen den Filtern). Der Filter sieht nur benachbarte Pixel.
dilation = 2: Der Filter überspringt jedes zweite Pixel und erfasst somit ein größeres Gebiet des Bildes.
dilation = 4: Der Filter überspringt jedes vierte Pixel und sieht ein noch größeres Gebiet.
Was passiert im Code:
Die Dilation beeinflusst, wie die Schichten im Decoder miteinander verbunden sind und wie viel Kontext sie erfassen.
Je höher der Dilation-Wert, desto mehr "Sicht" hat der Filter auf das Bild und desto mehr Kontext kann er bei der Segmentierung erfassen.
Beispiel:
Dilation = 1: Filter sieht 3x3 benachbarte Pixel.
Dilation = 2: Filter überspringt 1 Pixel zwischen den benachbarten Pixeln (sieht ein 5x5-Feld).
Dilation = 4: Filter überspringt 3 Pixel zwischen den benachbarten Pixeln (sieht ein 7x7-Feld).

In [96]:
import torch
import torch.nn.functional as F
import torchvision.models as models
from torch import nn

## Loss

In [97]:
losses = {
    "ce": nn.CrossEntropyLoss(),
    "mse": nn.MSELoss(),
}

# class Loss(nn.Module):
#     def __init__(self, args):
#         super().__init__()
#         self.loss_str = args.loss_str
#         self.losses = nn.ModuleList([losses[loss_fn] for loss_fn in self.loss_str.split("+")])

#     def forward(self, y_pred, y_true):
#         # Sicherstellen, dass y_true ein Tensor ist
#         if not isinstance(y_true, torch.Tensor):  # Wenn y_true kein Tensor ist
#             y_true = torch.tensor(y_true, dtype=torch.float32).to(y_pred.device)

#         device = y_pred.device
#         mask = y_true > 0  # Maskierung basierend auf Tensor-Werten

#         # Interpolation
#         y_pred = F.interpolate(y_pred, size=y_true.shape[1:], mode="bilinear", align_corners=False)

#         # Maskierung der Vorhersage und Zielwerte
#         y_pred = torch.stack([y_pred[:, i][mask] for i in range(y_pred.shape[1])], 1).to(device)
#         y_true = y_true[mask] - 1

#         if self.loss_str == "mse":
#             y_pred = F.relu(y_pred[:, 0], inplace=True)
#             y_true = y_true.float()
#         else:  # "ce"
#             y_true = y_true.long()

#         # Berechnung des Gesamtverlustes
#         loss = 0
#         for loss_fn in self.losses:
#             loss += loss_fn(y_pred, y_true)
#        return loss

class Loss(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.loss_str = args.loss_str
       #  self.post = args.type == "post" -> brauche ich nicht
        self.losses = nn.ModuleList([losses[loss_fn] for loss_fn in self.loss_str.split("+")])

    

    def forward(self, y_pred, y_true):
        # if self.post:
        #     device = y_pred.device
        #     mask = y_true > 0
        #     y_pred = torch.stack([y_pred[:, i][mask] for i in range(y_pred.shape[1])], 1).to(device)
        #     y_true = y_true[mask] - 1

        # device = y_pred.device
        # mask = y_true > 0

        # y_pred = torch.stack([y_pred[:, i][mask] for i in range(y_pred.shape[1])], 1).to(device)
        # y_pred = F.interpolate(y_pred, size=y_true.shape[1:], mode="bilinear", align_corners=False)

        # y_true = y_true[mask] -1
        device = y_pred.device
        mask = y_true > 0

        # Erst y_pred auf die gleiche Größe wie y_true bringen
        y_pred = F.interpolate(y_pred, size=y_true.shape[1:], mode="bilinear", align_corners=False)

        # Dann maskieren
        y_pred = torch.stack([y_pred[:, i][mask] for i in range(y_pred.shape[1])], 1).to(device)
        y_true = y_true[mask] - 1


        if self.loss_str == "mse":
            y_pred = F.relu(y_pred[:, 0], inplace=True)
            y_true = y_true.float()
        else: # "ce"
            y_true = y_true.long()

        loss = 0
        for loss_fn in self.losses:
            loss += loss_fn(y_pred, y_true)
        return loss

## Layer

In [98]:

class PPM(nn.Module):
    def __init__(self, in_channels):
        super(PPM, self).__init__()
        self.features = []
        out_channels = in_channels // 4
        print(in_channels.size())
        for bin in (1, 2, 3, 6):
            self.features.append(
                nn.Sequential(
                    nn.AdaptiveAvgPool2d(bin),
                    nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
                    nn.BatchNorm2d(out_channels, affine=True),
                    nn.LeakyReLU(negative_slope=0.01, inplace=True),
                )
            )
        self.features = nn.ModuleList(self.features)
        self.conv = nn.Conv2d(2 * in_channels, in_channels, kernel_size=1, bias=True)

    def forward(self, x):
        # Sicherstellen, dass die Eingabe die Form [Batch size, Channels, Height, Width] hat


        x_size = x.size()
        out = [x]
        for f in self.features:
            out.append(F.interpolate(f(x), x_size[2:], mode="bilinear", align_corners=True))
        out = self.conv(torch.cat(out, 1))
        return out


class ASPPModule(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
        super(ASPPModule, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_channels, affine=True)
        self.relu = nn.LeakyReLU(negative_slope=0.01, inplace=True)
        self._init_weight()

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)


class ASPP(nn.Module):
    def __init__(self, in_channels, dilation):
        super(ASPP, self).__init__()
        out_channels = in_channels // 4
        dilations = [1, 3 * dilation, 6 * dilation, 9 * dilation]
        self.aspp1 = ASPPModule(in_channels, out_channels, 1, padding=0, dilation=dilations[0])
        self.aspp2 = ASPPModule(in_channels, out_channels, 3, padding=dilations[1], dilation=dilations[1])
        self.aspp3 = ASPPModule(in_channels, out_channels, 3, padding=dilations[2], dilation=dilations[2])
        self.aspp4 = ASPPModule(in_channels, out_channels, 3, padding=dilations[3], dilation=dilations[3])

    def forward(self, x):
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        out = torch.cat((x1, x2, x3, x4), dim=1)
        return out


class AttentionLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AttentionLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.batch_norm = nn.BatchNorm2d(out_channels, affine=True)

    def forward(self, inputs):
        out = self.conv(inputs)
        out = self.batch_norm(out)
        return out


class ConvTranspose(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvTranspose, self).__init__()
        self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, bias=False)

    def forward(self, inputs):
        return self.conv(inputs)


class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.batch_norm = nn.BatchNorm2d(out_channels, affine=True)
        self.lrelu = nn.LeakyReLU(negative_slope=0.01, inplace=True)

    def forward(self, inputs):
        out = self.conv(inputs)
        out = self.batch_norm(out)
        out = self.lrelu(out)
        return out


class FusionBlock(nn.Module):
    def __init__(self, pre_conv, post_conv, channels):
        super(FusionBlock, self).__init__()
        self.pre_conv = pre_conv
        self.post_conv = post_conv
        self.conv_pre = ConvLayer(2 * channels, channels)
        self.conv_post = ConvLayer(2 * channels, channels)

    def forward(self, pre, post, dec_pre=None, dec_post=None, last_dec=False):
        pre = self.pre_conv(pre, dec_pre) if dec_pre is not None or last_dec else self.pre_conv(pre)
        post = self.post_conv(post, dec_post) if dec_post is not None or last_dec else self.post_conv(post)
        fmap = torch.cat([pre, post], 1)
        pre, post = self.conv_pre(fmap), self.conv_post(fmap)
        return pre, post


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = ConvLayer(in_channels, out_channels)
        self.conv2 = ConvLayer(out_channels, out_channels)

    def forward(self, inputs):
        out = self.conv1(inputs)
        out = self.conv2(out)
        return out


class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels, attention, dec_interp):
        super(UpsampleBlock, self).__init__()
        self.attention = attention
        self.dec_interp = dec_interp
        self.skip_channels = skip_channels
        inc = skip_channels + out_channels
        if self.dec_interp:
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=True)
        else:
            self.conv_tranpose = ConvTranspose(in_channels, out_channels)

        self.conv_block = ConvBlock(inc, out_channels)
        if skip_channels > 0 and self.attention:
            att_out = out_channels // 2
            self.conv_o = AttentionLayer(out_channels, att_out)
            self.conv_s = AttentionLayer(skip_channels, att_out)
            self.psi = AttentionLayer(att_out, 1)
            self.sigmoid = nn.Sigmoid()
            self.relu = nn.ReLU(inplace=True)

    def forward(self, inputs, skip):
        if self.dec_interp:
            out = F.interpolate(self.conv(inputs), scale_factor=2, mode="bilinear", align_corners=True)
        else:
            out = self.conv_tranpose(inputs)

        if self.skip_channels == 0:
            return self.conv_block(out)

        if self.attention:
            out_a = self.conv_o(out)
            skip_a = self.conv_s(skip)
            psi_a = self.psi(self.relu(out_a + skip_a))
            attention = self.sigmoid(psi_a)
            skip = skip * attention
        out = self.conv_block(torch.cat((out, skip), dim=1))
        return out


class OutputBlock(nn.Module):
    def __init__(self, in_channels, n_class, interpolate):
        super(OutputBlock, self).__init__()
        self.interpolate = interpolate
        # self.coral_loss = n_class == 3 -> vermutlich unwichtig, weil ich immer mehr als 3 Klassen habe. 
        # if self.coral_loss:
        #     self.conv = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False)
        #     self.bias = nn.Parameter(torch.tensor([[[1.0]], [[0.0]], [[-1.0]]]))
        # else:
        #     self.conv = nn.Conv2d(in_channels, n_class, kernel_size=1)
        self.conv = nn.Conv2d(in_channels, n_class, kernel_size=1)

    def forward(self, inputs):
        out = self.conv(inputs)
        # if self.coral_loss:
        #     out = out + self.bias
        if self.interpolate:
            size = (512, 512) if self.training else (1024, 1024)
            out = F.interpolate(out, size, mode="bilinear", align_corners=True)
        return out

## Unet

In [99]:

def get_encoder(encoder_str, dilation, pretrained=True, in_channels=3):
    assert "resnet" in encoder_str 

   
    encoder_channels = [64, 256, 512, 1024, 2048]
    replace_stride_with_dilation = [False, dilation == 4, dilation in [2, 4]]

    # ResNet Modelle
    if encoder_str == "resnet50":
        encoder = models.resnet50(pretrained=pretrained, replace_stride_with_dilation=replace_stride_with_dilation)
    elif encoder_str == "resnet101":
        encoder = models.resnet101(pretrained=pretrained, replace_stride_with_dilation=replace_stride_with_dilation)
    elif encoder_str == "resnet152":
        encoder = models.resnet152(pretrained=pretrained, replace_stride_with_dilation=replace_stride_with_dilation)
    else:
        raise f"Not implemented encoder {encoder_str}"

    if in_channels != 3:
        conv1 = encoder.conv1[0] if "st" in encoder else encoder.conv1
        conv1 = torch.nn.Conv2d(
            in_channels,
            conv1.out_channels,
            kernel_size=conv1.kernel_size,
            stride=conv1.stride,
            padding=conv1.padding,
            bias=conv1.bias,
        )
        if "st" in encoder:
            encoder.conv1[0] = conv1
        else:
            encoder.conv1 = conv1

    encoder_layer1 = nn.Sequential(encoder.conv1, encoder.bn1, nn.ReLU(inplace=True))
    encoder_layer2 = nn.Sequential(encoder.maxpool, encoder.layer1)
    encoder_layer3 = encoder.layer2
    encoder_layer4 = encoder.layer3
    encoder_layer5 = encoder.layer4

    return encoder_channels, encoder_layer1, encoder_layer2, encoder_layer3, encoder_layer4, encoder_layer5


def get_decoder(encf, dilation, attn, no_skip=False, dec_interp=False):
    decf = [512, 256, 128, 64, 32]
    if dilation == 1:
        decoder_layer1 = UpsampleBlock(encf[-1], decf[0], 0 if no_skip else encf[-2], attn, dec_interp)
        decoder_layer2 = UpsampleBlock(decf[0], decf[1], 0 if no_skip else encf[-3], attn, dec_interp)
        decoder_layer3 = UpsampleBlock(decf[1], decf[2], 0 if no_skip else encf[-4], attn, dec_interp)
        decoder_layer4 = UpsampleBlock(decf[2], decf[3], 0 if no_skip else encf[-5], attn, dec_interp)
        decoder_layer5 = UpsampleBlock(decf[3], decf[4], 0, attn, dec_interp)
    elif dilation == 2:
        decoder_layer1 = None
        decoder_layer2 = UpsampleBlock(encf[-1], decf[1], 0 if no_skip else encf[-3], attn, dec_interp)
        decoder_layer3 = UpsampleBlock(decf[1], decf[2], 0 if no_skip else encf[-4], attn, dec_interp)
        decoder_layer4 = UpsampleBlock(decf[2], decf[3], 0 if no_skip else encf[-5], attn, dec_interp)
        decoder_layer5 = UpsampleBlock(decf[3], decf[4], 0, attn, dec_interp)
    elif dilation == 4:
        decoder_layer1, decoder_layer2 = None, None
        decoder_layer3 = UpsampleBlock(encf[-1], decf[2], 0 if no_skip else encf[-4], attn, dec_interp)
        decoder_layer4 = UpsampleBlock(decf[2], decf[3], 0 if no_skip else encf[-5], attn, dec_interp)
        decoder_layer5 = UpsampleBlock(decf[3], decf[4], 0, attn, dec_interp)
    else:
        raise ValueError("Dilation can be set to 1, 2 or 4")
    return decf, decoder_layer1, decoder_layer2, decoder_layer3, decoder_layer4, decoder_layer5




class UNetTemplate(nn.Module):
    def __init__(self, args, in_channels=3):

        super(UNetTemplate, self).__init__()
        
        self.use_ppm = args.ppm
        self.use_aspp = args.aspp
        self.dilation = args.dilation
        self.no_skip = args.no_skip
        self.interpolate = args.interpolate
        self.enc_chn, self.enc_l1, self.enc_l2, self.enc_l3, self.enc_l4, self.enc_l5 = get_encoder(
            args.encoder, self.dilation, in_channels=in_channels
        )

        if self.use_ppm:
            self.ppm = PPM(self.enc_chn[-1])
        elif self.use_aspp:
            self.aspp = ASPP(self.enc_chn[-1], self.dilation)

        self.dec_chn = None
        if not self.interpolate:
            self.dec_chn, self.dec_l1, self.dec_l2, self.dec_l3, self.dec_l4, self.dec_l5 = get_decoder(
                self.enc_chn, self.dilation, args.attention, self.no_skip, args.dec_interp
            )

    def forward(self, data):
        enc1 = self.enc_l1(data)
        enc2 = self.enc_l2(enc1)
        enc3 = self.enc_l3(enc2)
        enc4 = self.enc_l4(enc3)
        enc5 = self.enc_l5(enc4)

        if self.use_ppm:
            enc5 = self.ppm(enc5)
        elif self.use_aspp:
            enc5 = self.aspp(enc5)
        if self.interpolate:
            return enc5, None, None

        if self.dilation == 1:
            if self.no_skip:
                enc1, enc2, enc3, enc4 = None, None, None, None
            dec1 = self.dec_l1(enc5, enc4)
            dec2 = self.dec_l2(dec1, enc3)
            dec3 = self.dec_l3(dec2, enc2)
            dec4 = self.dec_l4(dec3, enc1)
            dec5 = self.dec_l5(dec4, None)
        elif self.dilation == 2:
            if self.no_skip:
                enc1, enc2, enc3 = None, None, None
            dec2 = self.dec_l2(enc5, enc3)
            dec3 = self.dec_l3(dec2, enc2)
            dec4 = self.dec_l4(dec3, enc1)
            dec5 = self.dec_l5(dec4, None)
        elif self.dilation == 4:
            if self.no_skip:
                enc1, enc2 = None, None
            dec3 = self.dec_l3(enc5, enc2)
            dec4 = self.dec_l4(dec3, enc1)
            dec5 = self.dec_l5(dec4, None)

        return dec5, dec4, dec3

# bekommt die Decoder Outputs als Eingabe - besteht aus einer Output-Blick Schicht, die die Schadensklasse für jedes Pixel vorhersagt

class OutputTemplate(nn.Module):
    def __init__(self, 
                 n_class: int, 
                 deep_supervision: bool, # Sollen beim Training Zwischen Schichten überwacht werden
                 dec_chn: List[int], # Liste von Ints, die die Kanäle der Decoder-Schichten des Modells repräsentiert
                 scale: int = 1, 
                 interp: bool = False, # interpolation der Ausgabe
                 enc_last:int = 0): # Anzahl der Kanäle für die letzte Encoder schicht
        
        super(OutputTemplate, self).__init__()
        self.deep_supervision = deep_supervision
        self.interp = interp
        if self.interp:
            d5 = enc_last * scale
            self.deep_supervision = False
        else:
            d3, d4, d5 = scale * dec_chn[-3], scale * dec_chn[-2], scale * dec_chn[-1]

        if self.deep_supervision:
            self.output_block_ds3 = OutputBlock(d3, n_class, interp)
            self.output_block_ds4 = OutputBlock(d4, n_class, interp)
        self.output_block = OutputBlock(d5, n_class, interp)

    def forward(self, dec5, dec4, dec3):
        out = self.output_block(dec5)
        if self.training and self.deep_supervision:
            out_dec3 = self.output_block_ds3(dec3)
            out_dec4 = self.output_block_ds4(dec4)
            return [out, out_dec4, out_dec3]
        return out


In [100]:

def concat(x, y):
    return None if x is None or y is None else torch.cat([x, y], 1)

class SiameseUNet(nn.Module):
    
    def __init__(self,
                 args: Any, 
                 n_classes: int ):
        super(SiameseUNet, self).__init__()

        # Parameter von args:
        # args.ppm: bool              -> Aktiviert PPM (Pyramid Pooling Module)
        # args.aspp: bool             -> Aktiviert ASPP (Atrous Spatial Pyramid Pooling)
        # args.dilation: int          -> Dilationstyp (1, 2, 4)
        # args.no_skip: bool          -> Deaktiviert Skip-Verbindungen
        # args.interpolate: bool      -> Aktiviert Interpolation im Decoder
        # args.deep_supervision: bool -> Aktiviert Tiefenüberwachung (falls gewünscht)
        # args.encoder_str: str       -> kann resnet50, resnet101, resnet152 sein
        
    

        # Übergabe der argumente and UNetTempale
        self.unet = UNetTemplate(args)

        # Übergabe der fest definierten n_class an OutputTemplate
        self.output_block = OutputTemplate(
            n_classes,
            args.deep_supervision,
            self.unet.dec_chn,
            2,
            args.interpolate,
            self.unet.enc_chn[-1],
        )

    def forward(self, data):
        # Zugriff auf Pre- und Post-Bilder
        pre_data = data[0]  # Vorher-Bild
        post_data = data[1]  # Nachher-Bild
        
        # Verarbeitung der Bilder im UNet
        pre_dec5, pre_dec4, pre_dec3 = self.unet(pre_data)  # Verarbeitung des Vorher-Bildes
        post_dec5, post_dec4, post_dec3 = self.unet(post_data)  # Verarbeitung des Nachher-Bildes
        
        # Kombination der Ergebnisse von Pre- und Post-Bildern
        dec5, dec4, dec3 = concat(pre_dec5, post_dec5), concat(pre_dec4, post_dec4), concat(pre_dec3, post_dec3)
        
        # Ausgabe des Modells
        out = self.output_block(dec5, dec4, dec3)
        return out


    # def forward(self, data):
        
    #      # Vorwärtsdurchlauf: Verarbeitung von Pre- und Post-Disaster-Daten
    #     pre_dec5, pre_dec4, pre_dec3 = self.unet(data[:, :3]) # da UNet5 Stufen hat - startet decoder bei 5 
    #     post_dec5, post_dec4, post_dec3 = self.unet(data[:, 3:])

    #     # Kombination der Ergebnisse von Pre- und Post-Bildern
    #     dec5, dec4, dec3 = concat(pre_dec5, post_dec5), concat(pre_dec4, post_dec4), concat(pre_dec3, post_dec3)
        
    #     # Ausgabe des Modells
    #     out = self.output_block(dec5, dec4, dec3)
    #     return out        

Arguemnte die Übergeben werden müssen: 

n_classes = 5 Nummer der Klassen für Klassifikation
        # Parameter von args:
        # args.ppm: bool              -> Aktiviert PPM (Pyramid Pooling Module)
        # args.aspp: bool             -> Aktiviert ASPP (Atrous Spatial Pyramid Pooling)
        # args.dilation: int          -> Dilationstyp (1, 2, 4)
        # args.no_skip: bool          -> Deaktiviert Skip-Verbindungen
        # args.interpolate: bool      -> Aktiviert Interpolation im Decoder
        # args.deep_supervision: bool -> Aktiviert Tiefenüberwachung (falls gewünscht)

## Model Klasse

In [101]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import os
# from utils.f1 import F1
# from model.loss import Loss
# from model.unet import SiameseUNet

class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()

        self.args = args
        self.f1_score = F1(args)
        self.model = SiameseUNet(args, n_classes=5)
        self.loss = Loss(args)
        self.best_f1 = torch.tensor(0)
        self.best_epoch = 0
        self.tta_flips = [[2], [3], [2, 3]]
        self.lr = args.lr
        self.n_class = 5
        self.softmax = nn.Softmax(dim=1)
        self.test_idx = 0


        ''' defines the computation performed at every call and must be overridden by all subclasses of torch.nn.Module. 
    takes input data, processes it through the networks layers and returns the output. 
    Output: logits, probabilities or any other form of processed data 
    '''

    def forward(self, img):
        pred = self.model(img)
        if self.args.tta:
            for flip_idx in self.tta_flips:
                pred += self.flip(self.model(self.flip(img, flip_idx)), flip_idx)
            pred /= len(self.tta_flips) + 1
        return pred
    
    def compute_loss(self, preds, label):
        if self.args.deep_supervision:
            loss = self.loss(preds[0], label)
            for i, pred in enumerate(preds[1:]):
                downsampled_label = torch.nn.functional.interpolate(label.unsqueeze(1), pred.shape[2:])
                loss += 0.5 ** (i + 1) * self.loss(pred, downsampled_label.squeeze(1))
            c_norm = 1 / (2 - 2 ** (-len(preds)))
            return c_norm * loss
        return self.loss(preds, label)
    
    @staticmethod
    def flip(data, axis):
        return torch.flip(data, dims=axis)
    
    def save_predictions(self, preds, targets):
        if self.args.loss_str == "coral":
            probs = torch.sum(torch.sigmoid(preds) > 0.5, dim=1) + 1
        elif self.args.loss_str == "mse":
            probs = torch.round(F.relu(preds[:, 0], inplace=True)) + 1
        else:
            probs = self.softmax(preds)

        probs = probs.cpu().detach().numpy()
        targets = targets.cpu().detach().numpy().astype(np.uint8)
        
        for prob, target in zip(probs, targets):
            fname = os.path.join(self.args.results, "probs", f"test_damage_{self.test_idx:05d}")
            self.test_idx += 1
            np.save(fname, prob)
            Image.fromarray(target).save(fname.replace("probs", "targets") + "_target.png")




    


## Implement Training, Validation and Testing

## Write Argument Parser:

In [102]:

from typing import Tuple, Dict, List, Any

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN_ResNet50_FPN_Weights
import torch
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import tv_tensors
from torchvision.transforms import v2
from pathlib import Path
import pandas as pd
import PIL
     

In [103]:
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter

### LOCAL!!!

In [104]:
# ROOT = Path(r"C:\Users\elena\Documents\04-geo-oma\data")
# DATA_PATH = ROOT / "xview2-subset"
# IMG_PATH = DATA_PATH / "png_images"
# TARGET_PATH = DATA_PATH / "targets"

# EXPERIMENT_GROUP = "local1"
# EXPERIMENT_ID = "exp_001"
# # Local dev client path config
# # uncomment these lines and make adjustments if necessary
# EXPERIMENT_DIR = ROOT / "experiments" / EXPERIMENT_GROUP
# EXPERIMENT_DIR.mkdir(parents=True, exist_ok=True)

# CHECKPOINTS_DIR = ROOT / "checkpoints" / EXPERIMENT_GROUP
# CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)

# writer = SummaryWriter(EXPERIMENT_DIR / EXPERIMENT_ID)

# TRAIN_BATCH_SIZE = 5
# VAL_BATCH_SIZE = 5

# VAL_SCORE_THRESHOLD =0.0

In [105]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [106]:
# HPC Terrabyte
# adapt the user to your needs
USER = "di97ren"
# keep the following unchanged
ROOT = Path("/dss/dsstbyfs02/pn49ci/pn49ci-dss-0022")
USER_PATH = ROOT / f"users/{USER}"
DATA_PATH = ROOT / "data"

# when you are on a local dev client
# uncomment these lines and make necessary ajdustments
#ROOT = Path("C:/projects/hands-on-DL")
#DATA_PATH = Path("../data")

# Configure the path to the GWHD dataset for your environment
DATASET_ROOT = DATA_PATH / "xview2-subset"

IMG_PATH = DATASET_ROOT / "png_images"
TARGET_PATH = DATASET_ROOT / 'targets'

EXPERIMENT_GROUP = "XView2"
EXPERIMENT_ID = "exp_001"

# HPC Terrabyte path config
EXPERIMENT_DIR = USER_PATH / f"experiments/{EXPERIMENT_GROUP}"
EXPERIMENT_DIR.mkdir(parents=True, exist_ok=True)

CHECKPOINTS_DIR = USER_PATH / f"checkpoints/{EXPERIMENT_GROUP}"
CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)

# Local dev client path config
# uncomment these lines and make adjustments if necessary
# EXPERIMENT_DIR = ROOT / "experiments" / EXPERIMENT_GROUP
# EXPERIMENT_DIR.mkdir(parents=True, exist_ok=True)

# CHECKPOINTS_DIR = ROOT / "checkpoints" / EXPERIMENT_GROUP
# CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)

writer = SummaryWriter(EXPERIMENT_DIR / EXPERIMENT_ID)

# Configure the batch size in order to fit the model into your GPU (these settings were used on a A100 GPU)
TRAIN_BATCH_SIZE = 5
VAL_BATCH_SIZE = 5

# the val score threshold is used to determine if a prediction should be considered
# 0.0 means that all predictions are considered, normally this is what you want
VAL_SCORE_THRESHOLD = 0.0


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     

In [107]:
print(DATASET_ROOT)

/dss/dsstbyfs02/pn49ci/pn49ci-dss-0022/data/xview2-subset


In [108]:
import argparse

args = argparse.Namespace(
    # train_images_path = IMG_PATH,
    # train_masks_path = TARGET_PATH,
    # val_images_path = IMG_PATH,
    # val_masks_path = TARGET_PATH,
    # test_images_path = IMG_PATH,
    # test_masks_path = TARGET_PATH,
    # batch_size = TRAIN_BATCH_SIZE,
    epochs = 10,
    lr = 0.0001,
    tta = True,
    deep_supervision = False,
    loss_str = 'mse',
    results = EXPERIMENT_DIR,
    optimizer = 'adamw', 
    weight_decay = 1e-5,
    momentum = 0.9,
    use_scheduler = False,
    ppm = True,
    aspp = False,
    dilation = 1,
    no_skip = True,
    interpolate = True, 
    encoder = "resnet50")




In [109]:
train_dataset = xView2Dataset(
    png_path=IMG_PATH,
    target_path = TARGET_PATH,
    # transform = transform(),
    # image_transform = image_transform()
    )

train_dataloader = DataLoader(
    train_dataset, 
    batch_size = TRAIN_BATCH_SIZE,
    shuffle = True,
    num_workers = 3,
    prefetch_factor = 2,
    drop_last = True,
    collate_fn = collate_fn
)
 
val_dataset = xView2Dataset(
    png_path=IMG_PATH,
    target_path = TARGET_PATH,
    # transform = transform(),
    # image_transform = image_transform()
    )

val_dataloader = DataLoader(
    train_dataset, 
    batch_size = VAL_BATCH_SIZE,
    shuffle = True,
    num_workers = 3,
    prefetch_factor = 2,
    drop_last = True,
    collate_fn = collate_fn
)

test_dataset = xView2Dataset(
    png_path=IMG_PATH,
    target_path = TARGET_PATH,
    # transform = transform(),
    # image_transform = image_transform()
    )

test_dataloader = DataLoader(
    train_dataset, 
    batch_size = TRAIN_BATCH_SIZE,
    shuffle = True,
    num_workers = 3,
    prefetch_factor = 2,
    drop_last = True,
    collate_fn = collate_fn
)


Total pairs found: 20
Total pairs found: 20
Total pairs found: 20


In [110]:
     # Beispiel, um Dimensionen des Dataloaders zu überprüfen
for i, batch in enumerate(test_dataloader):  # Oder valid_loader für das Validierungsset
    images = batch["image"]
    masks = batch["masks"]
    
    # Zeige die Dimensionen von Bildern und Masken an
    print(f"Batch {i}:")
    print(f"Image dimensions: {images.shape}")  # Form der Bilder
    print(f"Mask dimensions: {masks.shape}")    # Form der Masken
    
    # Stoppe nach dem ersten Batch (optional)
    break


Batch 0:
Image dimensions: torch.Size([10, 3, 1024, 1024])
Mask dimensions: torch.Size([10, 1, 1024, 1024])


In [111]:
# def train_epoch(model, train_loader, optimizer, device):
#     model.train()  # Set the model to training mode
#     total_loss = 0

#     for batch in train_loader:
#         # Hole Bilder und Masken
#         img, lbl = batch["image"].to(device), batch["masks"].to(device)

#         # Überprüfe die Dimensionen der Eingabedaten
#         print(f"Image dimensions: {img.shape}")
#         print(f"Mask dimensions: {lbl.shape}")

#         # Sicherstellen, dass es sich um 4D-Tensoren handelt
#         assert img.ndimension() == 4, f"Expected 4D input (got {img.ndimension()}D input)"
#         assert lbl.ndimension() == 4, f"Expected 4D input (got {lbl.ndimension()}D input)"
        
#         optimizer.zero_grad()

#         # Berechne die Vorhersage
#         pred = model(img)

#         # Berechne den Verlust
#         loss = model.compute_loss(pred, lbl)
#         total_loss += loss.item()

#         # Backpropagation und Optimierung
#         loss.backward()
#         optimizer.step()

#     avg_loss = total_loss / len(train_loader)
#     return avg_loss


def train_epoch(model, train_loader, optimizer, device): # adding epoch ??
    model.train()
    total_loss = 0.0
    for batch in train_loader:
        img = [x.to(device) for x in batch["image"]]  # Beide Bilder zu einem Tensor mit 6 Kanälen
        lbl = [x.to(device) for x in batch["masks"]]

    # for batch in train_loader:
    #     img = [x.to(device) for x in batch["image"]]
    #     lbl = batch['mask'].to(device)

        # img, lbl = batch["image"].to(device), batch["masks"].to(device)
        
        optimizer.zero_grad()
        pred = model.model(img)  # Use model.model directly for training
        loss = model.compute_loss(pred, lbl)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(train_loader)

def validate(model, val_loader, device): # epoch?
    model.eval()
    total_loss = 0.0
    model.f1_score.reset()
    
    with torch.no_grad():
        for batch in val_loader:
            img = [x.to(device) for x in batch["image"]]
            lbl = [x.to(device) for x in batch["masks"]]


            # img, lbl = batch["image"].to(device), batch["masks"].to(device)
            pred = model(img)  # Use model for validation (includes TTA if enabled)
            loss = model.loss(pred, lbl)
            model.f1_score.update(pred, lbl)
            total_loss += loss.item()
    
    f1_score, dmgs_f1 = model.f1_score.compute()
    model.f1_score.reset()
    
    # Update best F1 score
    if f1_score >= model.best_f1:
        model.best_f1 = f1_score
        model.best_epoch = current_epoch  # You need to keep track of current_epoch
    
    metrics = {
        "f1": round(f1_score.item(), 3),
        "val_loss": round(total_loss / len(val_loader), 3),
        "top_f1": round(model.best_f1.item(), 3),
    }
    
    # Add damage scores if available
    if dmgs_f1 is not None:
        for i in range(4):
            metrics.update({f"D{i+1}": round(dmgs_f1[i].item(), 3)})
    
    return metrics

def test(model, test_loader, device):
    model.eval()
    model.f1_score.reset()
    
    with torch.no_grad():
        for batch in test_loader:
            img = [x.to(device) for x in batch["image"]]
            lbl = [x.to(device) for x in batch["masks"]]


            # img, lbl = batch["image"].to(device), batch["masks"].to(device)
            pred = model(img)
            model.f1_score.update(pred, lbl)
            model.save_predictions(pred, lbl)
    
    f1_score, dmgs_f1 = model.f1_score.compute()
    model.f1_score.reset()
    
    metrics = {"f1": round(f1_score.item(), 3)}
    if dmgs_f1 is not None:
        for i in range(4):
            metrics.update({f"D{i+1}": round(dmgs_f1[i].item(), 3)})
    
    return metrics

In [112]:
def train_model(model, train_loader, val_loader, args, device):
    # Setup optimizer
    if args.optimizer.lower() == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    elif args.optimizer.lower() == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        # Add other optimizers as needed
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    
    # Setup scheduler if needed
    scheduler = None
    if args.use_scheduler:
        # Implement your own scheduler or use PyTorch's built-in schedulers
        # For example, to mimic the NoamLR scheduler:
        from torch.optim.lr_scheduler import LambdaLR
        
        def lr_lambda(step):
            warmup_steps = args.warmup * len(train_loader)
            total_steps = args.epochs * len(train_loader)
            step += 1  # Avoid division by zero
            
            if step < warmup_steps:
                return args.init_lr + step * (args.lr - args.init_lr) / warmup_steps
            else:
                return args.lr * (args.final_lr / args.lr) ** ((step - warmup_steps) / (total_steps - warmup_steps))
        
        scheduler = LambdaLR(optimizer, lr_lambda)
    
    # Training loop
    for epoch in range(args.epochs):
        # Train
        train_loss = train_epoch(model, train_loader, optimizer, device)
        
        # Validate
        model.eval()
        metrics = validate(model, val_loader, device, epoch)
        
        # Print metrics
        print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={metrics['val_loss']:.4f}, f1={metrics['f1']:.4f}")
        
        # Update scheduler
        if scheduler is not None:
            if isinstance(scheduler, LambdaLR):
                for _ in range(len(train_loader)):
                    scheduler.step()
            else:
                scheduler.step()
        
        # Save best model
        if metrics['f1'] >= model.best_f1:
            print(f"Saving best model with F1 score: {metrics['f1']:.4f}")
            torch.save(model.state_dict(), os.path.join(args.results, "best_model.pth"))
    
    return model

In [113]:
def main(args):
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize model
    model = Model(args).to(device)
    
    # Create dataloaders (you need to implement this based on your data)
    train_loader = train_dataloader
    val_loader = val_dataloader
    test_loader = test_dataloader
    
    # Train model
    model = train_model(model, train_loader, val_loader, args, device)
    
    # Load best model for testing
    model.load_state_dict(torch.load(os.path.join(args.results, "best_model.pth")))
    
    # Test model
    test_metrics = test(model, test_loader, device)
    print(f"Test metrics: {test_metrics}")

if __name__ == "__main__":

    # Parse arguments
    args = args
    
    # Create results directory
    os.makedirs(args.results, exist_ok=True)
    os.makedirs(os.path.join(args.results, "probs"), exist_ok=True)
    os.makedirs(os.path.join(args.results, "targets"), exist_ok=True)
    
    main(args)



AttributeError: 'int' object has no attribute 'size'