In [9]:
from terratorch.cli_tools import LightningInferenceModel
import mlflow
import torch
import torch.nn as nn
from torchvision.models.segmentation import (
    deeplabv3_mobilenet_v3_large,
    DeepLabV3_MobileNet_V3_Large_Weights,
)
import lightning as L
import matplotlib.pyplot as plt

In [10]:
CONFIG = "teachers/hls_burn_scars_teacher/burn_scars_config.yaml"
CHECKPOINT = "teachers/hls_burn_scars_teacher/Prithvi_EO_V2_300M_BurnScars.pt"

In [11]:
inference_model = LightningInferenceModel.from_config(CONFIG, CHECKPOINT)
teacher = inference_model.model
datamodule = inference_model.datamodule

/home/mkoza/workspace/ml/distilprithvi/venv/lib/python3.12/site-packages/lightning/pytorch/cli.py:530: LightningCLI's args parameter is intended to run from within Python like if it were from the command line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: sys.argv[1:]=['--f=/run/user/1003/jupyter/runtime/kernel-v3299645aeb3caefd4ec88db5fd8907a8a3bc06dc1.json'], args=['--config', 'teachers/hls_burn_scars_teacher/burn_scars_config.yaml'].
Seed set to 2
INFO:root:Loaded weights for HLSBands.BLUE in position 0 of patch embed
INFO:root:Loaded weights for HLSBands.GREEN in position 1 of patch embed
INFO:root:Loaded weights for HLSBands.RED in position 2 of patch embed
INFO:root:Loaded weights for HLSBands.NIR_NARROW in position 3 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_1 in position 4 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_2 in position 5 of patch embed
/home/mkoza/workspace/ml/distilprithvi/venv/lib/pytho

In [12]:
class DeepLabMobileNetV3Large(nn.Module):
    def __init__(self, num_channels, num_classes):
        super().__init__()
        self.model = deeplabv3_mobilenet_v3_large(
            weights=DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT,
        )
        self.model.backbone["0"][0] = nn.Conv2d(
            num_channels,
            16,
            kernel_size=(3, 3),
            stride=(2, 2),
            padding=(1, 1),
            bias=False,
        )
        self.model.classifier[4] = nn.Conv2d(
            256,
            num_classes,
            kernel_size=(1, 1),
            stride=(1, 1),
            padding=(0, 0),
            bias=True,
        )

    def forward(self, x):
        return self.model(x)

In [13]:
class DistilPrithvi(L.LightningModule):
    def __init__(
        self,
        teacher,
        student,
        kd_weight=0.5,
        kd_temperature=2.0,
    ):
        super().__init__()
        self.teacher = teacher
        self.teacher.eval()

        self.student = student

        self.kd_weight = kd_weight
        self.kd_temperature = kd_temperature
        self.kd_criterion = nn.KLDivLoss(reduction="batchmean")

        self.metrics = {
            "train": self.teacher.train_metrics,
            "val": self.teacher.val_metrics,
            "test": self.teacher.test_metrics[0],
        }

    def forward(self, x):
        return self.student(x)["out"]

    def _step(self, batch, stage):
        x = batch["image"]
        y = batch["mask"]

        y_hat_s = self(x)
        loss_target = self.teacher.criterion(y_hat_s, y)

        if self.kd_weight == 0:
            loss = loss_target
        else:
            with torch.no_grad():
                y_hat_t = self.teacher(x).output

            loss_kd = self.kd_criterion(
                torch.log_softmax(y_hat_s / self.kd_temperature, dim=1),
                torch.softmax(y_hat_t / self.kd_temperature, dim=1),
            ) * (self.kd_temperature**2)

            loss = self.kd_weight * loss_kd + (1 - self.kd_weight) * loss_target

        self.metrics[stage].update(y_hat_s.argmax(dim=1), y)
        self.log(f"{stage}_loss", loss, on_epoch=True, on_step=False)
        return loss

    def _on_epoch_end(self, stage):
        metrics = self.metrics[stage].compute()
        self.log_dict(metrics, on_epoch=True, on_step=False)
        self.metrics[stage].reset()

    def training_step(self, batch):
        return self._step(batch, "train")

    def validation_step(self, batch):
        self._step(batch, "val")

    def test_step(self, batch):
        self._step(batch, "test")

    def on_train_epoch_end(self):
        self._on_epoch_end("train")

    def on_validation_epoch_end(self):
        self._on_epoch_end("val")

    def on_test_epoch_end(self):
        self._on_epoch_end("test")

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=0.0001)

In [14]:
distilprithvi = DistilPrithvi(
    teacher=teacher,
    student=DeepLabMobileNetV3Large(
        num_channels=len(datamodule.output_bands),
        num_classes=datamodule.num_classes,
    ),
    kd_temperature=4.0,
    kd_weight=0.75,
)

In [15]:
trainer = L.Trainer(
    max_epochs=100,
)

INFO:pytorch_lightning.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
mlflow.pytorch.autolog()

with mlflow.start_run():
    trainer.fit(
        distilprithvi,
        datamodule,
    )
    trainer.test(
        distilprithvi,
        datamodule,
    )

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type                     | Params | Mode 
------------------------------------------------------------------
0 | teacher      | SemanticSegmentationTask | 324 M  | eval 
1 | student      | DeepLabMobileNetV3Large  | 11.0 M | train
2 | kd_criterion | KLDivLoss                | 0      | train
------------------------------------------------------------------
335 M     Trainable params
0         Non-trainable params
335 M     Total params
1,340.918 Total estimated model params size (MB)
290       Modules in train mode
618       Modules in eval mode


                                                                           