In [8]:
import os
import sys
import torch
import torchgeo
import terratorch
import albumentations
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
from terratorch.models import EncoderDecoderFactory
from terratorch.models.decoders import IdentityDecoder
from albumentations.pytorch import ToTensorV2
import warnings

warnings.filterwarnings('ignore')

In [15]:
datamodule = terratorch.datamodules.TorchNonGeoDataModule(
    transforms = [
      albumentations.augmentations.geometric.resize.Resize(height=224, width=224),
      ToTensorV2()],
      cls=torchgeo.datamodules.EuroSATDataModule,
      batch_size=32,
      num_workers=8,
      root="./EuroSat",
      download=True,
      bands = ["B02","B03", "B04", "B8A", "B11", "B12"]
)


In [16]:
pl.seed_everything(0)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="output/burnscars/checkpoints/",
    mode="max",
    monitor="val/Multiclass_Jaccard_Index", # Variable to monitor
    filename="best-{epoch:02d}",
)

# Lightning Trainer
trainer = pl.Trainer(
    accelerator="auto",
    strategy="auto",
    devices=1, # Deactivate multi-gpu because it often fails in notebooks
    precision='bf16-mixed',  # Speed up training
    num_nodes=1,
    logger=True,  # Uses TensorBoard by default
    max_epochs=3, # For demos
    log_every_n_steps=1,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
    default_root_dir="output/eurosat",
    detect_anomaly=True,
)


INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
INFO: Using bfloat16 Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO: You have turned on `Trainer(detect_anomaly=True)`. This will significantly slow down compute speed and is recommended only for model debugging.
INFO:lightning.pytorch.utilities.rank_zero:You have turned on `Trainer(detect_anomaly=True)`. This will significantly slow down compute speed and is recommended only for model debugging.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [17]:
model = terratorch.tasks.ClassificationTask(
        model_args={
      "decoder": "IdentityDecoder",
      "backbone_pretrained": True,
      "backbone": "prithvi_eo_v2_300",
      "head_dim_list": [384, 128],
      "backbone_bands":
        ["BLUE",
        "GREEN",
        "RED",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2"],
      "num_classes": 10,
     "head_dropout": 0.1
      },
     loss = "ce",
     freeze_backbone = False,
     model_factory = "EncoderDecoderFactory",
     optimizer = "AdamW",
     lr = 1.e-4,
     #weight_decay = 0.05
)



In [18]:
trainer.fit(model, datamodule=datamodule)

INFO: You are using a CUDA device ('NVIDIA RTX A4500 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:lightning.pytorch.utilities.rank_zero:You are using a CUDA device ('NVIDIA RTX A4500 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


Downloading https://cdn-lfs.hf.co/repos/fc/1d/fc1dee780dee1dae2ad48856d0961ac6aa5dfcaaaa4fb3561be4aedf19b7ccc7/751f070f9bffa2eed48b24ca2dd0b02959280c08837e8c9a5532a67ba611df59?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27EuroSATallBands.zip%3B+filename%3D%22EuroSATallBands.zip%22%3B&response-content-type=application%2Fzip&Expires=1745615064&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0NTYxNTA2NH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mYy8xZC9mYzFkZWU3ODBkZWUxZGFlMmFkNDg4NTZkMDk2MWFjNmFhNWRmY2FhYWE0ZmIzNTYxYmU0YWVkZjE5YjdjY2M3Lzc1MWYwNzBmOWJmZmEyZWVkNDhiMjRjYTJkZDBiMDI5NTkyODBjMDg4MzdlOGM5YTU1MzJhNjdiYTYxMWRmNTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=Skdxl4luv1o5fbjdiGSAcA-1PrSFKNb-MAsLRVzEjrE56Sn2lBFUTzysqmPt77D1%7E%7EsMGfXkrTmng3s9TNGvOUYEPrJgIF18mZLvIq-rutpJFWTG5GYZ0qA2dE6Y-%7Efh1Pe6M-Y7jUolR0qoWa5u1Dms5XyetY7x5pF8Hgm01ULOCqtpfqZGFaFDdTWbLkXvn4AJyz-XwE

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.07G/2.07G [02:40<00:00, 12.9MB/s]


Downloading https://huggingface.co/datasets/torchgeo/eurosat/resolve/1ce6f1bfb56db63fd91b6ecc466ea67f2509774c/eurosat-train.txt to ./EuroSat/eurosat-train.txt


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 314k/314k [00:00<00:00, 1.22MB/s]


Downloading https://huggingface.co/datasets/torchgeo/eurosat/resolve/1ce6f1bfb56db63fd91b6ecc466ea67f2509774c/eurosat-val.txt to ./EuroSat/eurosat-val.txt


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105k/105k [00:00<00:00, 5.39MB/s]


Downloading https://huggingface.co/datasets/torchgeo/eurosat/resolve/1ce6f1bfb56db63fd91b6ecc466ea67f2509774c/eurosat-test.txt to ./EuroSat/eurosat-test.txt


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104k/104k [00:00<00:00, 5.29MB/s]
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'Float'