In [None]:
from terratorch.cli_tools import LightningInferenceModel
from distillers.semantic_segmentation_distiller import SemanticSegmentationDistiller
from students.deeplabv3_mobilenet_v3_large import DeepLabV3MobileNetV3Large
from students.lraspp_mobilenet_v3_large import LRASPPMobileNetV3Large
import torch
from lightning import Trainer
from lightning.pytorch.loggers import MLFlowLogger

torch.set_float32_matmul_precision('medium')

TEACHER_CONFIG = "teachers/hls_burn_scars_teacher/burn_scars_config.yaml"
TEACHER_CHECKPOINT = "teachers/hls_burn_scars_teacher/Prithvi_EO_V2_300M_BurnScars.pt"
STUDENT_MODEL = "lraspp"  # or "deeplabv3"
BATCH_SIZE = 16
NUM_EPOCHS = 50
EXPERIMENT_NAME = "hls_burn_scars_distillation"
RUN_NAME = "hls_burn_scars_distillation_run"
KD_TEMPERATURE = 4.0
KD_WEIGHT = 0.75


inference_model = LightningInferenceModel.from_config(
    TEACHER_CONFIG, TEACHER_CHECKPOINT
)
teacher = inference_model.model
datamodule = inference_model.datamodule
datamodule.batch_size = BATCH_SIZE

if STUDENT_MODEL == "deeplabv3":
    student = DeepLabV3MobileNetV3Large(
        num_channels=len(datamodule.output_bands),
        num_classes=datamodule.num_classes,
    )
elif STUDENT_MODEL == "lraspp":
    student = LRASPPMobileNetV3Large(
        num_channels=len(datamodule.output_bands),
        num_classes=datamodule.num_classes,
    )

distiller = SemanticSegmentationDistiller(
    teacher=teacher,
    student=student,
    kd_temperature=KD_TEMPERATURE,
    kd_weight=KD_WEIGHT,
)

mlf_logger = MLFlowLogger(experiment_name=EXPERIMENT_NAME, run_name=RUN_NAME)
trainer = Trainer(max_epochs=NUM_EPOCHS, logger=mlf_logger)
trainer.fit(distiller, datamodule)
trainer.test(distiller, datamodule)

  from .autonotebook import tqdm as notebook_tqdm
/home/mkoza/workspace/ml/distilprithvi/venv/lib/python3.12/site-packages/lightning/pytorch/cli.py:530: LightningCLI's args parameter is intended to run from within Python like if it were from the command line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: sys.argv[1:]=['--f=/run/user/1003/jupyter/runtime/kernel-v35416a706037c785073752ad43753c4b91fdc7500.json'], args=['--config', 'teachers/hls_burn_scars_teacher/burn_scars_config.yaml'].
Seed set to 2
INFO:root:Loaded weights for HLSBands.BLUE in position 0 of patch embed
INFO:root:Loaded weights for HLSBands.GREEN in position 1 of patch embed
INFO:root:Loaded weights for HLSBands.RED in position 2 of patch embed
INFO:root:Loaded weights for HLSBands.NIR_NARROW in position 3 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_1 in position 4 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_2 in position 5 of patch embed
/ho

                                                                           

/home/mkoza/workspace/ml/distilprithvi/venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 16. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.




/home/mkoza/workspace/ml/distilprithvi/venv/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (32) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0:   3%|▎         | 1/32 [00:01<00:33,  0.91it/s, v_num=5d64]