# HLS Burn Scars

## Task: Adapt the notebook yourself to perform fine-tuning with TerraMind on HLS Burn Scars.

You find several TODOs in this notebook.

Use the dataset description (https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars) and the TerraMind docs (https://terrastackai.github.io/terratorch/stable/guide/terramind/) to solve the TODOs.

# Setup

In colab: 
1. Go to "Runtime" -> "Change runtime type" -> Select "T4 GPU"
2. Install TerraTorch

In [None]:
!pip install terratorch==1.1.1
!pip install gdown tensorboard

In [None]:
import os
import torch
import gdown
import terratorch
import albumentations
import numpy as np
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
import warnings
warnings.filterwarnings("ignore")

3. Download the dataset from Google Drive

In [None]:
# This version is an adaptation from https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars with splits from https://github.com/IBM/peft-geofm/tree/main/datasets_splits/burn_scars.

if not os.path.isfile('hls_burn_scars.tar.gz'):
    gdown.download("https://drive.google.com/uc?id=1yFDNlGqGPxkc9lh9l1O70TuejXAQYYtC")

if not os.path.isdir('hls_burn_scars/'):
    !tar -xzf hls_burn_scars.tar.gz

## HLS Burn Scars Dataset

Lets start with analysing the dataset

In [None]:
dataset_path = Path("hls_burn_scars")
!ls "hls_burn_scars"

In [None]:
!ls "hls_burn_scars/splits/" | head

In [None]:
!ls "hls_burn_scars/data/" | head

In [None]:
import rioxarray as rxr
sample = rxr.open_rasterio('hls_burn_scars/data/subsetted_512x512_HLS.S30.T10SDH.2020248.v1.4_merged.tif')
sample

TerraTorch provides generic data modules that work directly with PyTorch Lightning.

In [None]:
datamodule = terratorch.datamodules.GenericMultiModalDataModule(
    task="segmentation",
    batch_size=8,
    num_workers=2,
    num_classes=2,

    # TODO: Define your input modalities. The names must match the keys in the following dicts
    modalities=["TODO"],
    rgb_modality="TODO",  # Used for plotting. Defaults to the first modality if not provided.
    rgb_indices=[3,2,1],  # RGB channel positions in the rgb_modality.

    # TODO: Define data paths as dicts using the modality names as keys.
    train_data_root={
        "TODO": dataset_path / "TODO",
    },
    train_label_data_root=dataset_path / "TODO",
    val_data_root={
        "TODO": dataset_path / "TODO",
    },
    val_label_data_root=dataset_path / "TODO",
    test_data_root={
        "TODO": dataset_path / "TODO",
    },
    test_label_data_root=dataset_path / "TODO",

    # TODO: Define split files
    train_split=dataset_path / "TODO",
    val_split=dataset_path / "TODO",
    test_split=dataset_path / "TODO",

    # TODO: Define suffix, again using dicts.
    img_grep={
        "TODO": "*TODO",
    },
    label_grep="*TODO",

    # TODO: Update the standardization values. Needs to be the same length as the images.
    means={
      # TerraMind pretraining means:
      "S2L2A": [1390.458, 1503.317, 1718.197, 1853.910, 2199.100, 2779.975, 2987.011, 3083.234, 3132.220, 3162.988, 2424.884, 1857.648],
    },
    stds={
      # TerraMind pretraining stds:
      "S2L2A": [2106.761, 2141.107, 2038.973, 2134.138, 2085.321, 1889.926, 1820.257, 1871.918, 1753.829, 1797.379, 1434.261, 1334.311],
    },
    
        # albumentations supports shared transformations and can handle multimodal inputs. 
    train_transform=[
        albumentations.D4(), # Random flips and rotation
        albumentations.pytorch.transforms.ToTensorV2(),
    ],
    val_transform=None,  # Fallback to ToTensor
    test_transform=None,
    
    no_label_replace=-1,  # Replace NaN labels. defaults to -1 which is ignored in the loss and metrics.
    no_data_replace=0,  # Replace NaN data
)

# Setup train and val datasets
datamodule.setup("fit")

In [None]:
# checking datasets validation split size
val_dataset = datamodule.val_dataset
len(val_dataset)

In [None]:
# Plotting a few samples
val_dataset.plot(val_dataset[1])
plt.show()
val_dataset.plot(val_dataset[2])
plt.show()
val_dataset.plot(val_dataset[3])
plt.show()

In [None]:
# checking datasets testing split size
datamodule.setup("test")
test_dataset = datamodule.test_dataset
len(test_dataset)

# Fine-tune TerraMind via PyTorch Lightning

With TerraTorch, we can use standard Lightning components for the fine-tuning.
These include callbacks and the trainer class.
TerraTorch provides EO-specific tasks that define the training and validation steps.
In this case, we are using the `SemanticSegmentationTask`.
We refer to the [TerraTorch paper](https://arxiv.org/abs/2503.20563) for a detailed explanation of the TerraTorch tasks.

In [None]:
pl.seed_everything(0)

# By default, TerraTorch saves the model with the best validation loss. You can overwrite this by defining a custom ModelCheckpoint, e.g., saving the model with the highest validation mIoU.  
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="output/terramind_small_burnscars/checkpoints/",
    mode="max",
    monitor="val/mIoU", # Variable to monitor
    filename="best-mIoU",
    save_weights_only=True,
)

# Lightning Trainer
trainer = pl.Trainer(
    accelerator="auto",
    strategy="auto",
    devices=1, # Deactivate multi-gpu because it often fails in notebooks
    precision="16-mixed",  # Speed up training with half precision, delete for full precision training.
    num_nodes=1,
    logger=True,  # Uses TensorBoard by default
    max_epochs=3, # For demos
    log_every_n_steps=1,
    callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
    default_root_dir="output/terramind_small_burnscars/",
)

# Segmentation mask that build the model and handles training and validation steps.  
model = terratorch.tasks.SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",  # Combines a backbone with necks, the decoder, and a head
    model_args={
        # TerraMind backbone
        "backbone": "terramind_v1_small",
        "backbone_pretrained": True,
        # TODO Select the modality
        "backbone_modalities": ["TODO"],
        # TODO define the input bands. This is only needed because you need to select a subset of the pre-training bands for Burn Scars
        "backbone_bands": {"TODO": ["TODO"]},
        
        # Necks 
        "necks": [
            {
                "name": "SelectIndices",
                "indices": [2, 5, 8, 11] # indices for terramind_v1_base
                # "indices": [5, 11, 17, 23] # indices for terramind_v1_large
            },
            {"name": "ReshapeTokensToImage",
             "remove_cls_token": False},  # TerraMind is trained without CLS token, which neads to be specified.
            {"name": "LearnedInterpolateToPyramidal"}  # Some decoders like UNet or UperNet expect hierarchical features. Therefore, we need to learn a upsampling for the intermediate embedding layers when using a ViT like TerraMind.
        ],
        
        # Decoder
        "decoder": "UNetDecoder",
        "decoder_channels": [256, 128, 64, 32],
        
        # Head
        "head_dropout": 0.1,
        "num_classes": 2,
    },
    
    loss="dice",  # We recommend dice for binary tasks and ce for tasks with multiple classes. 
    optimizer="AdamW",
    lr=2e-5,  # The optimal learning rate varies between datasets, we recommend testing different once between 1e-5 and 1e-4. You can perform hyperparameter optimization using terratorch-iterate.  
    ignore_index=-1,
    freeze_backbone=True, # Only used to speed up fine-tuning in this demo, we highly recommend fine-tuning the backbone for the best performance. 
    freeze_decoder=False,  # Should be false in most cases as the decoder is randomly initialized.
    plot_on_val=True,  # Plot predictions during validation steps  
    class_names=["Others", "Burned"]  # optionally define class names
)

In [None]:
# Before starting the fine-tuning, you can start the tensorboard with:
%load_ext tensorboard
%tensorboard --logdir output

In [None]:
# Training
trainer.fit(model, datamodule=datamodule)

After fine-tuning, we can evaluate the model on the test set:

In [None]:
# Let's test the fine-tuned model
best_ckpt_path = "output/terramind_small_burnscars/checkpoints/best-mIoU.ckpt"
trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)

# Note: This demo only trains for 3 epochs by default, which does not result in good test metrics.

In [None]:
# Now we can use the model for predictions and plotting
model = terratorch.tasks.SemanticSegmentationTask.load_from_checkpoint(
    best_ckpt_path,
    model_factory=model.hparams.model_factory,
    model_args=model.hparams.model_args,
)

test_loader = datamodule.test_dataloader()
with torch.no_grad():
    batch = next(iter(test_loader))
    images = batch["image"]
    for mod, value in images.items():
        images[mod] = value.to(model.device)
    masks = batch["mask"].numpy()

    with torch.no_grad():
        outputs = model(images)
    
    preds = torch.argmax(outputs.output, dim=1).cpu().numpy()

for i in range(5):
    sample = {
        "image": batch["image"]["S2L2A"][i].cpu(),
        "mask": batch["mask"][i],
        "prediction": preds[i],
    }
    test_dataset.plot(sample)
    plt.show()
    
# Note: This demo only trains for 5 epochs by default, which does not result in good predictions.

# Burnscars config

If you are struggling with this task, you can check this burn scars config for guidance.

Please note, that this config uses the generic segmentation dataset instead of the generic multimodal dataset. The idea is similar, but the details are a bit different (e.g. dicts of strings).

In [None]:
# Download config
!wget https://raw.githubusercontent.com/IBM/terramind/refs/heads/main/configs/terramind_v1_base_burnscars.yaml

In [None]:
# Check the config
!cat terramind_v1_base_burnscars.yaml