In [23]:
import os
import sys
import torch
import torchgeo
import terratorch
import albumentations
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
from terratorch.models import EncoderDecoderFactory
from terratorch.models.decoders import IdentityDecoder
from albumentations.pytorch import ToTensorV2
import warnings

warnings.filterwarnings('ignore')

In [24]:
max_epochs = 1

### Our datamodule. 

In [25]:
datamodule = terratorch.datamodules.TorchNonGeoDataModule(
    transforms = [
      albumentations.augmentations.geometric.resize.Resize(height=224, width=224),
      ToTensorV2()],
      cls=torchgeo.datamodules.EuroSATDataModule,
      batch_size=32,
      num_workers=8,
      root="./EuroSat",
      download=True,
      bands = ["B02","B03", "B04", "B8A", "B11", "B12"]
)


### Instantiating the Trainer.

In [26]:
pl.seed_everything(0)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="output/burnscars/checkpoints/",
    mode="max",
    monitor="val/Multiclass_Jaccard_Index", # Variable to monitor
    filename="best-{epoch:02d}",
)

# Lightning Trainer
trainer = pl.Trainer(
    accelerator="auto",
    strategy="auto",
    devices=1, # Deactivate multi-gpu because it often fails in notebooks
    precision='bf16-mixed',  # Speed up training
    num_nodes=max_epochs,
    logger=True,  # Uses TensorBoard by default
    max_epochs=3, # For demos
    log_every_n_steps=1,
    enable_checkpointing=True,
    #callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
    default_root_dir="output/eurosat",
    detect_anomaly=True,
)


INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
INFO: Using bfloat16 Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO: You have turned on `Trainer(detect_anomaly=True)`. This will significantly slow down compute speed and is recommended only for model debugging.
INFO:lightning.pytorch.utilities.rank_zero:You have turned on `Trainer(detect_anomaly=True)`. This will significantly slow down compute speed and is recommended only for model debugging.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


### Classification task. 

In [27]:
model = terratorch.tasks.ClassificationTask(
        model_args={
      "decoder": "IdentityDecoder",
      "backbone_pretrained": True,
      "backbone": "prithvi_eo_v2_300",
      "head_dim_list": [384, 128],
      "backbone_bands":
        ["BLUE",
        "GREEN",
        "RED",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2"],
      "num_classes": 10,
     "head_dropout": 0.1
      },
     loss = "ce",
     freeze_backbone = False,
     model_factory = "EncoderDecoderFactory",
     optimizer = "AdamW",
     lr = 1.e-4,
     scheduler_hparams = {
         "weight_decay" : 0.05,
     }
)



### Executing the training. 

In [None]:
trainer.fit(model, datamodule=datamodule)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name          | Type              | Params | Mode 
------------------------------------------------------------
0 | model         | ScalarOutputModel | 304 M  | train
1 | criterion     | CrossEntropyLoss  | 0      | train
2 | train_metrics | MetricCollection  | 0      | train
3 | val_metrics   | MetricCollection  | 0      | train
4 | test_metrics  | ModuleList        | 0      | train
------------------------------------------------------------
304 M     Trainable params
0         Non-trainable params
304 M     Total params
1,217.322 Total estimated model params size (MB)
553       Modules in train mode
0         Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name          | Type              | Params | Mode 
------------------------------------------------------------
0 | model         | ScalarOutputModel | 304 M  | trai

Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 507/507 [11:54<00:00,  0.71it/s, v_num=4]
[Aidation: |                                                                                                                                                           | 0/? [00:00<?, ?it/s]
[Aidation:   0%|                                                                                                                                                     | 0/169 [00:00<?, ?it/s]
[Aidation DataLoader 0:   0%|                                                                                                                                        | 0/169 [00:00<?, ?it/s]
[Aidation DataLoader 0:   1%|▊                                                                                                                               | 1/169 [00:00<01:18,  2.15it/s]
[Aidation DataLoader 0:   1%|█▌             

### Executing the test step. 

In [None]:
trainer.test(model, datamodule=datamodule)