In [30]:
import mlflow
import torch
import torch.nn as nn
from torchmetrics.segmentation import DiceScore 
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large, DeepLabV3_MobileNet_V3_Large_Weights
import albumentations as A
import lightning as L
from terratorch.datamodules import GenericNonGeoSegmentationDataModule

In [31]:
class DistilPrithvi(L.LightningModule):
    def __init__(
        self,
        teacher,
        student,
        soft_loss_func,
        hard_loss_func,
        soft_loss_weight=0.5,
    ):
        super().__init__()
        self.teacher = teacher.eval()
        self.student = student
        self.soft_loss_func = soft_loss_func
        self.hard_loss_func = hard_loss_func
        self.soft_loss_weight = soft_loss_weight

        self.dice = DiceScore(num_classes=1)

    def forward(self, image):
        return self.student(image)["out"]

    def step(self, batch, stage):
        image = batch["image"]
        mask = batch["mask"].unsqueeze(1).float()

        y = self.forward(image)
        y = torch.sigmoid(y)

        loss = self.hard_loss_func(y, mask)

        self.log(f"{stage}_loss", loss, on_epoch=True)
        self.log(f"{stage}_dice", self.dice(y, mask), on_epoch=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, "val")

    def test_step(self, batch, batch_idx):
        return self.step(batch, "test")

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=0.001)

In [32]:
def get_student(num_channels):
    model = deeplabv3_mobilenet_v3_large(
        weights=DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT,
    )
    model.backbone["0"][0] = nn.Conv2d(
        num_channels, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
    )
    model.classifier[4] = nn.Conv2d(
        256, 1, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True
    )
    return model

In [33]:
datamodule = GenericNonGeoSegmentationDataModule(
    batch_size=40,
    num_workers=8,
    dataset_bands=[
        "BLUE",
        "GREEN",
        "RED",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
    ],
    output_bands=[
        "BLUE",
        "GREEN",
        "RED",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
    ],
    rgb_indices=[2, 1, 0],
    train_data_root="datasets/hls_burn_scars/data",
    val_data_root="datasets/hls_burn_scars/data",
    test_data_root="datasets/hls_burn_scars/data",
    train_split="datasets/hls_burn_scars/splits/train.txt",
    val_split="datasets/hls_burn_scars/splits/val.txt",
    test_split="datasets/hls_burn_scars/splits/test.txt",
    img_grep="*_merged.tif",
    label_grep="*.mask.tif",
    means=[
        0.033349706741586264,
        0.05701185520536176,
        0.05889748132001316,
        0.2323245113436119,
        0.1972854853760658,
        0.11944914225186566,
    ],
    stds=[
        0.02269135568823774,
        0.026807560223070237,
        0.04004109844362779,
        0.07791732423672691,
        0.08708738838140137,
        0.07241979477437814,
    ],
    num_classes=2,
    train_transform=[A.D4(), A.pytorch.ToTensorV2()],
    test_transform=[A.pytorch.ToTensorV2()],
    no_data_replace=0,
    no_label_replace=-1,
)

In [34]:
# TODO: replace with an actual teacher model
teacher = get_student(num_channels=len(datamodule.output_bands))
student = get_student(num_channels=len(datamodule.output_bands))

distilprithvi = DistilPrithvi(
    teacher=teacher,
    student=student,
    soft_loss_func=torch.nn.BCEWithLogitsLoss(),
    hard_loss_func=torch.nn.BCEWithLogitsLoss(),
    soft_loss_weight=0.0,
)

In [35]:
trainer = L.Trainer(
    max_epochs=100,
    log_every_n_steps=1,
    val_check_interval=0.5,
)

INFO:pytorch_lightning.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
mlflow.pytorch.autolog()

with mlflow.start_run():
    trainer.fit(
        distilprithvi,
        datamodule,
    )

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type              | Params | Mode 
-------------------------------------------------------------
0 | teacher        | DeepLabV3         | 11.0 M | eval 
1 | student        | DeepLabV3         | 11.0 M | train
2 | soft_loss_func | BCEWithLogitsLoss | 0      | train
3 | hard_loss_func | BCEWithLogitsLoss | 0      | train
4 | dice           | DiceScore         | 0      | train
-------------------------------------------------------------
22.0 M    Trainable params
0         Non-trainable params
22.0 M    Total params
88.197    Total estimated model params size (MB)
291       Modules in train mode
288       Modules in eval mode


Epoch 58:  46%|████▌     | 6/13 [05:02<05:53,  0.02it/s, v_num=38]         
Epoch 23:  77%|███████▋  | 10/13 [00:06<00:01,  1.58it/s, v_num=40]

Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/queues.py", line 259, in _feed
    reader_close()
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 178, in close
    self._close()
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 377, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor


Epoch 60:   0%|          | 0/13 [00:00<?, ?it/s, v_num=40]         