In [116]:
import torch
import torch.nn as nn
import torchvision
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large, DeepLabV3_MobileNet_V3_Large_Weights
import albumentations as A
import lightning as L
from terratorch.datamodules import GenericNonGeoSegmentationDataModule

In [117]:
class DistilPrithvi(L.LightningModule):
    def __init__(
        self,
        teacher,
        student,
        soft_loss_func,
        hard_loss_func,
        soft_loss_weight=0.5,
    ):
        super().__init__()
        self.student = student
        self.hard_loss_func = hard_loss_func

    def forward(self, image):
        return self.student(image)['out']

    def step(self, batch):
        image = batch["image"]
        mask = batch["mask"].unsqueeze(1).float()
        y = self.forward(image)
        loss = self.hard_loss_func(y, mask)
        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch)

    def validation_step(self, batch, batch_idx):
        return self.step(batch)

    def test_step(self, batch, batch_idx):
        return self.step(batch)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.01)

In [118]:
def get_student(num_channels, num_classes):
    student = deeplabv3_mobilenet_v3_large(
        num_classes=num_classes,
        # TODO: make transfer learning work
        # weights=DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT,
    )
    student.backbone["0"][0] = nn.Conv2d(
        num_channels, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
    )
    return student

In [120]:
distilprithvi = DistilPrithvi(
    teacher=get_student(6, 1),
    student=get_student(6, 1),
    soft_loss_func=torch.nn.BCEWithLogitsLoss(),
    hard_loss_func=torch.nn.BCEWithLogitsLoss(),
    soft_loss_weight=0.0,
)

In [121]:
hls_burn_scars = GenericNonGeoSegmentationDataModule(
    batch_size=8,
    num_workers=8,
    dataset_bands=[
        "BLUE",
        "GREEN",
        "RED",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
    ],
    output_bands=[
        "BLUE",
        "GREEN",
        "RED",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
    ],
    rgb_indices=[2, 1, 0],
    train_data_root="datasets/hls_burn_scars/data",
    val_data_root="datasets/hls_burn_scars/data",
    test_data_root="datasets/hls_burn_scars/data",
    train_split="datasets/hls_burn_scars/splits/train.txt",
    val_split="datasets/hls_burn_scars/splits/val.txt",
    test_split="datasets/hls_burn_scars/splits/test.txt",
    img_grep="*_merged.tif",
    label_grep="*.mask.tif",
    means=[
        0.033349706741586264,
        0.05701185520536176,
        0.05889748132001316,
        0.2323245113436119,
        0.1972854853760658,
        0.11944914225186566,
    ],
    stds=[
        0.02269135568823774,
        0.026807560223070237,
        0.04004109844362779,
        0.07791732423672691,
        0.08708738838140137,
        0.07241979477437814,
    ],
    num_classes=2,
    train_transform=[A.D4(), A.pytorch.ToTensorV2()],
    test_transform=[A.pytorch.ToTensorV2()],
    no_data_replace=0,
    no_label_replace=-1,
)

In [None]:
trainer = L.Trainer(
    max_epochs=100,
    log_every_n_steps=20,
    val_check_interval=0.5,
    default_root_dir="checkpoints",
)

In [None]:
trainer.fit(
    distilprithvi,
    hls_burn_scars,
)