# Creating a finetuning workload with the script interface
This tutorial does not intend to create an accurate finetuned example (we are running for a single epoch!), but to describe step-by-step how to instantiate and run this kind of task. 

In [1]:
from lightning.pytorch import Trainer
import terratorch
import albumentations
from albumentations.pytorch import ToTensorV2
from terratorch.models import EncoderDecoderFactory
from terratorch.models.necks import SelectIndices, LearnedInterpolateToPyramidal, ReshapeTokensToImage
from terratorch.models.decoders import UNetDecoder
from terratorch.datasets import HLSBands
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
from terratorch.tasks import SemanticSegmentationTask

### Defining fundamental parameters:
* `lr` - learning rate.
* `accelerator` - The kind of device in which the model will be executed. It is usually `gpu` or `cpu`. If we set it as `auto`, Lightning will select the most appropiate available device.
* `max_epochs` - The maximum number of epochs used to train the model. 

In [2]:
lr = 1e-4
accelerator = "auto"
max_epochs = 1

### Next, we will instantiate the datamodule, the object we will use to load the files from disk to memory.

In [3]:
datamodule = GenericNonGeoSegmentationDataModule(
    batch_size = 2,
    num_workers = 8,
    dataset_bands = [HLSBands.BLUE, HLSBands.GREEN, HLSBands.RED, HLSBands.NIR_NARROW, HLSBands.SWIR_1, HLSBands.SWIR_2],
    output_bands = [HLSBands.BLUE, HLSBands.GREEN, HLSBands.RED, HLSBands.NIR_NARROW, HLSBands.SWIR_1, HLSBands.SWIR_2],
    rgb_indices = [2, 1, 0],
    means = [
          0.033349706741586264,
          0.05701185520536176,
          0.05889748132001316,
          0.2323245113436119,
          0.1972854853760658,
          0.11944914225186566,
    ],
    stds = [
          0.02269135568823774,
          0.026807560223070237,
          0.04004109844362779,
          0.07791732423672691,
          0.08708738838140137,
          0.07241979477437814,
    ],
    train_data_root = "../burn_scars/hls_burn_scars/training",
    val_data_root = "../burn_scars/hls_burn_scars/validation",
    test_data_root = "../burn_scars/hls_burn_scars/validation",
    img_grep = "*_merged.tif",
    label_grep = "*.mask.tif",
    num_classes = 2,
    train_transform = [albumentations.D4(), ToTensorV2()],
    test_transform = [ToTensorV2()],
    no_data_replace = 0,
    no_label_replace =  -1,
)

### A dictionary containing all the arguments necessary to instantiate a complete `backbone-neck-decoder-head`, which will be passed to the task object.

In [4]:
model_args = dict(
  backbone="prithvi_eo_v2_300",
  backbone_pretrained=True,
  backbone_num_frames=1,
  num_classes = 2,
  backbone_bands=[
      "BLUE",
      "GREEN",
      "RED",
      "NIR_NARROW",
      "SWIR_1",
      "WIR_2",
  ],
  decoder = "UNetDecoder",
  decoder_channels = [512, 256, 128, 64],
  necks=[{"name": "SelectIndices", "indices": [5, 11, 17, 23]},
         {"name": "ReshapeTokensToImage"},
         {"name": "LearnedInterpolateToPyramidal"}],
  head_dropout=0.1
)

### Creating the `task` object, which will be used to properly define how the model will be trained and used after it.

In [5]:
task = SemanticSegmentationTask(
    model_args,
    "EncoderDecoderFactory",
    loss="ce",
    lr=lr,
    ignore_index=-1,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.05},
    freeze_backbone = False,
    plot_on_val = False,
    class_names = ["Not burned", "Burn scar"],
)

INFO:root:Loaded weights for HLSBands.BLUE in position 0 of patch embed
INFO:root:Loaded weights for HLSBands.GREEN in position 1 of patch embed
INFO:root:Loaded weights for HLSBands.RED in position 2 of patch embed
INFO:root:Loaded weights for HLSBands.NIR_NARROW in position 3 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_1 in position 4 of patch embed


### The object `Trainer` manages all the training process. It can be interpreted as an improved optimization loop, in which parallelism and checkpointing are transparently managed by the system. 

In [6]:
trainer = Trainer(
    accelerator=accelerator,
    max_epochs=max_epochs,
)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


### Executing the training.

In [7]:
trainer.fit(model=task, datamodule=datamodule)

You are using a CUDA device ('NVIDIA RTX A4500 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | PixelWiseModel   | 324 M  | train
1 | criterion     | CrossEntropyLoss | 0      | train
2 | train_metrics | MetricCollection | 0      | train
3 | val_metrics   | MetricCollection | 0      | train
4 | test_metrics  | ModuleList       | 0      | train
-----------------------------------------------------------
324 M     Trainable params
0         Non-trainable params
324 M     Total params
1,296.819 Total estimated model params size (MB)
617       Modul

Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=1` reached.


### Testing the trained model (extracting metrics). 

In [8]:
trainer.test(dataloaders=datamodule)

Restoring states from the checkpoint path at /home/jalmeida/Projetos/SBSR_courses/SBSR_notebooks/local/lightning_logs/version_5/checkpoints/epoch=0-step=270.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/jalmeida/Projetos/SBSR_courses/SBSR_notebooks/local/lightning_logs/version_5/checkpoints/epoch=0-step=270.ckpt


Testing: |                                                                                                    …

[{'test/loss': 0.2669268250465393,
  'test/Multiclass_Accuracy': 0.9274423718452454,
  'test/multiclassaccuracy_Not burned': 0.9267654418945312,
  'test/multiclassaccuracy_Burn scar': 0.9340785145759583,
  'test/Multiclass_F1_Score': 0.9274423718452454,
  'test/Multiclass_Jaccard_Index': 0.7321492433547974,
  'test/multiclassjaccardindex_Not burned': 0.9205750226974487,
  'test/multiclassjaccardindex_Burn scar': 0.5437235236167908,
  'test/Multiclass_Jaccard_Index_Micro': 0.8647016882896423}]