In [1]:
from lightning.pytorch import Trainer
import terratorch
import albumentations
from albumentations.pytorch import ToTensorV2
from terratorch.models import EncoderDecoderFactory
from terratorch.models.necks import SelectIndices, LearnedInterpolateToPyramidal, ReshapeTokensToImage
from terratorch.models.decoders import UNetDecoder
from terratorch.datasets import HLSBands
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
from terratorch.tasks import SemanticSegmentationTask

  from .autonotebook import tqdm as notebook_tqdm


### Escolhendo alguns parâmetros fundamentals:
* `lr` - learning rate.
* `accelerator` - O tipo de dispositivo em que iremos executar o modelo. Geralmente será `gpu`ou `cpu`, mas podemos definir como `auto` e deixar os sitema escolher o que estiver disponível.
* `max_epochs` - O máximo número de iterações (`epochs` no jargão de machine learning) que usaremos para treinar o modelo. 

In [2]:
lr = 1e-4
accelerator = "auto"
max_epochs = 1

### Abaixo a instanciamento do datamodule, o objeto que usaremos para gerenciar o carregamento dos dadados do disco para a memória. 

In [3]:
datamodule = GenericNonGeoSegmentationDataModule(
    batch_size = 2,
    num_workers = 8,
    dataset_bands = [HLSBands.BLUE, HLSBands.GREEN, HLSBands.RED, HLSBands.NIR_NARROW, HLSBands.SWIR_1, HLSBands.SWIR_2],
    output_bands = [HLSBands.BLUE, HLSBands.GREEN, HLSBands.RED, HLSBands.NIR_NARROW, HLSBands.SWIR_1, HLSBands.SWIR_2],
    rgb_indices = [2, 1, 0],
    means = [
          0.033349706741586264,
          0.05701185520536176,
          0.05889748132001316,
          0.2323245113436119,
          0.1972854853760658,
          0.11944914225186566,
    ],
    stds = [
          0.02269135568823774,
          0.026807560223070237,
          0.04004109844362779,
          0.07791732423672691,
          0.08708738838140137,
          0.07241979477437814,
    ],
    train_data_root = "hls_burn_scars/training",
    val_data_root = "hls_burn_scars/validation",
    test_data_root = "hls_burn_scars/validation",
    img_grep = "*_merged.tif",
    label_grep = "*.mask.tif",
    num_classes = 2,
    train_transform = [albumentations.D4(), ToTensorV2()],
    test_transform = [ToTensorV2()],
    no_data_replace = 0,
    no_label_replace =  -1,
)

### Abaixo, o dicionário contendo todos os argumentos necessário para instanciar nosso modelo, um objeto `backbone-neck-decoder-head` completo. Esse dicionário é passado para a instância do objeto `task`, que, por sua vez, irá criar uma nova versão do modelo na memória. 

In [4]:
model_args = dict(
  backbone="prithvi_eo_v2_300",
  backbone_pretrained=True,
  backbone_num_frames=1,
  num_classes = 2,
  backbone_bands=[
      "BLUE",
      "GREEN",
      "RED",
      "NIR_NARROW",
      "SWIR_1",
      "WIR_2",
  ],
  decoder = "UNetDecoder",
  decoder_channels = [512, 256, 128, 64],
  necks=[{"name": "SelectIndices", "indices": [5, 11, 17, 23]},
         {"name": "ReshapeTokensToImage"},
         {"name": "LearnedInterpolateToPyramidal"}],
  head_dropout=0.1
)

### A criação do objeto `task`, que conduzirá o treinamento do modelo que configuramos na etapa anterior para a tarefa de segmentação que temos como objetivo. 

In [5]:
task = SemanticSegmentationTask(
    model_args,
    "EncoderDecoderFactory",
    loss="ce",
    lr=lr,
    ignore_index=-1,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.05},
    freeze_backbone = False,
    plot_on_val = False,
    class_names = ["Not burned", "Burn scar"],
)

INFO:root:Loaded weights for HLSBands.BLUE in position 0 of patch embed
INFO:root:Loaded weights for HLSBands.GREEN in position 1 of patch embed
INFO:root:Loaded weights for HLSBands.RED in position 2 of patch embed
INFO:root:Loaded weights for HLSBands.NIR_NARROW in position 3 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_1 in position 4 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_2 in position 5 of patch embed


### O objeto `Trainer` coordena todo o processo de treinamento. Podemos interpretá-lo como um `laço` de otimização aperfeiçoado, pois suporta recursos adicionais, como paralelismo para múltiplos nós. 

In [6]:
trainer = Trainer(
    accelerator=accelerator,
    max_epochs=max_epochs,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


### Executando, de fato, o treinamento. 

In [8]:
trainer.fit(model=task, datamodule=datamodule)

You are using a CUDA device ('NVIDIA RTX A4500 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | PixelWiseModel   | 324 M  | train
1 | criterion     | CrossEntropyLoss | 0      | train
2 | train_metrics | MetricCollection | 0      | train
3 | val_metrics   | MetricCollection | 0      | train
4 | test_metrics  | ModuleList       | 0      | train
-----------------------------------------------------------
324 M     Trainable params
0         Non-trainable params
324 M     Total params
1,296.819 Total estimated model params size (MB)
617       Modul

Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 270/270 [05:23<00:00,  0.83it/s, v_num=7]
[Aidation: |                                                                                                                                   | 0/? [00:00<?, ?it/s]
[Aidation:   0%|                                                                                                                             | 0/132 [00:00<?, ?it/s]
[Aidation DataLoader 0:   0%|                                                                                                                | 0/132 [00:00<?, ?it/s]
[Aidation DataLoader 0:   1%|▊                                                                                                       | 1/132 [00:00<01:02,  2.11it/s]
[Aidation DataLoader 0:   2%|█▌                                                                                                      | 2/132 [00:00<01:02,  2.09it/s

INFO:tensorboardX.summary:Summary name val/multiclassaccuracy_Not burned is illegal; using val/multiclassaccuracy_Not_burned instead.
INFO:tensorboardX.summary:Summary name val/multiclassaccuracy_Burn scar is illegal; using val/multiclassaccuracy_Burn_scar instead.
INFO:tensorboardX.summary:Summary name val/multiclassjaccardindex_Not burned is illegal; using val/multiclassjaccardindex_Not_burned instead.
INFO:tensorboardX.summary:Summary name val/multiclassjaccardindex_Burn scar is illegal; using val/multiclassjaccardindex_Burn_scar instead.



Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 270/270 [06:26<00:00,  0.70it/s, v_num=7]

INFO:tensorboardX.summary:Summary name train/multiclassaccuracy_Not burned is illegal; using train/multiclassaccuracy_Not_burned instead.
INFO:tensorboardX.summary:Summary name train/multiclassaccuracy_Burn scar is illegal; using train/multiclassaccuracy_Burn_scar instead.
INFO:tensorboardX.summary:Summary name train/multiclassjaccardindex_Not burned is illegal; using train/multiclassjaccardindex_Not_burned instead.
INFO:tensorboardX.summary:Summary name train/multiclassjaccardindex_Burn scar is illegal; using train/multiclassjaccardindex_Burn_scar instead.
`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 270/270 [06:35<00:00,  0.68it/s, v_num=7]


In [9]:
trainer.test(dataloaders=datamodule)

Restoring states from the checkpoint path at /home/jalmeida/Projetos/SBSR_courses/SBSR_notebooks/burn_scars/lightning_logs/version_7/checkpoints/epoch=0-step=270.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/jalmeida/Projetos/SBSR_courses/SBSR_notebooks/burn_scars/lightning_logs/version_7/checkpoints/epoch=0-step=270.ckpt


Testing DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:43<00:00,  3.06it/s]

INFO:tensorboardX.summary:Summary name test/multiclassaccuracy_Not burned is illegal; using test/multiclassaccuracy_Not_burned instead.
INFO:tensorboardX.summary:Summary name test/multiclassaccuracy_Burn scar is illegal; using test/multiclassaccuracy_Burn_scar instead.
INFO:tensorboardX.summary:Summary name test/multiclassjaccardindex_Not burned is illegal; using test/multiclassjaccardindex_Not_burned instead.
INFO:tensorboardX.summary:Summary name test/multiclassjaccardindex_Burn scar is illegal; using test/multiclassjaccardindex_Burn_scar instead.


Testing DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 132/132 [00:43<00:00,  3.06it/s]


[{'test/loss': 0.18288634717464447,
  'test/Multiclass_Accuracy': 0.9638374447822571,
  'test/multiclassaccuracy_Not burned': 0.9925030469894409,
  'test/multiclassaccuracy_Burn scar': 0.6828244924545288,
  'test/Multiclass_F1_Score': 0.9638374447822571,
  'test/Multiclass_Jaccard_Index': 0.7987370491027832,
  'test/multiclassjaccardindex_Not burned': 0.9613975882530212,
  'test/multiclassjaccardindex_Burn scar': 0.6360765695571899,
  'test/Multiclass_Jaccard_Index_Micro': 0.9301990866661072}]