In [17]:
import mlflow
import torch
import torch.nn as nn
from torchmetrics.segmentation import DiceScore 
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large, DeepLabV3_MobileNet_V3_Large_Weights
import albumentations as A
import lightning as L
from terratorch.models import EncoderDecoderFactory
from terratorch.datamodules import GenericNonGeoSegmentationDataModule

In [18]:
datamodule = GenericNonGeoSegmentationDataModule(
    batch_size=32,
    num_workers=8,
    dataset_bands=[
        "BLUE",
        "GREEN",
        "RED",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
    ],
    output_bands=[
        "BLUE",
        "GREEN",
        "RED",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
    ],
    rgb_indices=[2, 1, 0],
    train_data_root="datasets/hls_burn_scars/data",
    val_data_root="datasets/hls_burn_scars/data",
    test_data_root="datasets/hls_burn_scars/data",
    train_split="datasets/hls_burn_scars/splits/train.txt",
    val_split="datasets/hls_burn_scars/splits/val.txt",
    test_split="datasets/hls_burn_scars/splits/test.txt",
    img_grep="*_merged.tif",
    label_grep="*.mask.tif",
    means=[
        0.033349706741586264,
        0.05701185520536176,
        0.05889748132001316,
        0.2323245113436119,
        0.1972854853760658,
        0.11944914225186566,
    ],
    stds=[
        0.02269135568823774,
        0.026807560223070237,
        0.04004109844362779,
        0.07791732423672691,
        0.08708738838140137,
        0.07241979477437814,
    ],
    num_classes=2,
    train_transform=[A.D4(), A.pytorch.ToTensorV2()],
    test_transform=[A.pytorch.ToTensorV2()],
    no_data_replace=0,
    no_label_replace=-1,
)

In [19]:
class Student(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.model = deeplabv3_mobilenet_v3_large(
            weights=DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT,
        )
        self.model.backbone["0"][0] = nn.Conv2d(
            num_channels, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
        )
        self.model.classifier[4] = nn.Conv2d(
            256, 1, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True
        )

    def forward(self, x):
        return self.model(x)

In [20]:
class Teacher(nn.Module):
    def __init__(self):
        super().__init__()
        factory = EncoderDecoderFactory()
        self.model = factory.build_model(
            task="segmentation",
            backbone="prithvi_eo_v2_300",
            backbone_pretrained=True,
            backbone_bands=["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
            necks=[
                {"name": "SelectIndices", "indices": [5, 11, 17, 23]},
                {"name": "ReshapeTokensToImage"},
                {"name": "LearnedInterpolateToPyramidal"},
            ],
            decoder="UNetDecoder",
            decoder_channels=[512, 256, 128, 64],
            num_classes=2,
        )
        checkpoint = torch.load(
            "teachers/Prithvi_EO_V2_300M_BurnScars.pt", map_location="cpu"
        )
        self.model.load_state_dict(checkpoint["state_dict"], strict=False)

    def forward(self, x):
        return self.model(x)

In [None]:
class DistilPrithvi(L.LightningModule):
    def __init__(
        self,
        soft_loss_weight=1.0,
    ):
        super().__init__()
        self.teacher = Teacher()
        self.student = Student(num_channels=6)
        self.soft_loss_func = nn.BCEWithLogitsLoss()
        self.hard_loss_func = nn.BCEWithLogitsLoss()
        self.dice = DiceScore(num_classes=1)

        self.soft_loss_weight = soft_loss_weight
        self.teacher.eval()
        self.teacher.requires_grad_(False)

    def forward(self, image):
        return self.student(image)["out"]

    def step(self, batch, stage):
        image = batch["image"]
        mask = batch["mask"].unsqueeze(1).float()

        y = self.forward(image)

        loss = self.hard_loss_func(y, mask)
        self.log(f"{stage}_loss", loss, on_epoch=True)

        y_dice = torch.sigmoid(y) > 0.5
        self.log(f"{stage}_dice", self.dice(y_dice, mask), on_epoch=True)
        return loss

    def training_step(self, batch):
        return self.step(batch, "train")

    def validation_step(self, batch):
        return self.step(batch, "val")

    def test_step(self, batch):
        return self.step(batch, "test")

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=0.0001)

In [22]:
distilprithvi = DistilPrithvi(
    soft_loss_weight=0.0,
)

INFO:root:Loaded weights for HLSBands.BLUE in position 0 of patch embed
INFO:root:Loaded weights for HLSBands.GREEN in position 1 of patch embed
INFO:root:Loaded weights for HLSBands.RED in position 2 of patch embed
INFO:root:Loaded weights for HLSBands.NIR_NARROW in position 3 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_1 in position 4 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_2 in position 5 of patch embed


In [23]:
trainer = L.Trainer(
    max_epochs=100,
    log_every_n_steps=1,
    val_check_interval=0.5,
)

INFO:pytorch_lightning.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [24]:
mlflow.pytorch.autolog()

with mlflow.start_run():
    trainer.fit(
        distilprithvi,
        datamodule,
    )
    trainer.test(
        distilprithvi,
        datamodule,
    )

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type              | Params | Mode 
-------------------------------------------------------------
0 | teacher        | Teacher           | 324 M  | eval 
1 | student        | Student           | 11.0 M | train
2 | soft_loss_func | BCEWithLogitsLoss | 0      | train
3 | hard_loss_func | BCEWithLogitsLoss | 0      | train
4 | dice           | DiceScore         | 0      | train
-------------------------------------------------------------
11.0 M    Trainable params
324 M     Non-trainable params
335 M     Total params
1,340.917 Total estimated model params size (MB)
292       Modules in train mode
589       Modules in eval mode


Epoch 27:  62%|██████▎   | 10/16 [04:56<02:57,  0.03it/s, v_num=64]        
Epoch 99: 100%|██████████| 16/16 [00:15<00:00,  1.03it/s, v_num=66]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 16/16 [00:19<00:00,  0.84it/s, v_num=66]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 4/4 [00:00<00:00,  8.02it/s]


/home/mkoza/workspace/ml/distilprithvi/venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 24. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
