## Overview
This notebook focuses on finetuning the terramind model to identify floods in a sentinel scene. The main take aways from this notebook will be as follows:
1. Learn how to use Terratorch to finetune terramind for floods in sentinel scene.
2. Understand the effects of spefic parameters in training and hardware utilization.

**Note:** The entirety of this notebook is tuned to work well in a sagemaker environment. Sagemaker Training Job will be used to train the models. If you are interested in running this in a colab environment, please leverage [ESA NASA Foundation Model Workshop](https://github.com/NASA-IMPACT/ESA-NASA-workshop-2025/tree/main/Track%201%20(EO)/TerraMind) materials.

## Setup
1. Go to "Kernel"
2. Select "terramind"

In [None]:
import gdown
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import numpy as np
import os
import rasterio
import sagemaker
import terratorch
import torch
import yaml

from datetime import time
from glob import glob
from huggingface_hub import hf_hub_download, snapshot_download
from pathlib import Path

from sagemaker import get_execution_role
from sagemaker.estimator import Estimator


3. Download the dataset from Google Drive

In [None]:
!mkdir ../data/
if not os.path.isfile("../data/sen1floods11_v1.1.tar.gz"):
    gdown.download("https://drive.google.com/uc?id=1lRw3X7oFNq_WyzBO6uyUJijyTuYm23VS", '../data/sen1floods11_v1.1.tar.gz')

!tar -xzf ../data/sen1floods11_v1.1.tar.gz -C ../data/

In [None]:
# Prepare sagemaker session with files uploaded to s3 bucket
dataset_path = Path("../data/sen1floods11_v1.1")
BUCKET_NAME = <BUCKET_NAME>

# Prepare sagemaker session with files uploaded to s3 bucket
sagemaker_session = sagemaker.Session()
all_files = sagemaker_session.upload_data(path=f'{dataset_path}', bucket=BUCKET_NAME, key_prefix='data')

In [7]:
# check config file.
import yaml
with open('../configs/terramind_v1_base_sen1floods11.yaml') as config_file:
    config = yaml.safe_load(config_file)

config

{'seed_everything': 42,
 'trainer': {'accelerator': 'auto',
  'strategy': 'auto',
  'devices': 'auto',
  'num_nodes': 1,
  'precision': '16-mixed',
  'logger': True,
  'callbacks': [{'class_path': 'RichProgressBar'},
   {'class_path': 'LearningRateMonitor',
    'init_args': {'logging_interval': 'epoch'}},
   {'class_path': 'ModelCheckpoint',
    'init_args': {'dirpath': '../output/burnscars/checkpoints',
     'mode': 'max',
     'monitor': 'val/Multiclass_Jaccard_Index',
     'filename': 'best-mIoU',
     'save_weights_only': True}}],
  'max_epochs': 100,
  'log_every_n_steps': 5,
  'default_root_dir': '../output/terramind_base_sen1floods11/'},
 'data': {'class_path': 'terratorch.datamodules.GenericMultiModalDataModule',
  'init_args': {'task': 'segmentation',
   'batch_size': 4,
   'num_workers': 4,
   'modalities': ['S2L1C', 'S1GRD'],
   'rgb_modality': 'S2L1C',
   'rgb_indices': [3, 2, 1],
   'train_data_root': {'S2L1C': '../data/sen1floods11_v1.1/data/S2L1CHand',
    'S1GRD': '../d

# TerraTorch model registry

TerraTorch includes its own backbone registry with many EO FMs. It also includes meta registries for all model components that include other sources like timm image models or SMP decoders.

In [None]:
from terratorch.registry import BACKBONE_REGISTRY, TERRATORCH_BACKBONE_REGISTRY, TERRATORCH_DECODER_REGISTRY

**Note:** TiM models are using the Thinking-in-Modalities approach, see our [paper](https://arxiv.org/abs/2504.11171) for details.

In [None]:
# Print all TerraMind v1 backbones.  
[backbone   
 for backbone in TERRATORCH_BACKBONE_REGISTRY
 if 'terramind_v1' in backbone]

In [None]:
# Available decoders. We use the UNetDecoder in this example.
list(TERRATORCH_DECODER_REGISTRY)

# TerraMind Training Configuration Overview

This section provides a detailed explanation of the parameters used for configuring a TerraMind model fine-tuning job.

---

## 📁 **Core Setup & Identifiers**

-   **`root_dir`**:
    -   **Explanation**: Specifies the absolute base directory where essential data and configurations are located, especially relevant in cloud environments like SageMaker.
    -   **Example**: `/opt/ml/data/`
    -   **Usage**: Defines the root for accessing input data and often for storing outputs if not otherwise specified.

-   **`identifier`**:
    -   **Explanation**: A user-defined string that serves as a unique prefix for naming all artifacts generated during the training and deployment process. This includes model checkpoints, logs, and output files.
    -   **Example**: `sen1floods_experiment_001`
    -   **Usage**: Helps in organizing and distinguishing different experimental runs.

-   **`usecase`**:
    -   **Explanation**: A string defining the specific Earth observation application or dataset the model is being trained or fine-tuned for.
    -   **Example**: `terramind_base_sen1floods11`
    -   **Usage**: Used for naming conventions, potentially for loading specific configurations, and organizing outputs related to this particular task.

-   **`data_path`**:
    -   **Explanation**: The local file system path where the input dataset resides before any processing or uploading to cloud storage.
    -   **Example**: `../data/sen1floods11_v1.1`
    -   **Usage**: Primarily used to locate raw data files for preprocessing, statistics calculation, or transferring to a training environment.

-   **`output_folder`**:
    -   **Explanation**: The base directory where all output artifacts from the training process (like model checkpoints, logs, and predictions) will be stored. It's constructed using the `root_dir`.
    -   **Example**: `/opt/ml/data/output` (derived from `root_dir`)
    -   **Usage**: Central location for all generated training outputs.

-   **`default_root_dir`**:
    -   **Explanation**: A more specific root directory, typically within the `output_folder`, tailored to the current `usecase`. It serves as the base for storing processed data, data splits, and other use-case-specific files.
    -   **Example**: `/opt/ml/data/output/terramind_base_sen1floods11/`
    -   **Usage**: Organizes data specific to a particular use case, like processed images and labels.

-   **`checkpoint_path`**:
    -   **Explanation**: The specific directory where model checkpoints (saved states of the model during training) are stored. This path is usually within the `output_folder`.
    -   **Example**: `/opt/ml/data/output/burnscars/checkpoints` (*Note: The example path seems to reference "burnscars" while the usecase is "sen1floods11". Ensure consistency in actual use.*)
    -   **Usage**: Essential for resuming training or for deploying a trained model from a specific point.

---

## 📊 **Data Standardization & Normalization**

-   **`data_mean`**:
    -   **Explanation**: A dictionary where keys are modality names (e.g., "S2L1C", "S1GRD") and values are lists of mean values for each band/channel of that modality. These values are crucial for normalizing the input data (subtracting the mean). Using pre-training values is common, but dataset-specific statistics can also be computed and used.
    -   **Example**: `{"S1GRD": [-12.599, -20.293], "S2L1C": [2357.089, ...]}`
    -   **Usage**: Applied during data preprocessing to standardize the input features, which typically helps in model convergence and performance.

-   **`data_stds`**:
    -   **Explanation**: Similar to `data_mean`, this dictionary holds the standard deviation values for each band/channel of the specified modalities. Used for scaling the input data after mean subtraction (dividing by standard deviation).
    -   **Example**: `{"S1GRD": [5.195, 5.890], "S2L1C": [1624.683, ...]}`
    -   **Usage**: Works in conjunction with `data_mean` for Z-score normalization of the input data.

---

## ⚙️ **Training Job Parameters**

-   **`batch_size`**:
    -   **Explanation**: The number of data samples (e.g., image patches) that are processed by the model in one forward/backward pass during training.
    -   **Example**: `4`
    -   **Usage**: Affects memory consumption (GPU VRAM) and training stability. Larger batch sizes can lead to faster training per epoch but require more memory.

-   **`num_workers`**:
    -   **Explanation**: The number of CPU worker processes assigned to load and preprocess data in parallel while the GPU is busy with model computations.
    -   **Example**: `2`
    -   **Usage**: Can significantly speed up data loading, preventing bottlenecks if data preprocessing is complex. The optimal value depends on CPU cores and I/O capabilities.

-   **`num_classes`**:
    -   **Explanation**: The total number of distinct classes the model is expected to predict in a segmentation task. For binary segmentation (e.g., flood vs. no-flood), this would be 2.
    -   **Example**: `2`
    -   **Usage**: Defines the output dimension of the final layer in the segmentation model.

---

## 🛰️ **Data Modality & Structure Configuration**

-   **`modalities`**:
    -   **Explanation**: A list of strings specifying the names of the data modalities (e.g., sensor types or data products) that will be used as input to the model.
    -   **Example**: `["S2L1C", "S1GRD"]`
    -   **Usage**: Informs the data loader which types of data to load and process for each sample.

-   **`rgb_modality`**:
    -   **Explanation**: A string specifying which of the input `modalities` should be considered the primary source for generating RGB visualizations or plots. If not provided, it typically defaults to the first modality in the `modalities` list.
    -   **Example**: `"S2L1C"`
    -   **Usage**: Used by visualization tools or callbacks to display image samples during or after training.

-   **`rgb_indices`**:
    -   **Explanation**: A list of integers indicating the band positions (0-indexed) within the `rgb_modality` that correspond to the Red, Green, and Blue channels, respectively, for creating true or false-color composite images.
    -   **Example**: `[3, 2, 1]` (for S2L1C, this would typically be B4, B3, B2 for a standard RGB view)
    -   **Usage**: Essential for correct visualization of multi-band satellite imagery.

-   **`train_data_root`**, **`val_data_root`**, **`test_data_root`**:
    -   **Explanation**: Dictionaries where keys are modality names and values are the file paths to the root directories containing the training, validation, and testing image data for each respective modality.
    -   **Example (`train_data_root`)**: `{"S2L1C": "/opt/ml/data/output/terramind_base_sen1floods11/S2L1CHand", "S1GRD": "..."}`
    -   **Usage**: Tells the data loader where to find the image files for each data split and modality.

-   **`train_label_data_root`**, **`val_label_data_root`**, **`test_label_data_root`**:
    -   **Explanation**: Strings specifying the file paths to the root directories containing the corresponding label (ground truth segmentation masks) data for the training, validation, and testing sets.
    -   **Example (`train_label_data_root`)**: `/opt/ml/data/output/terramind_base_sen1floods11/LabelHand`
    -   **Usage**: Directs the data loader to the location of ground truth masks.

-   **`train_split`**, **`val_split`**, **`test_split`**:
    -   **Explanation**: File paths to text files that define the data splits. These files typically list the names or identifiers of the samples belonging to the training, validation, and testing sets, respectively. This is useful when all image/label files are stored in common directories.
    -   **Example (`train_split`)**: `/opt/ml/data/output/terramind_base_sen1floods11/splits/flood_train_data.txt`
    -   **Usage**: Ensures consistent and reproducible data splits across different runs.

-   **`img_grep`**:
    -   **Explanation**: A dictionary where keys are modality names and values are glob patterns (wildcard expressions) used to identify and filter image files for each modality within their respective data root directories.
    -   **Example**: `{"S2L1C": "*_S2Hand.tif", "S1GRD": "*_S1Hand.tif"}`
    -   **Usage**: Helps in selectively loading files that match a specific naming convention for each modality.

-   **`label_grep`**:
    -   **Explanation**: A glob pattern used to identify and filter label files (segmentation masks) within their data root directory.
    -   **Example**: `*_LabelHand.tif`
    -   **Usage**: Selectively loads label files matching a specific naming convention.

-   **`dataset_bands`**:
    -   **Explanation**: An optional dictionary where keys are modality names and values are lists of strings representing the names of all available bands in the original dataset files for that modality. This provides context for band selection.
    -   **Example**: `{"S1GRD": ["VV", "VH"]}`
    -   **Usage**: Used in conjunction with `output_bands` to specify which bands to select from the source files.

-   **`output_bands`**:
    -   **Explanation**: An optional dictionary, similar to `dataset_bands`. It specifies the subset of bands (by name, corresponding to names in `dataset_bands`) that should actually be extracted and used as input to the model for each modality. If not provided for a modality, all bands are typically used.
    -   **Example**: `{"S1GRD": ["VV", "VH"]}` (Here, selecting both available S1GRD bands. Could be `{"S1GRD": ["VV"]}` to use only VV).
    -   **Usage**: Allows for experimentation with different band combinations without modifying the original data files. The `data_mean` and `data_stds` must align with these selected `output_bands`.

-   **`no_label_replace`**:
    -   **Explanation**: A value used to replace NaN (Not a Number) or missing values in the label data (ground truth masks). The default of `-1` is often used because it can be configured to be ignored by the loss function and evaluation metrics.
    -   **Example**: `-1`
    -   **Usage**: Handles missing or invalid pixels in ground truth data.

-   **`no_data_replace`**:
    -   **Explanation**: A value used to replace NaN or missing values in the input image data across all modalities.
    -   **Example**: `0`
    -   **Usage**: Ensures that the model receives valid numerical inputs by handling missing pixels in the source imagery.

---

## 🧠 **Model Architecture & Training Strategy**

-   **`terramind_backbone`**:
    -   **Explanation**: A string specifying the pre-trained TerraMind foundation model architecture to be used as the backbone (feature extractor) for the fine-tuning task. Different versions offer trade-offs in size, speed, and performance.
    -   **Choices**: `'terramind_v1_base'`, `'terramind_v1_base_tim'`, `'terramind_v1_large'`, `'terramind_v1_large_tim'`
    -   **Example**: `'terramind_v1_base'`
    -   **Usage**: Critical choice that determines the core feature extraction capabilities of the model.

-   **`max_epochs`**:
    -   **Explanation**: The maximum number of times the entire training dataset will be passed forward and backward through the neural network.
    -   **Example**: `100`
    -   **Usage**: Controls the total duration of training. Training might stop earlier if early stopping criteria are met.

-   **`indices`**:
    -   **Explanation**: A list of integers specifying which Transformer blocks (by their 0-indexed layer number) from the TerraMind backbone will provide their output feature embeddings to the decoder part of the model. This enables the decoder to use multi-scale features. The appropriate indices depend on the chosen `terramind_backbone`'s architecture (e.g., 'base' vs. 'large' versions have different total numbers of layers).
    -   **Conditional Logic**: The provided code snippet shows logic to set these indices based on whether a 'base' or 'large' backbone is selected.
        -   For 'base' models: `[2, 5, 8, 11]`
        -   For 'large' models: `[5, 11, 17, 23]`
    -   **Usage**: Crucial for how the decoder reconstructs segmentation maps from backbone features.

-   **`model_factory`**:
    -   **Explanation**: A string specifying the factory class responsible for constructing the overall model architecture. The `EncoderDecoderFactory` typically combines the chosen backbone with a neck (optional), a decoder, and a task-specific head.
    -   **Example**: `"EncoderDecoderFactory"`
    -   **Usage**: Defines the high-level structure for assembling the model components.

-   **`remove_cls_token`**:
    -   **Explanation**: A boolean indicating whether the special CLS (classification) token, often used in Vision Transformer architectures for image-level classification tasks, should be removed or ignored when using the backbone for dense prediction tasks like segmentation. TerraMind models are often trained without relying on a CLS token for segmentation.
    -   **Example**: `False` (*The comment says "TerraMind is trained without CLS token, which neads to be specified." This suggests it should likely be `True` if the backbone was pre-trained without a CLS token for segmentation and the fine-tuning setup needs to account for that. Clarify based on TerraMind's specific requirements.*)
    -   **Usage**: Ensures compatibility between the backbone's output and the decoder's input for segmentation.

-   **`freeze_backbone`**:
    -   **Explanation**: A boolean value. If `True`, the weights of the pre-trained TerraMind backbone are frozen and not updated during fine-tuning. Only the weights of the newly added decoder and head are trained.
    -   **Example**: `True`
    -   **Usage**: Can speed up fine-tuning and reduce memory usage, especially for demonstrations or when fine-tuning data is very limited. However, for best performance, setting this to `False` (fine-tuning the entire backbone or parts of it) is generally recommended.

-   **`freeze_decoder`**:
    -   **Explanation**: A boolean value. If `True`, the weights of the decoder part of the model are frozen. This is rarely set to `True` because the decoder is typically randomly initialized or adapted from a different task and needs to be trained for the specific fine-tuning use case.
    -   **Example**: `False`
    -   **Usage**: Should generally be `False` to allow the decoder to learn its task.

-   **`class_names`**:
    -   **Explanation**: An optional list of strings that provide human-readable names for each class index. The order of names should correspond to the class indices (e.g., class 0, class 1, ...).
    -   **Example**: `["Others", "Water"]` (for `num_classes = 2`)
    -   **Usage**: Useful for logging, creating legends in visualizations, and interpreting model outputs.

In [None]:
root_dir = '/opt/ml/data'

# Parameters to modify
identifier = <identifier>

# Usecase definition
usecase = "terramind_base_sen1floods11"

#local data path
data_path = '../data/sen1floods11_v1.1'

# Output base directory
output_folder = f"{root_dir}/output"

# Default root dir for all data and splits
default_root_dir = f"{output_folder}/terramind_base_sen1floods11/"

# Path to save checkpoints in
checkpoint_path = f"{output_folder}/{usecase}/checkpoints"

# Define standardization values. We use the pre-training values here and providing the additional modalities is not a problem, 
# which makes it simple to experiment with different modality combinations. 
# Alternatively, use the dataset statistics that you can generate using 
# `terratorch compute_statistics -c config.yaml` (requires concat_bands: true for this multimodal datamodule).
data_mean = {
    "S2L1C": [2357.089, 2137.385, 2018.788, 2082.986, 2295.651, 2854.537, 3122.849, 3040.560, 3306.481, 1473.847, 506.070, 2472.825, 1838.929],
    "S2L2A": [1390.458, 1503.317, 1718.197, 1853.910, 2199.100, 2779.975, 2987.011, 3083.234, 3132.220, 3162.988, 2424.884, 1857.648],
    "S1GRD": [-12.599, -20.293],
    "S1RTC": [-10.93, -17.329],
    "RGB": [87.271, 80.931, 66.667],
    "DEM": [670.665]
}

data_stds = {
    "S2L1C": [1624.683, 1675.806, 1557.708, 1833.702, 1823.738, 1733.977, 1732.131, 1679.732, 1727.26, 1024.687, 442.165, 1331.411, 1160.419],
    "S2L2A": [2106.761, 2141.107, 2038.973, 2134.138, 2085.321, 1889.926, 1820.257, 1871.918, 1753.829, 1797.379, 1434.261, 1334.311],
    "S1GRD": [5.195, 5.890],
    "S1RTC": [4.391, 4.459],
    "RGB": [58.767, 47.663, 42.631],
    "DEM": [951.272],
}

# Batches of examples to use at a time
batch_size = 4

# Number of training job workers
num_workers = 2

# Number of classes for segmentation
num_classes = 2


modalities = ["S2L1C", "S1GRD"]
rgb_modality = "S2L1C"  # Used for plotting. Defaults to the first modality if not provided.
rgb_indices = [3,2,1]  # RGB channel positions in the rgb_modality.

# Define data paths as dicts using the modality names as keys.
train_data_root = {
    "S2L1C": f"{root_dir}/data/S2L1CHand",
    "S1GRD": f"{root_dir}/data/S1GRDHand",
}
train_label_data_root = f"{root_dir}/data/LabelHand"
val_data_root = {
    "S2L1C": f"{root_dir}/data/S2L1CHand",
    "S1GRD": f"{root_dir}/data/S1GRDHand",
}
val_label_data_root = f"{root_dir}/data/LabelHand"
test_data_root = {
    "S2L1C": f"{root_dir}/data/S2L1CHand",
    "S1GRD": f"{root_dir}/data/S1GRDHand",
}
test_label_data_root = f"{root_dir}/data/LabelHand"

# Define split files because all samples are saved in the same folder.
train_split = f"{root_dir}/splits/flood_train_data.txt"
val_split = f"{root_dir}/splits/flood_valid_data.txt"
test_split = f"{root_dir}/splits/flood_test_data.txt"



# Define suffix, again using dicts.
img_grep = {
    "S2L1C": "*_S2Hand.tif",
    "S1GRD": "*_S1Hand.tif",
}
label_grep = "*_LabelHand.tif"

# With TerraTorch, you can select a subset of the dataset bands as model inputs by providing dataset_bands (all bands in the data) and output_bands (selected bands). This setting is optional for all modalities and needs to be provided as dicts.
# Here is an example for with S-1 GRD. You could change the output to ["VV"] to only train on the first band. Note that means and stds must be aligned with the output_bands (equal length of values). 
dataset_bands = {
    "S1GRD": ["VV", "VH"]
}
output_bands = {
    "S1GRD": ["VV", "VH"]
}

no_label_replace = -1  # Replace NaN labels. defaults to -1 which is ignored in the loss and metrics.
no_data_replace = 0  # Replace NaN data

terramind_backbone = 'terramind_v1_base' # choice of 'terramind_v1_base', 'terramind_v1_base_tim', 'terramind_v1_large', 'terramind_v1_large_tim'

max_epochs = 100

indices = [5, 11, 17, 23]
if 'terramind_v1_base' in terramind_backbone:
    indices = [2, 5, 8, 11]  # indices for terramind_v1_base
elif 'terramind_v1_large' in terramind_backbone: 
    indices = [5, 11, 17, 23]  # indices for terramind_v1_large

# Define standardization values. We use the pre-training values here and providing the additional modalities is not a problem, which makes it simple to experiment with different modality combinations. Alternatively, use the dataset statistics that you can generate using `terratorch compute_statistics -c config.yaml` (requires concat_bands: true for this multimodal datamodule).
data_mean = {
    "S2L1C": [2357.089, 2137.385, 2018.788, 2082.986, 2295.651, 2854.537, 3122.849, 3040.560, 3306.481, 1473.847, 506.070, 2472.825, 1838.929],
    "S2L2A": [1390.458, 1503.317, 1718.197, 1853.910, 2199.100, 2779.975, 2987.011, 3083.234, 3132.220, 3162.988, 2424.884, 1857.648],
    "S1GRD": [-12.599, -20.293],
    "S1RTC": [-10.93, -17.329],
    "RGB": [87.271, 80.931, 66.667],
    "DEM": [670.665]
}

data_stds = {
    "S2L1C": [1624.683, 1675.806, 1557.708, 1833.702, 1823.738, 1733.977, 1732.131, 1679.732, 1727.26, 1024.687, 442.165, 1331.411, 1160.419],
    "S2L2A": [2106.761, 2141.107, 2038.973, 2134.138, 2085.321, 1889.926, 1820.257, 1871.918, 1753.829, 1797.379, 1434.261, 1334.311],
    "S1GRD": [5.195, 5.890],
    "S1RTC": [4.391, 4.459],
    "RGB": [58.767, 47.663, 42.631],
    "DEM": [951.272],
}

model_factory="EncoderDecoderFactory"  # Combines a backbone with necks, the decoder, and a head

remove_cls_token = False # TerraMind is trained without CLS token, which neads to be specified.

freeze_backbone = True # Only used to speed up fine-tuning in this demo, we highly recommend fine-tuning the backbone for the best performance. 
freeze_decoder = False  # Should be false in most cases as the decoder is randomly initialized.
class_names = ["Others", "Water"]  # optionally define class names

In [None]:
# Trainer specific configurations
config['trainer']['default_root_dir'] = default_root_dir
config['trainer']['max_epochs'] = max_epochs
config['trainer']['callbacks'][-1]['init_args']['dirpath'] = checkpoint_path

# Model specific configurations
config['model']['init_args']['model_args']['necks'][0]['indices'] = indices
config['model']['init_args']['model_args']['backbone'] = terramind_backbone
config['model']['init_args']['model_factory'] = model_factory
config['model']['init_args']['class_names'] = class_names
config['model']['init_args']['freeze_decoder'] = freeze_decoder
config['model']['init_args']['freeze_backbone'] = freeze_backbone

for index, neck in enumerate(config['model']['init_args']['model_args']['necks']):
    if neck['name'] == 'ReshapeTokensToImage':
        config['model']['init_args']['model_args']['necks'][index]['remove_cls_token'] = remove_cls_token
    elif neck['name'] == 'LearnedInterpolateToPyramidal':
        # Some decoders like UNet or UperNet expect hierarchical features. 
        # Therefore, we need to learn a upsampling for the intermediate embedding layers when using a ViT like TerraMind.
        continue


# Data specific configurations
config['data']['init_args']['train_data_root'] = train_data_root
config['data']['init_args']['test_data_root'] = test_data_root
config['data']['init_args']['val_data_root'] = val_data_root

config['data']['init_args']['train_label_data_root'] = train_label_data_root
config['data']['init_args']['val_label_data_root'] = val_label_data_root
config['data']['init_args']['test_label_data_root'] = test_label_data_root

config['data']['init_args']['train_split'] = train_split
config['data']['init_args']['val_split'] = val_split
config['data']['init_args']['test_split'] = test_split

config['data']['init_args']['modalities'] = modalities

config['data']['init_args']['label_grep'] = label_grep
config['data']['init_args']['image_grep'] = img_grep

config['data']['init_args']['num_classes'] = num_classes
config['data']['init_args']['no_data_replace'] = no_data_replace
config['data']['init_args']['no_label_replace'] = no_label_replace

config['data']['init_args']['means'] = data_mean 
config['data']['init_args']['stds'] = data_stds

config['data']['init_args']['batch_size'] = batch_size
config['data']['init_args']['num_workers'] = num_workers

config['data']['init_args']['num_classes'] = num_classes

In [None]:
# Rename configuration file name to user specific filename
import os

config_filename = f"{identifier}-flood_terramind.yaml"
config_filepath = f"../configs/{config_filename}"
with open(config_filepath, 'w') as config_file:
    yaml.dump(config, config_file, default_flow_style=False)

# Upload config files to s3 bucket
configs = sagemaker_session.upload_data(path=config_filepath, bucket=BUCKET_NAME, key_prefix='data/configs')

## Sen1Floods11 Dataset

Lets start with analysing the dataset

TerraTorch provides generic data modules that work directly with PyTorch Lightning.

Sen1Floods11 is a multimodal dataset that provides Sentinel-2 L2A and Sentinel-1 GRD data. 
Therefore, we are using the `GenericMultiModalDataModule`. 
This module is similar to the `GenericNonGeoSegmentationDataModule`, which is used for standard segmentation tasks.
However, the data roots, `img_grep` are other settings are provided as dict to account for the multimodal inputs. You find all settings in the [documentation](https://ibm.github.io/terratorch/stable/generic_datamodules/). 
In a Lightning config, the data module is defined with the `data` key.

In [None]:
# Helper function for plotting both modalities
def plot_sample(sample):
    s1 = sample['image']['S1GRD']
    s2 = sample['image']['S2L1C']
    mask = sample['mask']
    
    # Scaling data. Using -30 to 0 scaling for S-1 and 0 - 2000 for S-2. S-1 is visualized as [VH, VV, VH]
    s1 = (s1.clip(-30, 0) / 30 + 1) * 255
    s2 = (s2.clip(0, 2000) / 2000) * 255
    s1_rgb = np.stack([s1[1], s1[0], s1[1]], axis=0).astype(np.uint8).transpose(1,2,0)
    s2_rgb = s2[[3,2,1]].astype(np.uint8).transpose(1,2,0)
    
    fig, ax = plt.subplots(1, 3, figsize=(12, 4))
    ax[0].imshow(s1_rgb)
    ax[0].set_title('S-1 GRD')
    ax[0].axis('off')
    ax[1].imshow(s2_rgb)
    ax[1].set_title('S-2 GRD')
    ax[1].axis('off')   
    ax[2].imshow(mask, vmin=-1, vmax=1, interpolation='nearest')
    ax[2].set_title('Mask')
    ax[2].axis('off')
    fig.tight_layout()
    plt.show()

In [None]:
# plot some samples
import random
from glob import glob

samples = glob(f"{data_path}/data/S*/*.tif")
samples = random.sample(samples, 3)

for sample in samples:
    s1_file = sample.replace('S2L1CHand', 'S1GRDHand').replace('S2Hand', 'S1Hand')
    s2_file = sample.replace('S1GRDHand', 'S2L1CHand').replace('S1Hand', 'S2Hand')
    mask_file = sample.replace('S1GRDHand', 'LabelHand').replace('S2L1CHand', 'LabelHand').replace('S2Hand', 'LabelHand').replace('S1Hand', 'LabelHand')
    updated_values = {
        'image': {
            'S1GRD': rasterio.open(s1_file).read(),
            'S2L1C': rasterio.open(s2_file).read(),
        },
        'mask': rasterio.open(mask_file).read()[0]
    }
    plot_sample(updated_values)


# Fine-tune TerraMind via PyTorch Lightning

With TerraTorch, we can use standard Lightning components for the fine-tuning.
These include callbacks and the trainer class.
TerraTorch provides EO-specific tasks that define the training and validation steps.
In this case, we are using the `SemanticSegmentationTask`.
We refer to the [TerraTorch paper](https://arxiv.org/abs/2503.20563) for a detailed explanation of the TerraTorch tasks.

In [None]:
# Setup variables for training using sagemaker

name = f'{identifier}-sagemaker'
role = get_execution_role()
input_s3_uri = f"s3://{BUCKET_NAME}/data"
model_name = f"{identifier}-workshop.ckpt",
model_path = f"{root_dir}/{usecase}/checkpoints"

environment_variables = {
    'CONFIG_FILE': f"{root_dir}/configs/{config_filename}",
    'MODEL_DIR': model_path,
    'MODEL_NAME': model_name,
    'S3_URL': input_s3_uri,
    'ROLE_ARN': role,
    'ROLE_NAME': role.split('/')[-1],
    'EVENT_TYPE': usecase,
    'SPLITS': 'splits,configs,data/LabelHand,data/S1GRDHand,data/S2L1CHand',
    'VERSION': 'v1'
}

account_id = boto3.client('sts').get_caller_identity().get('Account')
ecr_container_url = f'{account_id}.dkr.ecr.us-west-2.amazonaws.com/eo_training:latest'
sagemaker_role = get_execution_role().split('/')[-1]

instance_type = 'ml.p3.2xlarge'

instance_count = 1
memory_volume = 100 


In [None]:

# Establish an estimator (model) using sagemaker and the configurations from the previous cell.
estimator = Estimator(image_uri=ecr_container_url,
                      role=get_execution_role(),
                      base_job_name=name,
                      instance_count=1,
                      environment=environment_variables,
                      instance_type=instance_type)

estimator.fit()

In [None]:

# Save important values in a file for reuse.
export_values = {
    'identifier': identifier,
    'model_name': model_name,
    'config_filename': config_filename,
    'bucket_name': BUCKET_NAME
}

with open('../variables.yaml', 'w') as variable_export:
    yaml.dump(export_values, variable_export, default_flow_style=False)


# Test the fine-tuned model 

In [None]:
# Let's test the fine-tuned model
best_ckpt_path = f"{output_folder}/terramind_base_sen1floods11/checkpoints/best-mIoU.ckpt"
trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)

# Note: This demo only trains for 5 epochs by default, which does not result in good test metrics.

In [None]:
# Now we can use the model for predictions and plotting
model = terratorch.tasks.SemanticSegmentationTask.load_from_checkpoint(
    best_ckpt_path,
    model_factory=model.hparams.model_factory,
    model_args=model.hparams.model_args,
)

test_loader = datamodule.test_dataloader()
with torch.no_grad():
    batch = next(iter(test_loader))
    images = batch["image"]
    for mod, value in images.items():
        images[mod] = value.to(model.device)
    masks = batch["mask"].numpy()

    with torch.no_grad():
        outputs = model(images)
    
    preds = torch.argmax(outputs.output, dim=1).cpu().numpy()

for i in range(4):
    sample = {
        "image": batch["image"]["S2L1C"][i].cpu(),
        "mask": batch["mask"][i],
        "prediction": preds[i],
    }
    test_dataset.plot(sample)
    plt.show()
    
# Note: This demo only trains for 5 epochs by default, which does not result in good predictions.