In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
from pathlib import Path
import os

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import jaccard_score
from tqdm import tqdm

from collections import OrderedDict

# from utils.collate import pad_collate
# from dataset import BaselineDataset
# # from baseline.model import SimpleSegmentationModel
# # from torchvision.models.vision_transformer import VisionTransformer
# # from baseline.SegmentationViT import SegmentationViT
# # from baseline.model_vision_transformer import TemporalVisionTransformer

# from models import TemporalViTEncoder, ConvTransformerTokensToEmbeddingNeck, ViTConvNeckModel

## Utils

In [3]:
import collections.abc
import re

import torch
from torch.nn import functional as F


def pad_tensor(x, l, pad_value=0):
    padlen = l - x.shape[0]
    pad = [0 for _ in range(2 * len(x.shape[1:]))] + [0, padlen]
    return F.pad(x, pad=pad, value=pad_value)


np_str_obj_array_pattern = re.compile(r"[SaUO]")


def pad_collate(batch, pad_value=0):
    # Utility function to be used as collate_fn for the PyTorch dataloader
    # to handle sequences of varying length.
    # Sequences are padded with zeros by default.
    #
    # Modified default_collate from the official pytorch repo
    # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if len(elem.shape) > 0:
            sizes = [e.shape[0] for e in batch]
            m = max(sizes)
            if not all(s == m for s in sizes):
                # pad tensors which have a temporal dimension
                batch = [pad_tensor(e, m, pad_value=pad_value) for e in batch]
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif (
        elem_type.__module__ == "numpy"
        and elem_type.__name__ != "str_"
        and elem_type.__name__ != "string_"
    ):
        if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError("Format not managed : {}".format(elem.dtype))

            return pad_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)

    elif isinstance(elem, collections.abc.Mapping):
        return {key: pad_collate([d[key] for d in batch]) for key in elem}

    elif isinstance(elem, tuple) and hasattr(elem, "_fields"):  # namedtuple
        return elem_type(*(pad_collate(samples) for samples in zip(*batch)))

    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError("each element in list of batch should be of equal size")
        transposed = zip(*batch)
        return [pad_collate(samples) for samples in transposed]

    raise TypeError("Format not managed : {}".format(elem_type))


## Dataset

In [4]:
import os
from pathlib import Path

import geopandas as gpd
import numpy as np
import torch


class BaselineDataset(torch.utils.data.Dataset):
    def __init__(self, folder: Path, channels = [2,3,4,5,8,9]):
        super(BaselineDataset, self).__init__()
        self.folder = folder

        # Get metadata
        print("Reading patch metadata ...")
        self.meta_patch = gpd.read_file(os.path.join(folder, "metadata.geojson"))
        self.meta_patch.index = self.meta_patch["ID"].astype(int)
        self.meta_patch.sort_index(inplace=True)
        print("Done.")

        self.len = self.meta_patch.shape[0]
        self.id_patches = self.meta_patch.index
        print("Dataset ready.")

        self.channels = channels

    def __len__(self) -> int:
        return self.len

    def __getitem__(self, item: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
        id_patch = self.id_patches[item]

        # Open and prepare satellite data into T x C x H x W arrays
        path_patch = os.path.join(self.folder, "DATA_S2", "S2_{}.npy".format(id_patch))
        data = np.load(path_patch).astype(np.float32)
        data = torch.from_numpy(data)[:,self.channels,:,:]

        # Open and prepare targets
        target = np.load(
            os.path.join(self.folder, "ANNOTATIONS", "TARGET_{}.npy".format(id_patch))
        )
        target = torch.from_numpy(target[0].astype(int))

        return data, target

class BaselineDatasetTest(torch.utils.data.Dataset):
    def __init__(self, folder: Path, channels = [2,3,4,5,8,9]):
        super(BaselineDatasetTest, self).__init__()
        self.folder = folder

        # Get metadata
        print("Reading patch metadata ...")
        self.meta_patch = gpd.read_file(os.path.join(folder, "metadata.geojson"))
        self.meta_patch.index = self.meta_patch["ID"].astype(int)
        self.meta_patch.sort_index(inplace=True)
        print("Done.")

        self.len = self.meta_patch.shape[0]
        self.id_patches = self.meta_patch.index
        print("Dataset ready.")

        self.channels = channels

    def __len__(self) -> int:
        return self.len

    def __getitem__(self, item: int) -> dict[str, torch.Tensor]:
        id_patch = self.id_patches[item]

        # Open and prepare satellite data into T x C x H x W arrays
        path_patch = os.path.join(self.folder, "DATA_S2", "S2_{}.npy".format(id_patch))
        data = np.load(path_patch).astype(np.float32)
        data = torch.from_numpy(data)[:,self.channels,:,:]

        return data

## Models

In [5]:
!pip install timm



In [6]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.layers import to_2tuple
from timm.models.vision_transformer import Block
from typing import List
import torchvision.transforms as vt

def _convTranspose2dOutput(
    input_size: int,
    stride: int,
    padding: int,
    dilation: int,
    kernel_size: int,
    output_padding: int,
):
    """
    Calculate the output size of a ConvTranspose2d.
    Taken from: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
    """
    return (
        (input_size - 1) * stride
        - 2 * padding
        + dilation * (kernel_size - 1)
        + output_padding
        + 1
    )


def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float32)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def get_3d_sincos_pos_embed(embed_dim: int, grid_size: tuple, cls_token: bool = False):
    # Copyright (c) Meta Platforms, Inc. and affiliates.
    # All rights reserved.

    # This source code is licensed under the license found in the
    # LICENSE file in the root directory of this source tree.
    # --------------------------------------------------------
    # Position embedding utils
    # --------------------------------------------------------
    """
    grid_size: 3d tuple of grid size: t, h, w
    return:
    pos_embed: L, D
    """

    assert embed_dim % 16 == 0

    t_size, h_size, w_size = grid_size

    w_embed_dim = embed_dim // 16 * 6
    h_embed_dim = embed_dim // 16 * 6
    t_embed_dim = embed_dim // 16 * 4

    w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
    h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
    t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))

    w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
    h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
    t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)

    pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)

    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


class PatchEmbed(nn.Module):
    """Frames of 2D Images to Patch Embedding
    The 3D version of timm.models.vision_transformer.PatchEmbed
    """

    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        num_frames: int = 3,
        tubelet_size: int = 1,
        in_chans: int = 3,
        embed_dim: int = 768,
        norm_layer: nn.Module = None,
        flatten: bool = True,
        bias: bool = True,
    ):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_frames = num_frames
        self.tubelet_size = tubelet_size
        self.grid_size = (
            num_frames // tubelet_size,
            img_size[0] // patch_size[0],
            img_size[1] // patch_size[1],
        )
        self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
        self.flatten = flatten

        self.proj = nn.Conv3d(
            in_chans,
            embed_dim,
            kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
            stride=(tubelet_size, patch_size[0], patch_size[1]),
            bias=bias,
        )
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, T, H, W = x.shape
        assert (
            H == self.img_size[0]
        ), f"Input image height ({H}) doesn't match model ({self.img_size[0]})."
        assert (
            W == self.img_size[1]
        ), f"Input image width ({W}) doesn't match model ({self.img_size[1]})."
        x = self.proj(x)
        Hp, Wp = x.shape[3], x.shape[4]
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # B,C,T,H,W -> B,C,L -> B,L,C
        x = self.norm(x)
        return x, Hp, Wp


class Norm2d(nn.Module):
    def __init__(self, embed_dim: int):
        super().__init__()
        self.ln = nn.LayerNorm(embed_dim, eps=1e-6)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = self.ln(x)
        x = x.permute(0, 3, 1, 2).contiguous()
        return x


class ConvTransformerTokensToEmbeddingNeck(nn.Module):
    """
    Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers.
    Performs 4 ConvTranspose2d operations on the rearranged input with kernel_size=2 and stride=2
    """

    def __init__(
        self,
        embed_dim: int,
        output_embed_dim: int,
        # num_frames: int = 1,
        Hp: int = 14,
        Wp: int = 14,
        drop_cls_token: bool = True,
    ):
        """

        Args:
            embed_dim (int): Input embedding dimension
            output_embed_dim (int): Output embedding dimension
            Hp (int, optional): Height (in patches) of embedding to be upscaled. Defaults to 14.
            Wp (int, optional): Width (in patches) of embedding to be upscaled. Defaults to 14.
            drop_cls_token (bool, optional): Whether there is a cls_token, which should be dropped. This assumes the cls token is the first token. Defaults to True.
        """
        super().__init__()
        self.drop_cls_token = drop_cls_token
        self.Hp = Hp
        self.Wp = Wp
        self.H_out = Hp
        self.W_out = Wp
        # self.num_frames = num_frames

        kernel_size = 2
        stride = 2
        dilation = 1
        padding = 0
        output_padding = 0
        for _ in range(4):
            self.H_out = _convTranspose2dOutput(
                self.H_out, stride, padding, dilation, kernel_size, output_padding
            )
            self.W_out = _convTranspose2dOutput(
                self.W_out, stride, padding, dilation, kernel_size, output_padding
            )

        self.embed_dim = embed_dim
        self.output_embed_dim = output_embed_dim
        self.fpn1 = nn.Sequential(
            nn.ConvTranspose2d(
                self.embed_dim,
                self.output_embed_dim,
                kernel_size=kernel_size,
                stride=stride,
                dilation=dilation,
                padding=padding,
                output_padding=output_padding,
            ),
            Norm2d(self.output_embed_dim),
            nn.GELU(),
            nn.ConvTranspose2d(
                self.output_embed_dim,
                self.output_embed_dim,
                kernel_size=kernel_size,
                stride=stride,
                dilation=dilation,
                padding=padding,
                output_padding=output_padding,
            ),
        )
        self.fpn2 = nn.Sequential(
            nn.ConvTranspose2d(
                self.output_embed_dim,
                self.output_embed_dim,
                kernel_size=kernel_size,
                stride=stride,
                dilation=dilation,
                padding=padding,
                output_padding=output_padding,
            ),
            Norm2d(self.output_embed_dim),
            nn.GELU(),
            nn.ConvTranspose2d(
                self.output_embed_dim,
                self.output_embed_dim,
                kernel_size=kernel_size,
                stride=stride,
                dilation=dilation,
                padding=padding,
                output_padding=output_padding,
            ),
        )

    def forward(self, x):
        if self.drop_cls_token:
            x = x[:, 1:, :]
        x = x.permute(0, 2, 1).reshape(x.shape[0], -1, self.Hp, self.Wp)

        x = self.fpn1(x)
        x = self.fpn2(x)

        x = x.reshape((-1, self.output_embed_dim, self.H_out, self.W_out))

        return x


class TemporalViTEncoder(nn.Module):
    """Encoder from an ViT with capability to take in temporal input.

    This class defines an encoder taken from a ViT architecture.
    """

    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        num_frames: int = 1,
        tubelet_size: int = 1,
        in_chans: int = 3,
        embed_dim: int = 1024,
        depth: int = 24,
        num_heads: int = 16,
        mlp_ratio: float = 4.0,
        norm_layer: nn.Module = nn.LayerNorm,
        norm_pix_loss: bool = False,
        pretrained: str = None
    ):
        """

        Args:
            img_size (int, optional): Input image size. Defaults to 224.
            patch_size (int, optional): Patch size to be used by the transformer. Defaults to 16.
            num_frames (int, optional): Number of frames (temporal dimension) to be input to the encoder. Defaults to 1.
            tubelet_size (int, optional): Tubelet size used in patch embedding. Defaults to 1.
            in_chans (int, optional): Number of input channels. Defaults to 3.
            embed_dim (int, optional): Embedding dimension. Defaults to 1024.
            depth (int, optional): Encoder depth. Defaults to 24.
            num_heads (int, optional): Number of heads used in the encoder blocks. Defaults to 16.
            mlp_ratio (float, optional): Ratio to be used for the size of the MLP in encoder blocks. Defaults to 4.0.
            norm_layer (nn.Module, optional): Norm layer to be used. Defaults to nn.LayerNorm.
            norm_pix_loss (bool, optional): Whether to use Norm Pix Loss. Defaults to False.
            pretrained (str, optional): Path to pretrained encoder weights. Defaults to None.
        """
        super().__init__()

        # --------------------------------------------------------------------------
        # MAE encoder specifics
        self.embed_dim = embed_dim
        self.patch_embed = PatchEmbed(
            img_size, patch_size, num_frames, tubelet_size, in_chans, embed_dim
        )
        num_patches = self.patch_embed.num_patches
        self.num_frames = num_frames

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False
        )  # fixed sin-cos embedding

        self.blocks = nn.ModuleList(
            [
                Block(
                    embed_dim,
                    num_heads,
                    mlp_ratio,
                    qkv_bias=True,
                    norm_layer=norm_layer,
                )
                for _ in range(depth)
            ]
        )
        self.norm = norm_layer(embed_dim)

        self.norm_pix_loss = norm_pix_loss
        self.pretrained = pretrained

    #     self.initialize_weights()

    # def initialize_weights(self):
    #     # initialization
    #     # initialize (and freeze) pos_embed by sin-cos embedding
    #     pos_embed = get_3d_sincos_pos_embed(
    #         self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True
    #     )
    #     self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

    #     # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
    #     w = self.patch_embed.proj.weight.data
    #     torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

    #     if isinstance(self.pretrained, str):
    #         self.apply(self._init_weights)
    #         print(f"load from {self.pretrained}")
    #         load_checkpoint(self, self.pretrained, strict=False, map_location="cpu")
    #     elif self.pretrained is None:
    #         # # initialize nn.Linear and nn.LayerNorm
    #         self.apply(self._init_weights)


    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        # embed patches
        x, _, _ = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)

        return x


class ViTConvNeckModel(nn.Module):
    def __init__(self, vit_encoder, conv_neck, num_classes):
        super(ViTConvNeckModel, self).__init__()
        self.vit_encoder = vit_encoder
        self.conv_neck = conv_neck

        self.head = nn.Sequential(
            nn.Conv2d(self.conv_neck.output_embed_dim, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, num_classes, kernel_size=1),  # Final layer without padding
            # nn.Softmax(dim=1)  # Apply softmax across the channel dimension
        )

        self.encoder_im_size = self.vit_encoder.patch_embed.img_size[0]
        self.encoder_temp_length = self.vit_encoder.num_frames

    def forward(self, x):
        # Input shape: [B, T, C, H, W]
        B, T, C, H, W = x.shape
        x = x.permute(0,2,3,4,1) #[B, T, C, H, W] -> [B, C, H, W, T]

        # Interpolate to make T = 3 and resize H and W for encoder
        x_resized = F.interpolate(x, size=(H, W, self.encoder_temp_length))
        x = x_resized.permute(0,1,4,2,3) #[B, C, H, W, T] -> [B, C, T, H, W]
        B, C, T, H, W = x.shape

        # Define padding values
        pad_h = (self.encoder_im_size - H) // 2
        pad_w = (self.encoder_im_size - W) // 2

        # Define the transform with padding mode 'edge'
        pad_transform = vt.Pad(padding=(pad_w, pad_h), padding_mode='edge')
        x_reshaped = x.reshape(-1, T, H, W)
        x_padded = pad_transform(x_reshaped)
        x_padded = x_padded.view(B, C, T, self.encoder_im_size, self.encoder_im_size)

        # Pass through ViT encoder
        vit_output = self.vit_encoder(x_padded)

        # Pass through the convolutional neck to transform tokens into spatial embeddings
        neck_output = self.conv_neck(vit_output)

        # Output shape: [B, num_classes, H_out, W_out], apply bilinear upsampling to match input size
        # resized_output = F.interpolate(neck_output, size=(H, W), mode='bilinear', align_corners=True)
        H_out, W_out = neck_output.shape[-2:]
        crop_h_start = (H_out - H) // 2
        crop_w_start = (W_out - W) // 2
        neck_output_cropped = neck_output[:, :, crop_h_start:crop_h_start + H, crop_w_start:crop_w_start + W]

        output_final = self.head(neck_output_cropped)

        return output_final


## TRAINING

In [10]:
DIR = Path("/content/drive/My Drive/")  # Path to your Google Drive
DATA_PATH_TRAIN = DIR / "data-challenge-invent-mines-2024/DATA/DATA/TRAIN"  # Replace 'dataset' with the actual folder name where your data is stored
CHECKPOINT_PATH = DIR / "multi_temporal_crop_classification_Prithvi_100M.pth"
LAST_CHECKPOINT_PATH = DIR / "checkpoints/vit_epoch1.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {DEVICE}")

device: cuda


In [8]:
def print_iou_per_class(
    targets: torch.Tensor,
    preds: torch.Tensor,
    nb_classes: int,
) -> None:
    """
    Compute IoU between predictions and targets, for each class.

    Args:
        targets (torch.Tensor): Ground truth of shape (B, H, W).
        preds (torch.Tensor): Model predictions of shape (B, nb_classes, H, W).
        nb_classes (int): Number of classes in the segmentation task.
    """

    # Compute IoU for each class
    # Note: I use this for loop to iterate also on classes not in the demo batch

    iou_per_class = []
    for class_id in range(nb_classes):
        iou = jaccard_score(
            targets == class_id,
            preds == class_id,
            average="binary",
            zero_division=0,
        )
        iou_per_class.append(iou)

    for class_id, iou in enumerate(iou_per_class):
        print(
            "class {} - IoU: {:.4f} - targets: {} - preds: {}".format(
                class_id, iou, (targets == class_id).sum(), (preds == class_id).sum()
            )
        )


def print_mean_iou(targets: torch.Tensor, preds: torch.Tensor) -> None:
    """
    Compute mean IoU between predictions and targets.

    Args:
        targets (torch.Tensor): Ground truth of shape (B, H, W).
        preds (torch.Tensor): Model predictions of shape (B, nb_classes, H, W).
    """

    mean_iou = jaccard_score(targets, preds, average="macro")
    print(f"meanIOU (over existing classes in targets): {mean_iou:.4f}")

def split_state_dict(state_dict: OrderedDict) -> dict:
    # Create dictionaries for each component
    backbone_dict = OrderedDict()
    neck_dict = OrderedDict()
    decode_head_dict = OrderedDict()
    auxiliary_head_dict = OrderedDict()

    # Iterate through the state_dict and classify based on the prefix
    for key, value in state_dict.items():
        if key.startswith("backbone."):
            backbone_dict[key[len("backbone."):]] = value  # Remove the prefix for cleaner dict
        elif key.startswith("neck."):
            neck_dict[key[len("neck."):]] = value
        elif key.startswith("decode_head."):
            decode_head_dict[key[len("decode_head."):]] = value
        elif key.startswith("auxiliary_head."):
            auxiliary_head_dict[key[len("auxiliary_head."):]] = value

    return {
        "backbone": backbone_dict,
        "neck": neck_dict,
        "decode_head": decode_head_dict,
        "auxiliary_head": auxiliary_head_dict
    }

def custom_model_init(checkpoint='model/multi_temporal_crop_classification_Prithvi_100M.pth', device = "cpu"):
    state_dict = torch.load(checkpoint, map_location=torch.device(device))['state_dict']
    split_dicts = split_state_dict(state_dict)

    backbone_state_dict = split_dicts["backbone"]
    neck_state_dict = split_dicts["neck"]
    decode_head_state_dict = split_dicts["decode_head"]
    auxiliary_head_state_dict = split_dicts["auxiliary_head"]

    # Params
    num_frames = 3
    img_size = 224
    num_workers = 2

    num_layers = 6
    patch_size = 16
    embed_dim = 768
    num_heads = 8
    tubelet_size = 1
    max_epochs = 80
    eval_epoch_interval = 5

    bands = [0,1,2,3,4,5]
    output_embed_dim = embed_dim * num_frames

    # You can now use these to load the corresponding nn.Module parts
    vit_encoder = TemporalViTEncoder(
        img_size=img_size,
        patch_size=patch_size,
        num_frames=num_frames,
        tubelet_size=tubelet_size,
        in_chans=len(bands),
        embed_dim=768,
        depth=6,
        num_heads=num_heads,
        mlp_ratio=4.0,
        norm_pix_loss=False
    )
    missing_keys, unexpected_keys = vit_encoder.load_state_dict(backbone_state_dict, strict=False)
    if missing_keys or unexpected_keys:
        print("VIT Encoder INIT:")
        print(f"Missing keys: {missing_keys}")
        print(f"Unexpected keys: {unexpected_keys}")

    conv_neck = ConvTransformerTokensToEmbeddingNeck(
        embed_dim=embed_dim * num_frames,
        output_embed_dim=output_embed_dim,
        drop_cls_token=True,
        Hp=14,
        Wp=14,
    )
    missing_keys, unexpected_keys = conv_neck.load_state_dict(neck_state_dict, strict=False)
    if missing_keys or unexpected_keys:
        print("ConvNeck Encoder INIT:")
        print(f"Missing keys: {missing_keys}")
        print(f"Unexpected keys: {unexpected_keys}")

    # ViTConvNeckModel(
    #     vit_encoder=vit_encoder,
    #     conv_neck=conv_neck,
    # )

    # Initialize the adaptive model with 10 classes (or any other number of classes)
    return ViTConvNeckModel(
        vit_encoder=vit_encoder,
        conv_neck=conv_neck,
        num_classes=20
    )


def train_model(
    data_folder: Path,
    nb_classes: int,
    input_channels: int,
    num_epochs: int = 10,
    accumulation_steps = 10,
    batch_size: int = 4,
    learning_rate: float = 1e-3,
    device: str = "cpu",
    verbose: bool = False,
) -> ViTConvNeckModel:
    """
    Training pipeline.
    """
    # Create data loader
    dataset = BaselineDataset(data_folder)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, collate_fn=pad_collate, shuffle=True
    )

    # Initialize the model, loss function, and optimizer
    # model = SimpleSegmentationModel(input_channels, nb_classes)
    # model = custom_model_init(checkpoint=CHECKPOINT_PATH)
    model = torch.load(LAST_CHECKPOINT_PATH, map_location=torch.device('cpu'))
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Move the model to the appropriate device (GPU if available)
    device = torch.device(device)
    model.to(device)

    # Training loop
    for epoch in range(num_epochs):
        model.train()  # Set the model to training mode
        running_loss = 0.0
        concat_target = torch.Tensor().to(device)
        concat_preds = torch.Tensor().to(device)
        optimizer.zero_grad()

        for i, (inputs, targets) in tqdm(enumerate(dataloader), total=len(dataloader)):
            # Move data to device
            inputs = inputs.to(device)  # Satellite data
            targets = targets.long()
            targets = targets.to(device)

            # Forward pass
            outputs = model(inputs)

            # Loss computation
            loss = criterion(outputs, targets)

            # Normalize the loss by the number of accumulation steps
            loss = loss / accumulation_steps

            # Backward pass (accumulate gradients)
            loss.backward()

            # Accumulate loss
            running_loss += loss.item() * accumulation_steps

            # Only perform optimizer step every accumulation_steps batches
            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(dataloader):
                # Perform optimizer step
                optimizer.step()

                # Zero the parameter gradients
                optimizer.zero_grad()

            # Get the predicted class per pixel (B, H, W)
            preds = torch.argmax(outputs, dim=1)

            # Move data from GPU/Metal to CPU
            concat_target = torch.cat([concat_target, targets.view(-1)], dim=0)
            concat_preds = torch.cat([concat_preds, preds.view(-1)], dim=0)

            if verbose and (i + 1) % accumulation_steps == 0:
                # Print mean IoU for debugging after every 10 batches
                print_mean_iou(concat_target.cpu().numpy(), concat_preds.cpu().numpy())
                concat_target = torch.Tensor().to(device)
                concat_preds = torch.Tensor().to(device)

        # Print the loss for this epoch
        epoch_loss = running_loss / len(dataloader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}")
        torch.save(model, DIR / f"checkpoints/vit_prithvi{num_epochs}.pth")

    print("Training complete.")
    return model



In [None]:
model = train_model(
    data_folder=Path(DATA_PATH_TRAIN),
    nb_classes=20,
    input_channels=6,
    num_epochs=2,
    batch_size=4,
    accumulation_steps = 32,
    learning_rate=1e-4,
    device=DEVICE,
    verbose=True,
)

Reading patch metadata ...
Done.
Dataset ready.


  model = torch.load(LAST_CHECKPOINT_PATH, map_location=torch.device('cpu'))
  7%|▋         | 32/490 [04:58<1:10:35,  9.25s/it]

meanIOU (over existing classes in targets): 0.0211


 13%|█▎        | 64/490 [09:50<1:05:02,  9.16s/it]

meanIOU (over existing classes in targets): 0.0192


 20%|█▉        | 96/490 [14:41<1:00:10,  9.16s/it]

meanIOU (over existing classes in targets): 0.0187


 26%|██▌       | 128/490 [19:37<56:32,  9.37s/it]

meanIOU (over existing classes in targets): 0.0203


 33%|███▎      | 160/490 [24:26<50:19,  9.15s/it]

meanIOU (over existing classes in targets): 0.0185


 39%|███▉      | 192/490 [29:19<46:37,  9.39s/it]

meanIOU (over existing classes in targets): 0.0214


 46%|████▌     | 224/490 [34:12<41:02,  9.26s/it]

meanIOU (over existing classes in targets): 0.0023


 47%|████▋     | 231/490 [35:17<40:45,  9.44s/it]

## TEST

In [None]:
DATA_PATH_TEST = DIR / "data-challenge-invent-mines-2024/DATA/DATA/TEST"  # Replace 'dataset' with the actual folder name where your data is stored
CHECKPOINT_PATH = DIR / "multi_temporal_crop_classification_Prithvi_100M.pth"
LAST_CHECKPOINT_PATH = DIR / "checkpoints/vit_epoch1.pth"

In [None]:
import pandas as pd

def masks_to_str(predictions: np.ndarray) -> list[str]:
    """
    Convert the

    Args:
        predictions (np.ndarray): predictions as a 3D batch (B, H, W)

    Returns:
        list[str]: a list of B strings, each string is a flattened stringified prediction mask
    """
    return [" ".join(f"{x}" for x in np.ravel(x)) for x in predictions]


def decode_masks(
    masks: list[str],
    target_shape: tuple[int, int] = (128, 128),
) -> np.ndarray:
    """
    Convert each string in masks back to a 1D list of integers.

    Args:
        masks (list[str]): list of stringified masks

    Returns:
        np.ndarray: reconstructed batch of masks
    """
    return np.array(
        [
            np.fromstring(mask, sep=" ", dtype=np.uint8).reshape(target_shape)
            for mask in masks
        ]
    )

def test_model(
        name: str,
        checkpoint_path: str,
        input_channels: int,
        nb_classes: int,
        data_folder: Path,
        batch_size: int = 1,
):
    # Load model
    # Initialize the model architecture
    model = custom_model_init(checkpoint=CHECKPOINT_PATH)  # Skip loading the checkpoint here

    # Load the saved state_dict
    model = torch.load(checkpoint_path, map_location=torch.device('cpu'))

    model.eval()

    # Load dataset
    dataset = BaselineDatasetTest(data_folder)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, collate_fn=pad_collate, shuffle=False
    )

    # Evaluate model
    all_preds = torch.zeros(len(dataloader), 128, 128)
    for i, images in tqdm(enumerate(dataloader), total=len(dataloader)):
        with torch.no_grad():
            preds = model(images)  # Only the 10th image
            preds = torch.argmax(preds, dim=1)

        all_preds[batch_size*i:batch_size*(i+1)] = preds

    all_preds = all_preds.int()

    # all_preds_flat = all_preds.cpu().numpy().flatten()

    # Print mIoU for the test set
    # print_iou_per_class(all_targets_flat, all_preds_flat, nb_classes)
    # print_mean_iou(all_targets_flat, all_preds_flat)

    # Generate the csv submission file
    masks = masks_to_str(all_preds)
    submission = pd.DataFrame.from_dict({"ID": range(len(all_preds)), "MASKS": masks})
    submission["ID"] = submission["ID"] + 20000
    submission.to_csv(DIR / f"submissions/submission_{name}.csv", index=False)


In [None]:
test_model(
    name="1epoch",
    checkpoint_path=LAST_CHECKPOINT_PATH,
    input_channels=6,
    nb_classes=20,
    data_folder=Path(DATA_PATH_TEST),
    batch_size=1,
)