# Setup
Install TerraTorch

In [1]:
!pip install terratorch gdown tensorboard


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [None]:
import os
import sys
import torch
import gdown
import terratorch
import albumentations
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule

  from .autonotebook import tqdm as notebook_tqdm


3. Download the dataset from Google Drive

In [None]:
if not os.path.isfile('sen1floods11_v1.1.tar.gz'):
    !gdown.download("https://drive.google.com/uc?id=1lRw3X7oFNq_WyzBO6uyUJijyTuYm23VS")
    !tar -xzvf sen1floods11_v1.1.tar.gz

## Sen1Floods11 Dataset

Lets start with analysing the dataset

In [None]:
dataset_path = Path('sen1floods11_v1.1')
!ls "sen1floods11_v1.1/data"

In [None]:
!ls "sen1floods11_v1.1/data/S2L1CHand/" | head

In [None]:
# TerraTorch provides generic data modules that work directly with PyTorch Lightning
datamodule = terratorch.datamodules.GenericNonGeoSegmentationDataModule(
    batch_size=8,
    num_workers=2,
    num_classes=2,

    # Define data paths
    train_data_root=dataset_path / 'data/S2L1CHand',
    train_label_data_root=dataset_path / 'data/LabelHand',
    val_data_root=dataset_path / 'data/S2L1CHand',
    val_label_data_root=dataset_path / 'data/LabelHand',
    test_data_root=dataset_path / 'data/S2L1CHand',
    test_label_data_root=dataset_path / 'data/LabelHand',

    # Define splits as all samples are saved in the same folder
    train_split=dataset_path / 'splits/flood_train_data.txt',
    val_split=dataset_path / 'splits/flood_valid_data.txt',
    test_split=dataset_path / 'splits/flood_test_data.txt',
    
    # Define suffix
    img_grep='*_S2Hand.tif',
    label_grep='*_LabelHand.tif',
    
    train_transform=[
        albumentations.D4(), # Random flips and rotation
        albumentations.pytorch.transforms.ToTensorV2(),
    ],
    val_transform=None,  # Using ToTensor() by default
    test_transform=None,
    
    # Define bands in the data and which one you want to use (optional)
    dataset_bands=[
      "COASTAL_AEROSOL",
      "BLUE",
      "GREEN",
      "RED",
      "RED_EDGE_1",
      "RED_EDGE_2",
      "RED_EDGE_3",
      "NIR_BROAD",
      "NIR_NARROW",
      "CIRRUS",
      "SWIR_1",
      "SWIR_2",
    ],
    output_bands=[
      "BLUE",
      "GREEN",
      "RED",
      "NIR_NARROW",
      "SWIR_1",
      "SWIR_2", 
    ],
    
    # Define standardization values for the output_bands
    means=[
      0.11076498225107874,
      0.13456047562676646,
      0.12477149645635542,
      0.3248933937526503,
      0.23118412840904512,
      0.15624583324071273,
    ],
    stds=[
      0.15469174852002912,
      0.13070592427323752,
      0.12786689586224442,
      0.13925781946803198,
      0.11303782829438778,
      0.10207461132314981,
    ],
)

# Setup train and val datasets
datamodule.setup("fit")

In [None]:
# checking datasets train split size
train_dataset = datamodule.train_dataset
len(train_dataset)

In [None]:
# checking datasets validation split size
val_dataset = datamodule.val_dataset
len(val_dataset)

In [None]:
# plotting a few samples
val_dataset.plot(val_dataset[0])
val_dataset.plot(val_dataset[9])
val_dataset.plot(val_dataset[11])

In [None]:
# checking datasets testing split size
datamodule.setup("test")
test_dataset = datamodule.test_dataset
len(test_dataset)

# TerraTorch model factory

In [None]:
# TerraTorch includes meta registries for all model components 
from terratorch.registry import BACKBONE_REGISTRY, TERRATORCH_BACKBONE_REGISTRY, TERRATORCH_DECODER_REGISTRY

In [None]:
list(TERRATORCH_BACKBONE_REGISTRY)[:5]

In [None]:
list(TERRATORCH_DECODER_REGISTRY)

In [None]:
# Build PyTorch model for custom pipeline
model = BACKBONE_REGISTRY.build("prithvi_eo_v2_300_tl", pretrained=True)

In [None]:
model

# Fine-tune Prithvi via PyTorch Lightning

In [None]:
pl.seed_everything(0)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="output/sen1floods11/checkpoints/",
    mode="max",
    monitor="val/Multiclass_Jaccard_Index", # Variable to monitor
    filename="best-{epoch:02d}",
)

# Lightning Trainer
trainer = pl.Trainer(
    accelerator="auto",
    strategy="auto",
    devices=1, # Deactivate multi-gpu because it often fails in notebooks
    precision='16-mixed',  # Speed up training
    num_nodes=1,
    logger=True,  # Uses TensorBoard by default
    max_epochs=5, # For demos
    log_every_n_steps=1,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
    default_root_dir="output/sen1floods11/",
)

# Model
model = terratorch.tasks.SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",
    model_args={
        # Backbone
        "backbone": "prithvi_eo_v2_300_tl", # Model can be either prithvi_eo_v1_100, prithvi_eo_v2_300, prithvi_eo_v2_300_tl, prithvi_eo_v2_600, prithvi_eo_v2_600_tl
        "backbone_pretrained": True,
        "backbone_num_frames": 1, # 1 is the default value
        "backbone_img_size": 512, # if not provided: interpolate pos embedding from 224 pre-training which also works well
        "backbone_bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
        "backbone_coords_encoding": [], # use ["time", "location"] for time and location metadata
        
        # Necks 
        "necks": [
            {
                "name": "SelectIndices",
                # "indices": [2, 5, 8, 11] # indices for prithvi_eo_v1_100
                "indices": [5, 11, 17, 23] # indices for prithvi_eo_v2_300
                # "indices": [7, 15, 23, 31] # indices for prithvi_eo_v2_600
            },
            {"name": "ReshapeTokensToImage",},
            {"name": "LearnedInterpolateToPyramidal"}            
        ],
        
        # Decoder
        "decoder": "UNetDecoder",
        "decoder_channels": [512, 256, 128, 64],
        
        # Head
        "head_dropout": 0.1,
        "num_classes": 2,
    },
    
    loss="dice",
    optimizer="AdamW",
    lr=1e-4,
    ignore_index=-1,
    freeze_backbone=True, # Speeds up fine-tuning
    freeze_decoder=False,
    plot_on_val=True,
    class_names=['no water', 'water']  # optionally define class names
)

In [None]:
%load_ext tensorboard
%tensorboard --logdir output

In [None]:
# Training
trainer.fit(model, datamodule=datamodule)

In [None]:
best_ckpt_path = "output/sen1floods11/checkpoints/best-epoch=01.ckpt"

In [None]:
trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)

In [None]:
# now we can use the model for predictions and plotting!
model = terratorch.tasks.SemanticSegmentationTask.load_from_checkpoint(
    best_ckpt_path,
    model_factory=model.hparams.model_factory,
    model_args=model.hparams.model_args,
)

test_loader = datamodule.test_dataloader()
with torch.no_grad():
    batch = next(iter(test_loader))
    images = batch["image"].to(model.device)
    masks = batch["mask"].numpy()

    outputs = model(images)
    preds = torch.argmax(outputs.output, dim=1).cpu().numpy()

for i in range(5):
    sample = {key: batch[key][i] for key in batch}
    sample["prediction"] = preds[i]
    test_dataset.plot(sample)

# Fine-tuning via CLI

You might want to restart the session to free up GPU memory.

In [None]:
# Download config
!wget wget https://raw.githubusercontent.com/ibm/TerraTorch/refs/heads/main/examples/tutorial/configs/prithvi_v2_eo_300_tl_unet_sen1floods11.yaml

In [None]:
# Run fine-tuning
!terratorch fit -c prithvi_v2_eo_300_tl_unet_sen1floods11.yaml