In [None]:
import os
import torch
import gdown
import terratorch
import albumentations
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
import warnings
warnings.filterwarnings("ignore")
import tarfile

from terratorch.registry import BACKBONE_REGISTRY, TERRATORCH_BACKBONE_REGISTRY, TERRATORCH_DECODER_REGISTRY

# Download Sen1Floods11 dataset
Check [original publication](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w11/Bonafilia_Sen1Floods11_A_Georeferenced_Dataset_to_Train_and_Test_Deep_Learning_CVPRW_2020_paper.pdf) and [here](https://github.com/cloudtostreet/Sen1Floods11) for more about the Sen1Floods11 dataset.

- Make sure to uncomment the cell below to download the dataset and specify `dataset_root` for where the downloaded dataset should go.

- Comment cell again after downloading to avoid unnecessarily running the download process

In [None]:
# dataset_root = "/Users/samuel.omole/Desktop/repos/geofm_datasets" # change dataset root to desired location
# url = "https://drive.google.com/uc?id=1lRw3X7oFNq_WyzBO6uyUJijyTuYm23VS"
# archive = dataset_root + "/sen1floods11_v1.1.tar.gz"
# extract_dir = dataset_root + "/sen1floods11_v1.1"

# # download if missing
# if not os.path.isfile(archive):
#     gdown.download(url, output=archive, quiet=False)

# # extract if not already extracted
# if not os.path.isdir(extract_dir):
#     with tarfile.open(archive, "r:gz") as tar:
#         tar.extractall(path=dataset_root)

# Preparing dataset with TerraTorch datamodule
Check TerraTorch details in [publication](https://arxiv.org/pdf/2503.20563) and in the [repository](https://github.com/IBM/terratorch)

- Point `dataset_path` to the location of the dataset
- The dataset has already been pre-processed in the format that the TerraTorch datamodule can accept
- We are using the `GenericMultiModalDataModule` in this case

## Setting up the datamodule

In [None]:
dataset_path = Path("/Users/samuel.omole/Desktop/repos/geofm_datasets/sen1floods11_v1.1") # path to dataset

datamodule = terratorch.datamodules.GenericMultiModalDataModule(
    task="segmentation",
    batch_size=8,
    num_workers=2,
    num_classes=2,
    # Define input modalities. The names must match the keys in the dicts below and everywhere.
    modalities=["S2L1C", "S1GRD"],
    rgb_modality="S2L1C",  # Used for plotting. Defaults to the first modality if not provided.
    rgb_indices=[3,2,1],  # RGB channel positions in the rgb_modality.

    # Define data paths as dicts using the modality names as keys.
    train_data_root={
        "S2L1C": dataset_path / "data/S2L1CHand",
        "S1GRD": dataset_path / "data/S1GRDHand",
    },
    train_label_data_root=dataset_path / "data/LabelHand",
    val_data_root={
        "S2L1C": dataset_path / "data/S2L1CHand",
        "S1GRD": dataset_path / "data/S1GRDHand",
    },
    val_label_data_root=dataset_path / "data/LabelHand",
    test_data_root={
        "S2L1C": dataset_path / "data/S2L1CHand",
        "S1GRD": dataset_path / "data/S1GRDHand",
    },
    test_label_data_root=dataset_path / "data/LabelHand",

    # Define split files
    train_split=dataset_path / "splits/flood_train_data.txt",
    val_split=dataset_path / "splits/flood_valid_data.txt",
    test_split=dataset_path / "splits/flood_test_data.txt",
    
    # Define suffix
    image_grep={
        "S2L1C": "*_S2Hand.tif",
        "S1GRD": "*_S1Hand.tif",
    },
    label_grep="*_LabelHand.tif",
    
    # You can select a subset of the dataset bands as model inputs by providing dataset_bands and output_bands.
    # This setting is optional for all modalities and needs to be provided as dicts.
    # Here is an example for with S-1 GRD. You could change the output to ["VV"] to only train on the first band.
    dataset_bands={
        "S1GRD": ["VV", "VH"]
    },
    output_bands={
        "S1GRD": ["VV", "VH"]
    },

    # Define standardization values. We use the pre-training values provided for the TerraMind model
    # Note that means and stds must be aligned with the output_bands defined earlier (equal length of values).
    # For the S-2 L1C where all the standardization values are provided, the dataset and output bands were not specified earlier
    means={
      "S2L1C": [2357.089, 2137.385, 2018.788, 2082.986, 2295.651, 2854.537, 3122.849, 3040.560, 3306.481, 1473.847, 506.070, 2472.825, 1838.929],
    #   "S2L2A": [1390.458, 1503.317, 1718.197, 1853.910, 2199.100, 2779.975, 2987.011, 3083.234, 3132.220, 3162.988, 2424.884, 1857.648],
      "S1GRD": [-12.599, -20.293],
    #   "S1RTC": [-10.93, -17.329],
    #   "RGB": [87.271, 80.931, 66.667],
    #   "DEM": [670.665]
    },
    stds={
      "S2L1C": [1624.683, 1675.806, 1557.708, 1833.702, 1823.738, 1733.977, 1732.131, 1679.732, 1727.26, 1024.687, 442.165, 1331.411, 1160.419],
    #   "S2L2A": [2106.761, 2141.107, 2038.973, 2134.138, 2085.321, 1889.926, 1820.257, 1871.918, 1753.829, 1797.379, 1434.261, 1334.311],
      "S1GRD": [5.195, 5.890],
    #   "S1RTC": [4.391, 4.459],
    #   "RGB": [58.767, 47.663, 42.631],
    #   "DEM": [951.272],
    },
    
    # Apply albumentations to augment the dataset
    train_transform=[
        albumentations.D4(), # Performs random flips and rotation
        albumentations.pytorch.transforms.ToTensorV2(),
    ],
    val_transform=None,  # Applies ToTensorV2() by default if not provided
    test_transform=None,
    
    no_label_replace=-1,  # Replace NaN labels. defaults to -1 which is ignored in the loss and metrics.
    no_data_replace=0,  # Replace NaN data
)

# Setup train and val datasets
datamodule.setup("fit")

## Plotting some training and validation examples

In [None]:
val_dataset = datamodule.val_dataset
train_dataset = datamodule.train_dataset
train_dataset.plot(train_dataset[50]) # to show some random plots of the training data
plt.show()
train_dataset.plot(train_dataset[57])
plt.show()
train_dataset.plot(train_dataset[200])
plt.show()

In [None]:
val_dataset.plot(val_dataset[8])
plt.show()
val_dataset.plot(val_dataset[15])
plt.show()
val_dataset.plot(val_dataset[68])
plt.show()

In [None]:
# checking test dataset size
# required later on when plotting test set predictions
datamodule.setup("test")
test_dataset = datamodule.test_dataset
len(test_dataset)

# Exploring the [TerraTorch](https://ibm.github.io/terratorch/quick_start/) model registry

In [None]:
# This just prints out all available TerraMind backbones in the registry.  
[backbone
 for backbone in TERRATORCH_BACKBONE_REGISTRY
 if 'terramind' in backbone
 ]

In [None]:
# Check available decoders
list(TERRATORCH_DECODER_REGISTRY)

In [None]:
# Just a glance at the model and its architecture
model = BACKBONE_REGISTRY.build("terramind_v1_small", pretrained=True)
model

# Building TerraMind and fine-tuning via PyTorch Lightning
Refer to the [publication](https://arxiv.org/pdf/2504.11171) and [repository](https://github.com/IBM/terramind/tree/main) for more details
- This section sets up the trainer for fine-tuning the model on the dataset
- The paths to store the logging details and checkpoint need to be provided
- The decoder layer parameters are updated while the model backbone are frozen
- Training for a set number of epochs (50 is shown below but change to lower number to quickly test the trainer set up)

## Setting up the trainer

In [None]:
pl.seed_everything(0)

# By default, TerraTorch saves the model with the best validation loss.
# You can overwrite this by defining a custom ModelCheckpoint, e.g., saving the model with the highest validation mIoU.  
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="../output/terramind_small_sen1floods11/checkpoints/", # Change as appropriate
    mode="max",
    monitor="val/mIoU", # Variable to monitor
    filename="best-mIoU",
    save_weights_only=True,
)

# Set up the lightning trainer
trainer = pl.Trainer(
    accelerator="cpu", # set to gpu if you have one
    strategy="auto",
    devices=1, # Deactivate multi-gpu because it often fails in notebooks
    precision=32, # Note: setting precision as "16-mixed" speeds up training with half precision
    num_nodes=1,
    logger=True,  # Uses TensorBoard by default
    max_epochs=50, # The higher the number of epoch the longer the training process and vice-versa
    log_every_n_steps=1,
    callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
    default_root_dir="../output/terramind_small_sen1floods11/", # Change as appropriate
)

# Segmentation task that builds the model and handles training and validation steps.  
model = terratorch.tasks.SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",  # Combines a backbone with necks, the decoder, and a head
    model_args={
        # TerraMind backbone
        "backbone": "terramind_v1_small", # change to specific model e.g., for large version: terramind_v1_large 
        "backbone_pretrained": True,
        "backbone_modalities": ["S2L1C", "S1GRD"],
        # Optionally, define the input bands. This is only needed if you select a subset of the pre-training bands, as explained above.
        # "backbone_bands": {"S1GRD": ["VV"]},
        
        # Necks 
        "necks": [
            {
                "name": "SelectIndices",
                "indices": [2, 5, 8, 11] # indices for terramind_v1_base & small
                # "indices": [5, 11, 17, 23] # indices for terramind_v1_large
            },
            {"name": "ReshapeTokensToImage",
             "remove_cls_token": False},  # TerraMind is trained without CLS token, which needs to be specified.
            {"name": "LearnedInterpolateToPyramidal"}  # Some decoders like UNet or UperNet expect hierarchical features.
        ],
        
        # Decoder
        "decoder": "UNetDecoder",
        "decoder_channels": [512, 256, 128, 64],
        
        # Head
        "head_dropout": 0.1,
        "num_classes": 2, # there are two classes in the mask label image
    },
    
    loss="dice",  # dice is recommended for binary tasks and ce for multi-class tasks. 
    optimizer="AdamW",
    lr=2e-5,  # We can perform hyperparameter optimization using terratorch-iterate but we have demonstrated that  
    ignore_index=-1,
    freeze_backbone=True, # Setting as True speeds up fine-tuning. It is recommended to fine-tune the backbone as well for the best performance. 
    freeze_decoder=False, # Should be false to update the decoder layer parameters
    plot_on_val=True,  # Plot predictions during validation steps  
    class_names=["Others", "Water"]  # optionally define class names
)

In [None]:
# run training
trainer.fit(model, datamodule=datamodule)

## Evaluate the model performance on the test dataset

In [None]:
# Test the fine-tuned model
# This prints out the test metrics to evaluate the performance of the model on the test dataset
best_ckpt_path = "../output/terramind_small_sen1floods11/checkpoints/best-mIoU.ckpt" # Change as appropriate
trainer.test(model,
             datamodule=datamodule,
             ckpt_path=best_ckpt_path,
             )

## Predicting some example test set

In [None]:
# This predicts and plots some example test set batch
model = terratorch.tasks.SemanticSegmentationTask.load_from_checkpoint(
    best_ckpt_path,
    model_factory=model.hparams.model_factory,
    model_args=model.hparams.model_args,
)

test_loader = datamodule.test_dataloader()
with torch.no_grad():
    batch = next(iter(test_loader)) # this only selects the first batch in the test_dataloader
    images = batch["image"]
    for mod, value in images.items():
        images[mod] = value.to(model.device)
    masks = batch["mask"].numpy()

    with torch.no_grad():
        outputs = model(images)
    
    preds = torch.argmax(outputs.output, dim=1).cpu().numpy()

for i in range(8): # 8 is the batch size so set this to <= 8
    sample = {
        "image": batch["image"]["S2L1C"][i].cpu(),
        "mask": batch["mask"][i],
        "prediction": preds[i],
    }
    test_dataset.plot(sample)
    plt.show()