# Setup
1. In colab: Go to "Runtime" -> "Change runtime type" -> Select "T4 GPU"
2. Install TerraTorch

In [None]:
!pip install terratorch==1.1
!pip install gdown tensorboard

In [None]:
import os
import sys
import torch
import gdown
import terratorch
import albumentations
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from terratorch.datamodules import MultiTemporalCropClassificationDataModule
import warnings
warnings.filterwarnings('ignore')

3. Download the dataset from Google Drive

In [None]:
# Download a random subset for demos (~1 GB)

if not os.path.isdir('multi-temporal-crop-classification-subset/'):
    if not os.path.isfile('multi-temporal-crop-classification-subset.tar.gz'):
        gdown.download("https://drive.google.com/uc?id=1SycflNslu47yfMg2i_z8FqYkhZQv7JQM")
    !tar -xzf multi-temporal-crop-classification-subset.tar.gz

dataset_path = "multi-temporal-crop-classification-subset"

## Multi-temporal Crop Dataset

Lets start with analyzing the dataset


In [None]:
!ls "{dataset_path}"

In [None]:
# Each merged sample includes the stacked bands of three time steps
!ls "{dataset_path}/training_chips" | head

In [None]:
# Adjusted dataset class for this dataset (general dataset could be used as well)
datamodule = MultiTemporalCropClassificationDataModule(
    batch_size=8,
    num_workers=2,
    data_root=dataset_path,
    train_transform=[
        terratorch.datasets.transforms.FlattenTemporalIntoChannels(),  # Required for temporal data
        albumentations.D4(), # Random flips and rotation
        albumentations.pytorch.transforms.ToTensorV2(),
        terratorch.datasets.transforms.UnflattenTemporalFromChannels(n_timesteps=3),
    ],
    val_transform=None,  # Using ToTensor() by default
    test_transform=None,
    expand_temporal_dimension=True,
    use_metadata=False, # The crop dataset has metadata for location and time
    reduce_zero_label=True,
)

# Setup train and val datasets
datamodule.setup("fit")

In [None]:
# checking for the dataset means and stds
datamodule.means, datamodule.stds

In [None]:
# checking datasets train split size
train_dataset = datamodule.train_dataset
len(train_dataset)

In [None]:
# checking datasets available bands
train_dataset.all_band_names

In [None]:
# checking datasets classes
train_dataset.class_names

In [None]:
# plotting a few samples
for i in range(5):
    train_dataset.plot(train_dataset[i])

In [None]:
# checking datasets validation split size
val_dataset = datamodule.val_dataset
len(val_dataset)

In [None]:
# checking datasets testing split size
datamodule.setup("test")
test_dataset = datamodule.test_dataset
len(test_dataset)

# Fine-tune TerraMind via PyTorch Lightning

With TerraTorch, we can use standard Lightning components for the fine-tuning.
These include callbacks and the trainer class.
TerraTorch provides EO-specific tasks that define the training and validation steps.
In this case, we are using the `SemanticSegmentationTask`.
We refer to the [TerraTorch paper](https://arxiv.org/abs/2503.20563) for a detailed explanation of the TerraTorch tasks.

## Temporal Wrapper

TerraMind does not support multi-temporal inputs natively. Therefore, we use the temporal wrapper that applies the encoder on each image and merges the latents before the decoder in a mid-fusion fashion. More details: https://ibm.github.io/terratorch/stable/guide/temporal_wrapper/

In [None]:
pl.seed_everything(0)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="output/terramind_base_multicrop/checkpoints/",
    mode="min",
    monitor="val/loss",
    filename="best-loss",
    save_weights_only=True,
)

# Lightning Trainer
trainer = pl.Trainer(
    accelerator="auto",
    strategy="auto",
    devices=1, # Lightning multi-gpu often fails in notebooks
    precision='16-mixed',  # Speed up training
    num_nodes=1,
    logger=True, # Uses TensorBoard by default
    max_epochs=3, # For demos
    log_every_n_steps=1,
    callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
    default_root_dir="output/terramind_base_multicrop",
)

# Model
model = terratorch.tasks.SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",
    model_args={
        # TerraMind backbone
        "backbone": "terramind_v1_base", # large version: terramind_v1_large 
        "backbone_pretrained": True,
        "backbone_modalities": ["S2L2A"],
        "backbone_bands": {"S2L2A": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"]},

        # Apply temporal wrapper (params are passed with prefix backbone_temporal)
        "backbone_use_temporal": True,
        "backbone_temporal_pooling": "concat",  # Defaults to "mean" which also supports flexible input lengths
        "backbone_temporal_n_timestamps": 3,  # Required for pooling = concat
        
        # Necks 
        "necks": [
            {
                "name": "SelectIndices",
                "indices": [2, 5, 8, 11] # indices for terramind_v1_tiny, small, and base
                # "indices": [5, 11, 17, 23] # indices for terramind_v1_large
            },
            {
                "name": "ReshapeTokensToImage",
                "remove_cls_token": False,
            },
            {"name": "LearnedInterpolateToPyramidal"},            
        ],
        
        # Decoder
        "decoder": "UNetDecoder",
        "decoder_channels": [512, 256, 128, 64],
        
        # Head
        "head_dropout": 0.1,
        "num_classes": 13,
    },
    
    loss="ce",
    lr=1e-4, # The optimal learning rate varies between datasets, we recommend testing different once between 1e-5 and 1e-4. You can perform hyperparameter optimization using terratorch-iterate.
    optimizer="AdamW",
    ignore_index=-1,
    freeze_backbone=True,  # Speeds up fine-tuning
    freeze_decoder=False,
    plot_on_val=True,
    class_names=["Natural Vegetation", "Forest", "Corn", "Soybeans", "Wetlands", "Developed / Barren", "Open Water", "Winter Wheat", "Alfalfa", "Fallow / Idle Cropland", "Cotton", "Sorghum", "Other"],
)

In [None]:
# Before starting the fine-tuning, you can start the tensorboard with:
%load_ext tensorboard
%tensorboard --logdir output

In [None]:
# Training
trainer.fit(model, datamodule=datamodule)

In [None]:
# Let's test the fine-tuned model
best_ckpt_path = "output/terramind_base_multicrop/checkpoints/best-loss.ckpt"
trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)

# Note: This demo only trains for a few epochs by default, which does not result in good test metrics.

In [None]:
# Now we can use the model for predictions and plotting
model = terratorch.tasks.SemanticSegmentationTask.load_from_checkpoint(
    best_ckpt_path,
    model_factory=model.hparams.model_factory,
    model_args=model.hparams.model_args,
)

test_loader = datamodule.test_dataloader()
with torch.no_grad():
    batch = next(iter(test_loader))
    images = batch["image"]
    images = images.to(model.device)
    masks = batch["mask"].numpy()

    with torch.no_grad():
        outputs = model(images)

    preds = torch.argmax(outputs.output, dim=1).cpu().numpy()

for i in range(5):
    sample = {
        "image": batch["image"][i].cpu(),
        "mask": batch["mask"][i],
        "prediction": preds[i],
    }
    test_dataset.plot(sample)
    plt.show()

# Note: This demo only trains for 5 epochs by default, which does not result in good predictions.

# Fine-tuning via CLI

Locally, run the fine-tuning command in your terminal rather than in this notebook.

In Colab, you want to restart the session to free up GPU memory and set `freeze_backbone: true` to avoid OOM errors.

In [None]:
# Download config
!wget https://raw.githubusercontent.com/IBM/terramind/refs/heads/main/configs/terramind_v1_base_multitemporal_crop.yaml

In [None]:
# Run fine-tuning
!terratorch fit -c terramind_v1_base_multitemporal_crop.yaml