In [1]:
import collections.abc
import re

import torch
from torch.nn import functional as F


def pad_tensor(x, l, pad_value=0):
    padlen = l - x.shape[0]
    pad = [0 for _ in range(2 * len(x.shape[1:]))] + [0, padlen]
    return F.pad(x, pad=pad, value=pad_value)


np_str_obj_array_pattern = re.compile(r"[SaUO]")


def pad_collate(batch, pad_value=0):
    # Utility function to be used as collate_fn for the PyTorch dataloader
    # to handle sequences of varying length.
    # Sequences are padded with zeros by default.
    #
    # Modified default_collate from the official pytorch repo
    # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if len(elem.shape) > 0:
            sizes = [e.shape[0] for e in batch]
            m = max(sizes)
            if not all(s == m for s in sizes):
                # pad tensors which have a temporal dimension
                batch = [pad_tensor(e, m, pad_value=pad_value) for e in batch]
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif (
        elem_type.__module__ == "numpy"
        and elem_type.__name__ != "str_"
        and elem_type.__name__ != "string_"
    ):
        if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError("Format not managed : {}".format(elem.dtype))

            return pad_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)

    elif isinstance(elem, collections.abc.Mapping):
        return {key: pad_collate([d[key] for d in batch]) for key in elem}

    elif isinstance(elem, tuple) and hasattr(elem, "_fields"):  # namedtuple
        return elem_type(*(pad_collate(samples) for samples in zip(*batch)))

    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError("each element in list of batch should be of equal size")
        transposed = zip(*batch)
        return [pad_collate(samples) for samples in transposed]

    raise TypeError("Format not managed : {}".format(elem_type))


In [2]:
"""
Baseline Pytorch Dataset
"""

import os
from pathlib import Path

import geopandas as gpd
import numpy as np
import torch


class BaselineDataset(torch.utils.data.Dataset):
    def __init__(self, folder: Path):
        super(BaselineDataset, self).__init__()
        self.folder = folder

        # Get metadata
        print("Reading patch metadata ...")
        self.meta_patch = gpd.read_file(os.path.join(folder, "metadata.geojson"))
        self.meta_patch.index = self.meta_patch["ID"].astype(int)
        self.meta_patch.sort_index(inplace=True)
        print("Done.")

        self.len = self.meta_patch.shape[0]
        self.id_patches = self.meta_patch.index
        print("Dataset ready.")

    def __len__(self) -> int:
        return self.len

    def __getitem__(self, item: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
        id_patch = self.id_patches[item]

        # Open and prepare satellite data into T x C x H x W arrays
        path_patch = os.path.join(self.folder, "DATA_S2", "S2_{}.npy".format(id_patch))
        data = np.load(path_patch).astype(np.float32)
        data = {"S2": torch.from_numpy(data)}

        # If you have other modalities, add them as fields of the `data` dict ...
        # data["radar"] = ...

        # Open and prepare targets
        target = np.load(
            os.path.join(self.folder, "ANNOTATIONS", "TARGET_{}.npy".format(id_patch))
        )
        target = torch.from_numpy(target[0].astype(int))

        return data, target


In [3]:
import torch.nn as nn


class SimpleSegmentationModel(nn.Module):
    def __init__(self, input_channels: int, nb_classes: int):
        super(SimpleSegmentationModel, self).__init__()

        # A very basic architecture: Encoder + Decoder
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, nb_classes, kernel_size=3, padding=1),
        )

    def forward(self, x):
        # Input x shape: (B, Channels, H, W)
        x = self.encoder(x)
        x = self.decoder(x)
        # Output x shape: (B, Classes, H, W)
        return x


In [4]:
from pathlib import Path
from sklearn.metrics import jaccard_score
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

def print_iou_per_class(
    targets: torch.Tensor,
    preds: torch.Tensor,
    nb_classes: int,
) -> None:
    """
    Compute IoU between predictions and targets, for each class.

    Args:
        targets (torch.Tensor): Ground truth of shape (B, H, W).
        preds (torch.Tensor): Model predictions of shape (B, nb_classes, H, W).
        nb_classes (int): Number of classes in the segmentation task.
    """

    # Compute IoU for each class
    # Note: I use this for loop to iterate also on classes not in the demo batch

    iou_per_class = []
    for class_id in range(nb_classes):
        iou = jaccard_score(
            targets == class_id,
            preds == class_id,
            average="binary",
            zero_division=0,
        )
        iou_per_class.append(iou)

    for class_id, iou in enumerate(iou_per_class):
        print(
            "class {} - IoU: {:.4f} - targets: {} - preds: {}".format(
                class_id, iou, (targets == class_id).sum(), (preds == class_id).sum()
            )
        )


def print_mean_iou(targets: torch.Tensor, preds: torch.Tensor) -> None:
    """
    Compute mean IoU between predictions and targets.

    Args:
        targets (torch.Tensor): Ground truth of shape (B, H, W).
        preds (torch.Tensor): Model predictions of shape (B, nb_classes, H, W).
    """

    mean_iou = jaccard_score(targets, preds, average="macro")
    print(f"meanIOU (over existing classes in targets): {mean_iou:.4f}")


def train_model(
    data_folder: Path,
    nb_classes: int,
    input_channels: int,
    num_epochs: int = 10,
    batch_size: int = 4,
    learning_rate: float = 1e-3,
    device: str = "cpu",
    verbose: bool = False,
) -> SimpleSegmentationModel:
    """
    Training pipeline.
    """
    # Create data loader
    dataset = BaselineDataset(data_folder)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, collate_fn=pad_collate, shuffle=True
    )

    # Initialize the model, loss function, and optimizer
    model = SimpleSegmentationModel(input_channels, nb_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Move the model to the appropriate device (GPU if available)
    device = torch.device(device)
    model.to(device)

    # Training loop
    for epoch in range(num_epochs):
        model.train()  # Set the model to training mode
        running_loss = 0.0

        for i, (inputs, targets) in tqdm(enumerate(dataloader), total=len(dataloader)):
            # Move data to device
            inputs["S2"] = inputs["S2"].to(device)  # Satellite data
            targets = targets.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(torch.median(inputs["S2"],1).values)  # only use the 10th image

            # Loss computation
            loss = criterion(outputs, targets)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            # Get the predicted class per pixel (B, H, W)
            preds = torch.argmax(outputs, dim=1)

            # Move data from GPU/Metal to CPU
            targets = targets.cpu().numpy().flatten()
            preds = preds.cpu().numpy().flatten()

            if verbose:
                # Print IOU for debugging
                print_iou_per_class(targets, preds, nb_classes)
                print_mean_iou(targets, preds)

        # Print the loss for this epoch
        epoch_loss = running_loss / len(dataloader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}")

    print("Training complete.")
    return model


# New model

In [12]:
from enum import Enum, IntEnum
from typing import Optional, Union

import torch
from torch import nn


class ActivationFunction(str, Enum):
    RELU: str = "relu"
    LEAKY: str = "leaky"
    ELU: str = "elu"


class NormalizationLayer(str, Enum):
    BATCH: str = "batch"
    INSTANCE: str = "instance"


class Dimensions(IntEnum):
    TWO: int = 2
    THREE: int = 3


class ConvMode(str, Enum):
    SAME: str = "same"
    VALID: str = "valid"


class UpMode(str, Enum):
    TRANSPOSED: str = "transposed"
    NEAREST: str = "nearest"
    LINEAR: str = "linear"
    BILINEAR: str = "bilinear"
    BICUBIC: str = "bicubic"
    TRILINEAR: str = "trilinear"


@torch.jit.script
def autocrop(encoder_layer: torch.Tensor, decoder_layer: torch.Tensor):
    """
    Center-crops the encoder_layer to the size of the decoder_layer,
    so that merging (concatenation) between levels/blocks is possible.
    This is only necessary for input sizes != 2**n for 'same' padding and always required for 'valid' padding.
    """
    if encoder_layer.shape[2:] != decoder_layer.shape[2:]:
        ds = encoder_layer.shape[2:]
        es = decoder_layer.shape[2:]
        assert ds[0] >= es[0]
        assert ds[1] >= es[1]
        if encoder_layer.dim() == 4:  # 2D
            encoder_layer = encoder_layer[
                :,
                :,
                ((ds[0] - es[0]) // 2) : ((ds[0] + es[0]) // 2),
                ((ds[1] - es[1]) // 2) : ((ds[1] + es[1]) // 2),
            ]
        elif encoder_layer.dim() == 5:  # 3D
            assert ds[2] >= es[2]
            encoder_layer = encoder_layer[
                :,
                :,
                ((ds[0] - es[0]) // 2) : ((ds[0] + es[0]) // 2),
                ((ds[1] - es[1]) // 2) : ((ds[1] + es[1]) // 2),
                ((ds[2] - es[2]) // 2) : ((ds[2] + es[2]) // 2),
            ]
    return encoder_layer, decoder_layer


def conv_layer(dim: int) -> Union[nn.Conv2d, nn.Conv3d]:
    conv_layers: dict = {Dimensions.TWO: nn.Conv2d, Dimensions.THREE: nn.Conv3d}
    return conv_layers[dim]


def get_conv_layer(
    in_channels: int,
    out_channels: int,
    kernel_size: int = 3,
    stride: int = 1,
    padding: int = 1,
    bias: bool = True,
    dim: int = Dimensions.TWO,
) -> Union[nn.Conv2d, nn.Conv3d]:
    layer: Union[nn.Conv2d, nn.Conv3d] = conv_layer(dim=dim)
    return layer(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        bias=bias,
    )


def conv_transpose_layer(dim: int) -> Union[nn.ConvTranspose2d, nn.ConvTranspose3d]:
    conv_transpose_layers: dict = {
        Dimensions.TWO: nn.ConvTranspose2d,
        Dimensions.THREE: nn.ConvTranspose3d,
    }

    return conv_transpose_layers[dim]


def get_up_layer(
    in_channels: int,
    out_channels: int,
    kernel_size: int = 2,
    stride: int = 2,
    dim: int = Dimensions.TWO,
    up_mode: str = UpMode.TRANSPOSED,
) -> Union[Union[nn.ConvTranspose2d, nn.ConvTranspose3d], nn.Upsample]:
    if up_mode == UpMode.TRANSPOSED:
        return conv_transpose_layer(dim=dim)(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
        )
    else:
        return nn.Upsample(scale_factor=2.0, mode=up_mode)


def maxpool_layer(dim: int) -> Union[nn.MaxPool2d, nn.MaxPool3d]:
    maxpool_layers: dict = {
        Dimensions.TWO: nn.MaxPool2d,
        Dimensions.THREE: nn.MaxPool3d,
    }
    return maxpool_layers[dim]


def get_maxpool_layer(
    kernel_size: int = 2, stride: int = 2, padding: int = 0, dim: int = Dimensions.TWO
) -> Union[nn.MaxPool2d, nn.MaxPool3d]:
    layer = maxpool_layer(dim=dim)
    return layer(kernel_size=kernel_size, stride=stride, padding=padding)


def get_activation_layer(activation: str) -> Union[nn.ReLU, nn.LeakyReLU, nn.ELU]:
    activation_functions: dict = {
        ActivationFunction.RELU: nn.ReLU(),
        ActivationFunction.LEAKY: nn.LeakyReLU(negative_slope=0.1),
        ActivationFunction.ELU: nn.ELU(),
    }

    return activation_functions[activation]


def get_normalization_layer(
    normalization: str, num_channels: int, dim: int
) -> Union[
    Union[nn.BatchNorm2d, nn.BatchNorm3d],
    Union[nn.InstanceNorm2d, nn.InstanceNorm3d],
]:
    normalization_layers: dict = {
        Dimensions.TWO: {
            NormalizationLayer.BATCH: nn.BatchNorm2d(num_channels),
            NormalizationLayer.INSTANCE: nn.InstanceNorm2d(num_channels),
        },
        Dimensions.THREE: {
            NormalizationLayer.BATCH: nn.BatchNorm3d(num_channels),
            NormalizationLayer.INSTANCE: nn.InstanceNorm3d(num_channels),
        },
    }

    return normalization_layers[dim][normalization]


class Concatenate(nn.Module):
    def __init__(self):
        super(Concatenate, self).__init__()

    def forward(self, layer_1, layer_2):
        x = torch.cat((layer_1, layer_2), 1)

        return x


class DownBlock(nn.Module):
    """
    A helper Module that performs 2 Convolutions and 1 MaxPool.
    An activation follows each convolution.
    A normalization layer follows each convolution.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        pooling: bool = True,
        activation: str = ActivationFunction.RELU,
        normalization: Optional[str] = None,
        dim: int = Dimensions.TWO,
        conv_mode: str = ConvMode.SAME,
    ):
        super().__init__()

        conv_modes: dict = {ConvMode.SAME: 1, ConvMode.VALID: 0}

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pooling = pooling
        self.normalization = normalization
        self.padding = conv_modes[conv_mode]
        self.dim = dim
        self.activation = activation

        # conv layers
        self.conv1 = get_conv_layer(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=3,
            stride=1,
            padding=self.padding,
            bias=True,
            dim=self.dim,
        )
        self.conv2 = get_conv_layer(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            kernel_size=3,
            stride=1,
            padding=self.padding,
            bias=True,
            dim=self.dim,
        )

        # pooling layer
        if self.pooling:
            self.pool = get_maxpool_layer(
                kernel_size=2, stride=2, padding=0, dim=self.dim
            )

        # activation layers
        self.act1 = get_activation_layer(activation=self.activation)
        self.act2 = get_activation_layer(activation=self.activation)

        # normalization layers
        if self.normalization:
            self.norm1 = get_normalization_layer(
                normalization=self.normalization,
                num_channels=self.out_channels,
                dim=self.dim,
            )
            self.norm2 = get_normalization_layer(
                normalization=self.normalization,
                num_channels=self.out_channels,
                dim=self.dim,
            )

    def forward(self, x):
        y = self.conv1(x)  # convolution 1
        y = self.act1(y)  # activation 1
        if self.normalization:
            y = self.norm1(y)  # normalization 1
        y = self.conv2(y)  # convolution 2
        y = self.act2(y)  # activation 2
        if self.normalization:
            y = self.norm2(y)  # normalization 2

        before_pooling = y  # save the outputs before the pooling operation
        if self.pooling:
            y = self.pool(y)  # pooling
        return y, before_pooling


class UpBlock(nn.Module):
    """
    A helper Module that performs 2 Convolutions and 1 UpConvolution/Upsample.
    An activation follows each convolution.
    A normalization layer follows each convolution.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        activation: str = ActivationFunction.RELU,
        normalization: Optional[str] = None,
        dim: int = Dimensions.TWO,
        conv_mode: str = ConvMode.SAME,
        up_mode: str = UpMode.TRANSPOSED,
    ):
        super().__init__()

        conv_modes: dict = {ConvMode.SAME: 1, ConvMode.VALID: 0}

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalization = normalization
        self.padding = conv_modes[conv_mode]
        self.dim = dim
        self.activation = activation

        self.up_mode = up_mode

        # upconvolution/upsample layer
        self.up = get_up_layer(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=2,
            stride=2,
            dim=self.dim,
            up_mode=self.up_mode,
        )

        # conv layers
        self.conv0 = get_conv_layer(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
            dim=self.dim,
        )
        self.conv1 = get_conv_layer(
            in_channels=2 * self.out_channels,
            out_channels=self.out_channels,
            kernel_size=3,
            stride=1,
            padding=self.padding,
            bias=True,
            dim=self.dim,
        )
        self.conv2 = get_conv_layer(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            kernel_size=3,
            stride=1,
            padding=self.padding,
            bias=True,
            dim=self.dim,
        )

        # activation layers
        self.act0 = get_activation_layer(self.activation)
        self.act1 = get_activation_layer(self.activation)
        self.act2 = get_activation_layer(self.activation)

        # normalization layers
        if self.normalization:
            self.norm0 = get_normalization_layer(
                normalization=self.normalization,
                num_channels=self.out_channels,
                dim=self.dim,
            )
            self.norm1 = get_normalization_layer(
                normalization=self.normalization,
                num_channels=self.out_channels,
                dim=self.dim,
            )
            self.norm2 = get_normalization_layer(
                normalization=self.normalization,
                num_channels=self.out_channels,
                dim=self.dim,
            )

        # concatenate layer
        self.concat = Concatenate()

    def forward(self, encoder_layer, decoder_layer):
        """
        Forward pass
        encoder_layer: Tensor from the encoder pathway
        decoder_layer: Tensor from the decoder pathway (to be up'd)
        """
        up_layer = self.up(decoder_layer)  # up-convolution/up-sampling
        cropped_encoder_layer, dec_layer = autocrop(encoder_layer, up_layer)  # cropping

        if self.up_mode != UpMode.TRANSPOSED:
            # We need to reduce the channel dimension with a conv layer
            up_layer = self.conv0(up_layer)  # convolution 0
        up_layer = self.act0(up_layer)  # activation 0
        if self.normalization:
            up_layer = self.norm0(up_layer)  # normalization 0

        merged_layer = self.concat(up_layer, cropped_encoder_layer)  # concatenation
        y = self.conv1(merged_layer)  # convolution 1
        y = self.act1(y)  # activation 1
        if self.normalization:
            y = self.norm1(y)  # normalization 1
        y = self.conv2(y)  # convolution 2
        y = self.act2(y)  # acivation 2
        if self.normalization:
            y = self.norm2(y)  # normalization 2
        return y


class UNet(nn.Module):
    """
    activation: 'relu', 'leaky', 'elu'
    normalization: 'batch', 'instance', 'group{group_size}'
    conv_mode: 'same', 'valid'
    dim: 2, 3
    up_mode: 'transposed', 'nearest', 'linear', 'bilinear', 'bicubic', 'trilinear'
    """

    def __init__(
        self,
        in_channels: int = 1,
        out_channels: int = 2,
        n_blocks: int = 3,
        start_filters: int = 32,
        activation: str = ActivationFunction.RELU,
        normalization: str = NormalizationLayer.BATCH,
        conv_mode: str = ConvMode.SAME,
        dim: int = Dimensions.TWO,
        up_mode: str = UpMode.TRANSPOSED,
    ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n_blocks = n_blocks
        self.start_filters = start_filters
        self.activation = activation
        self.normalization = normalization
        self.conv_mode = conv_mode
        self.dim = dim
        self.up_mode = up_mode

        self.down_blocks = []
        self.up_blocks = []

        # create encoder path
        for i in range(self.n_blocks):
            num_filters_in = self.in_channels if i == 0 else num_filters_out
            num_filters_out = self.start_filters * (2**i)
            pooling = True if i < self.n_blocks - 1 else False

            down_block = DownBlock(
                in_channels=num_filters_in,
                out_channels=num_filters_out,
                pooling=pooling,
                activation=self.activation,
                normalization=self.normalization,
                conv_mode=self.conv_mode,
                dim=self.dim,
            )

            self.down_blocks.append(down_block)

        # create decoder path (requires only n_blocks-1 blocks)
        for i in range(n_blocks - 1):
            num_filters_in = num_filters_out
            num_filters_out = num_filters_in // 2

            up_block = UpBlock(
                in_channels=num_filters_in,
                out_channels=num_filters_out,
                activation=self.activation,
                normalization=self.normalization,
                conv_mode=self.conv_mode,
                dim=self.dim,
                up_mode=self.up_mode,
            )

            self.up_blocks.append(up_block)

        # final convolution
        self.conv_final = get_conv_layer(
            num_filters_out,
            self.out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
            dim=self.dim,
        )

        # add the list of modules to current module
        self.down_blocks = nn.ModuleList(self.down_blocks)
        self.up_blocks = nn.ModuleList(self.up_blocks)

        # initialize the weights
        self.initialize_parameters()

    @staticmethod
    def weight_init(module, method, **kwargs):
        if isinstance(
            module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)
        ):
            method(module.weight, **kwargs)  # weights

    @staticmethod
    def bias_init(module, method, **kwargs):
        if isinstance(
            module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)
        ):
            method(module.bias, **kwargs)  # bias

    def initialize_parameters(
        self, method_weights=nn.init.xavier_uniform_, method_bias=nn.init.zeros_
    ):
        for module in self.modules():
            self.weight_init(module, method_weights)  # initialize weights
            self.bias_init(module, method_bias)  # initialize bias

    def forward(self, x: torch.tensor):
        encoder_output = []

        # Encoder pathway
        x = torch.transpose(x,1,2)
        for module in self.down_blocks:
            x, before_pooling = module(x)
            encoder_output.append(before_pooling)

        # Decoder pathway
        for i, module in enumerate(self.up_blocks):
            before_pool = encoder_output[-(i + 2)]
            x = module(before_pool, x)

        x = self.conv_final(x)

        return x

    def __repr__(self):
        attributes = {
            attr_key: self.__dict__[attr_key]
            for attr_key in self.__dict__.keys()
            if "_" not in attr_key[0] and "training" not in attr_key
        }
        d = {self.__class__.__name__: attributes}
        return f"{d}"


# if __name__ == "__main__":
#     unet = UNet(
#         in_channels=1,
#         out_channels=2,
#         n_blocks=4,
#         start_filters=32,
#         activation=ActivationFunction.RELU,
#         normalization=NormalizationLayer.BATCH,
#         conv_mode=ConvMode.SAME,
#         dim=Dimensions.TWO,
#         up_mode=UpMode.TRANSPOSED,
#     )
#     from torchinfo import summary

#     # [B, C, H, W]
#     summary = summary(model=unet, input_size=(1, 1, 512, 512), device="cpu")

In [11]:
from pathlib import Path
from sklearn.metrics import jaccard_score
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm


def print_iou_per_class(
    targets: torch.Tensor,
    preds: torch.Tensor,
    nb_classes: int,
) -> None:
    """
    Compute IoU between predictions and targets, for each class.

    Args:
        targets (torch.Tensor): Ground truth of shape (B, H, W).
        preds (torch.Tensor): Model predictions of shape (B, nb_classes, H, W).
        nb_classes (int): Number of classes in the segmentation task.
    """

    # Compute IoU for each class
    # Note: I use this for loop to iterate also on classes not in the demo batch

    iou_per_class = []
    for class_id in range(nb_classes):
        iou = jaccard_score(
            targets == class_id,
            preds == class_id,
            average="binary",
            zero_division=0,
        )
        iou_per_class.append(iou)

    for class_id, iou in enumerate(iou_per_class):
        print(
            "class {} - IoU: {:.4f} - targets: {} - preds: {}".format(
                class_id, iou, (targets == class_id).sum(), (preds == class_id).sum()
            )
        )


def print_mean_iou(targets: torch.Tensor, preds: torch.Tensor) -> None:
    """
    Compute mean IoU between predictions and targets.

    Args:
        targets (torch.Tensor): Ground truth of shape (B, H, W).
        preds (torch.Tensor): Model predictions of shape (B, nb_classes, H, W).
    """

    mean_iou = jaccard_score(targets, preds, average="macro")
    print(f"meanIOU (over existing classes in targets): {mean_iou:.4f}")


def train_model_2(
    data_folder: Path,
    nb_classes: int,
    input_channels: int,
    num_epochs: int = 10,
    batch_size: int = 4,
    learning_rate: float = 1e-3,
    device: str = "cpu",
    verbose: bool = False,
) -> UNet:
    """
    Training pipeline.
    """
    # Create data loader
    dataset = BaselineDataset(data_folder)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, collate_fn=pad_collate, shuffle=True
    )

    # Initialize the model, loss function, and optimizer
    model = UNet( in_channels=input_channels,out_channels=nb_classes,dim=3)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Move the model to the appropriate device (GPU if available)
    device = torch.device(device)
    model.to(device)

    # Training loop
    for epoch in range(num_epochs):
        model.train()  # Set the model to training mode
        running_loss = 0.0

        for i, (inputs, targets) in tqdm(enumerate(dataloader), total=len(dataloader)):
            # Move data to device
            inputs["S2"] = inputs["S2"].to(device)  # Satellite data
            targets = targets.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass

            outputs = model(inputs["S2"]) 
            outputs_median_time = torch.median(outputs,2).values

            # Loss computation
            loss = criterion(outputs_median_time, targets)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            # Get the predicted class per pixel (B, H, W)
            preds = torch.argmax(outputs_median_time, dim=1)

            # Move data from GPU/Metal to CPU
            targets = targets.cpu().numpy().flatten()
            preds = preds.cpu().numpy().flatten()

            if verbose:
                # Print IOU for debugging
                print_iou_per_class(targets, preds, nb_classes)
                print_mean_iou(targets, preds)

        # Print the loss for this epoch
        epoch_loss = running_loss / len(dataloader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}")

    print("Training complete.")
    return model


# if __name__ == "__main__":
#     # Example usage:
#     model = train_model(
#         data_folder=Path(
#             "/Users/louis.stefanuto.c/Documents/pastis-benchmark-mines2024/DATA/TRAIN/"
#         ),
#         nb_classes=20,
#         input_channels=10,
#         num_epochs=100,
#         batch_size=32,
#         learning_rate=1e-3,
#         device="mps",
#         verbose=True,
#     )


NameError: name 'UNet' is not defined

In [None]:

DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
folder=Path('/kaggle/input/data-challenge-invent-mines-2024/DATA/DATA/TRAIN')
model = train_model_2(
            data_folder=folder,
            nb_classes= 20,
            input_channels=10,
            num_epochs = 10,
            batch_size= 5,
            learning_rate= 1e-3,
            device= DEVICE,
            verbose= True,
        )


In [None]:
import gc
gc.collect()

In [None]:
torch.cuda.empty_cache()


In [None]:
torch.save(model.state_dict(), 'unet3d.pt')

In [26]:
import torch
from torch.utils.data import DataLoader
from torch import nn
model_file = 'unet3d.pt'

"""
Baseline Pytorch Dataset
"""

import os
from pathlib import Path

import geopandas as gpd
import numpy as np
import torch


class OutputDataset(torch.utils.data.Dataset):
    def __init__(self, folder: Path):
        super(OutputDataset, self).__init__()
        self.folder = folder

        # Get metadata
        print("Reading patch metadata ...")
        self.meta_patch = gpd.read_file(os.path.join(folder, "metadata.geojson"))
        self.meta_patch.index = self.meta_patch["ID"].astype(int)
        self.meta_patch.sort_index(inplace=True)
        print("Done.")

        self.len = self.meta_patch.shape[0]
        self.id_patches = self.meta_patch.index
        print("Dataset ready.")

    def __len__(self) -> int:
        return self.len

    def __getitem__(self, item: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
        id_patch = self.id_patches[item]

        # Open and prepare satellite data into T x C x H x W arrays
        path_patch = os.path.join(self.folder, "DATA_S2", "S2_{}.npy".format(id_patch))
        data = np.load(path_patch).astype(np.float32)
        data = {"S2": torch.from_numpy(data)}
        data['ID']=id_patch
        # If you have other modalities, add them as fields of the `data` dict ...
        # data["radar"] = ...

        # Open and prepare targets
#         target = np.load(
#             os.path.join(self.folder, "ANNOTATIONS", "TARGET_{}.npy".format(id_patch))
#         )
#         target = torch.from_numpy(target[0].astype(int))

        return data

import numpy as np


def masks_to_str(predictions: np.ndarray) -> list[str]:
    """
    Convert the

    Args:
        predictions (np.ndarray): predictions as a 3D batch (B, H, W)

    Returns:
        list[str]: a list of B strings, each string is a flattened stringified prediction mask
    """
    return [" ".join(f"{x}" for x in np.ravel(x)) for x in predictions]

import pandas as pd 

def eval_model(
    data_folder: Path,
    model_file: str,
    batch_size: int = 1,
    device: str = "cpu",

) -> None:
    """
    Training pipeline.
    """
    # Create data loader
    dataset = OutputDataset(data_folder)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, collate_fn=pad_collate, shuffle=False
    )

    # Load the saved model
    model = UNet(in_channels=10,out_channels=20,dim=3)
    model.load_state_dict(torch.load(model_file, weights_only=True))
    model.to(device)

    # Set the model in evaluation mode
    model.eval()

    # 3. Evaluate the Model on Test Samples
    # Disable gradient computation for evaluation
    res = []
    with torch.no_grad():
        for i, (inputs) in tqdm(enumerate(dataloader), total=len(dataloader)):
            # Move data to device
            inputs["S2"] = inputs["S2"].to(device)  # Satellite data
            patch_id = inputs['ID']
            # Forward pass through the model
            outputs = model(inputs['S2'])
            outputs_median_time = torch.median(outputs,2).values
            preds = torch.argmax(outputs_median_time, dim=1).cpu()
            preds_str = masks_to_str(preds)
            res.append([patch_id.item(),preds_str[0]])
    return pd.DataFrame(res,columns=['ID','MASKS'])
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
preds = eval_model(data_folder='/kaggle/input/data-challenge-invent-mines-2024/DATA/DATA/TEST',model_file='/kaggle/input/unet3d/pytorch/default/1/unet3d.pt',device=DEVICE)
preds.to_csv('submissions_1_unet3d.csv',index=False)

Reading patch metadata ...
Done.
Dataset ready.


100%|██████████| 474/474 [01:10<00:00,  6.76it/s]


In [28]:
preds

Unnamed: 0,ID,MASKS
0,20000,0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
1,20001,0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
2,20002,19 19 19 19 19 19 19 19 19 9 9 9 9 9 9 9 9 9 9...
3,20003,19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 1...
4,20004,0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
...,...,...
469,20469,0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ...
470,20470,19 1 19 19 19 19 19 19 19 19 19 19 19 19 1 1 1...
471,20471,19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 1...
472,20472,19 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...


In [29]:
preds.to_csv('submissions_1_unet3d.csv',index=False)