In [10]:
from terratorch.cli_tools import LightningInferenceModel
import torch
import torch.nn as nn
from torchvision.models.segmentation import (
    deeplabv3_mobilenet_v3_large,
    DeepLabV3_MobileNet_V3_Large_Weights,
)
import lightning as L

In [11]:
torch.set_float32_matmul_precision("high")

In [12]:
CONFIG = "teachers/hls_burn_scars_teacher/burn_scars_config.yaml"
CHECKPOINT = "teachers/hls_burn_scars_teacher/Prithvi_EO_V2_300M_BurnScars.pt"

In [13]:
inference_model = LightningInferenceModel.from_config(CONFIG, CHECKPOINT)
teacher = inference_model.model
datamodule = inference_model.datamodule

/home/mkoza/workspace/ml/distilprithvi/venv/lib/python3.12/site-packages/lightning/pytorch/cli.py:530: LightningCLI's args parameter is intended to run from within Python like if it were from the command line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: sys.argv[1:]=['--f=/run/user/1003/jupyter/runtime/kernel-v388133463edcc071a288a755d17143d0c2631c103.json'], args=['--config', 'teachers/hls_burn_scars_teacher/burn_scars_config.yaml'].
Seed set to 2
INFO:root:Loaded weights for HLSBands.BLUE in position 0 of patch embed
INFO:root:Loaded weights for HLSBands.GREEN in position 1 of patch embed
INFO:root:Loaded weights for HLSBands.RED in position 2 of patch embed
INFO:root:Loaded weights for HLSBands.NIR_NARROW in position 3 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_1 in position 4 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_2 in position 5 of patch embed
/home/mkoza/workspace/ml/distilprithvi/venv/lib/pytho

In [14]:
class DeepLabMobileNetV3Large(nn.Module):
    def __init__(self, num_channels, num_classes):
        super().__init__()
        self.model = deeplabv3_mobilenet_v3_large(
            weights=DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT,
        )
        self.model.backbone["0"][0] = nn.Conv2d(
            num_channels,
            16,
            kernel_size=(3, 3),
            stride=(2, 2),
            padding=(1, 1),
            bias=False,
        )
        self.model.classifier[4] = nn.Conv2d(
            256,
            num_classes,
            kernel_size=(1, 1),
            stride=(1, 1),
            padding=(0, 0),
            bias=True,
        )

    def forward(self, x):
        return self.model(x)

In [15]:
class DistilPrithvi(L.LightningModule):
    def __init__(
        self,
        teacher,
        student,
        kd_weight=0.5,
        kd_temperature=2.0,
    ):
        super().__init__()
        self.teacher = teacher
        self.teacher.eval()

        self.student = student

        self.kd_weight = kd_weight
        self.kd_temperature = kd_temperature
        self.kd_criterion = nn.KLDivLoss(reduction="batchmean")

    def forward(self, x):
        return self.student(x)["out"]

    def training_step(self, batch):
        x = batch["image"]
        y = batch["mask"]

        y_hat_s = self(x)
        loss_target = self.teacher.criterion(y_hat_s, y)

        if self.kd_weight == 0:
            loss = loss_target
        else:
            with torch.no_grad():
                y_hat_t = self.teacher(x).output
            loss_kd = self.kd_criterion(
                torch.log_softmax(y_hat_s / self.kd_temperature, dim=1),
                torch.softmax(y_hat_t / self.kd_temperature, dim=1),
            ) * (self.kd_temperature**2)
            self.log("train/loss_target", loss_target, on_epoch=True, on_step=False)
            self.log("train/loss_kd", loss_kd, on_epoch=True, on_step=False)

            loss = self.kd_weight * loss_kd + (1 - self.kd_weight) * loss_target

        self.log("train/loss", loss, on_epoch=True, on_step=False)
        self.teacher.train_metrics.update(y_hat_s.argmax(dim=1), y)
        return loss

    def validation_step(self, batch):
        x = batch["image"]
        y = batch["mask"]
        y_hat_s = self(x)
        loss = self.teacher.criterion(y_hat_s, y)
        self.teacher.val_metrics.update(y_hat_s.argmax(dim=1), y)
        self.log("val/loss", loss, on_epoch=True, on_step=False)

    def test_step(self, batch):
        x = batch["image"]
        y = batch["mask"]
        y_hat_s = self(x)
        loss = self.teacher.criterion(y_hat_s, y)
        self.teacher.test_metrics[0].update(y_hat_s.argmax(dim=1), y)
        self.log("test/loss", loss, on_epoch=True, on_step=False)

    def on_train_epoch_end(self):
        metrics = self.teacher.train_metrics.compute()
        self.log_dict(metrics, on_epoch=True, on_step=False)
        self.teacher.train_metrics.reset()

    def on_validation_epoch_end(self):
        metrics = self.teacher.val_metrics.compute()
        self.log_dict(metrics, on_epoch=True, on_step=False)
        self.teacher.val_metrics.reset()

    def on_test_epoch_end(self):
        metrics = self.teacher.test_metrics[0].compute()
        self.log_dict(metrics, on_epoch=True, on_step=False)
        self.teacher.test_metrics[0].reset()

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=0.0001)

In [16]:
distilprithvi = DistilPrithvi(
    teacher=teacher,
    student=DeepLabMobileNetV3Large(
        num_channels=len(datamodule.output_bands),
        num_classes=datamodule.num_classes,
    ),
    kd_temperature=4.0,
    kd_weight=0.75,
)

In [17]:
mlf_logger = L.pytorch.loggers.MLFlowLogger(
    experiment_name="distilprithvi",
    run_name="distilprithvi",
)

trainer = L.Trainer(max_epochs=2, logger=mlf_logger)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [18]:
trainer.fit(
    distilprithvi,
    datamodule,
)
trainer.test(
    distilprithvi,
    datamodule,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



  | Name         | Type                     | Params | Mode 
------------------------------------------------------------------
0 | teacher      | SemanticSegmentationTask | 324 M  | eval 
1 | student      | DeepLabMobileNetV3Large  | 11.0 M | train
2 | kd_criterion | KLDivLoss                | 0      | train
------------------------------------------------------------------
335 M     Trainable params
0         Non-trainable params
335 M     Total params
1,340.918 Total estimated model params size (MB)
290       Modules in train mode
618       Modules in eval mode


                                                                           

/home/mkoza/workspace/ml/distilprithvi/venv/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (18) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 18/18 [00:18<00:00,  0.99it/s, v_num=76ff]



Epoch 1: 100%|██████████| 18/18 [00:18<00:00,  0.95it/s, v_num=76ff]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 18/18 [00:21<00:00,  0.84it/s, v_num=76ff]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 5/5 [00:00<00:00, 12.02it/s]


[{'test/loss': 0.41054394841194153,
  'test/Multiclass_Accuracy': 0.8825483322143555,
  'test/multiclassaccuracy_not_burned': 0.9733701944351196,
  'test/multiclassaccuracy_burn_scar': 0.10242204368114471,
  'test/Multiclass_F1_Score': 0.8825483322143555,
  'test/Multiclass_Jaccard_Index': 0.48231783509254456,
  'test/multiclassjaccardindex_not_burned': 0.8812803030014038,
  'test/multiclassjaccardindex_burn_scar': 0.08335535228252411,
  'test/Multiclass_Jaccard_Index_Micro': 0.7897866368293762}]