In [None]:
# Includes work derived from International Business Machines
# `prithvi_v2_eo_300_tl_unet_multitemporal_crop.ipynb``
# Licensed under the Apache License, Version 2.0 (the "License");
# You may obtain a copy of the License at
#  http://www.apache.org/licenses/LICENSE-2.0
#
# All rights reserved for this derived work.

In [23]:
# Bash command to install dependencies from pyproject.toml with `pip``;
# ouput is piped to `tail`` to limit length of text printed below.
!pip install -e . | tail -n 7

Successfully built nasa-prithvi-wetlands
Installing collected packages: nasa-prithvi-wetlands
  Attempting uninstall: nasa-prithvi-wetlands
    Found existing installation: nasa-prithvi-wetlands 0.1.0
    Uninstalling nasa-prithvi-wetlands-0.1.0:
      Successfully uninstalled nasa-prithvi-wetlands-0.1.0
Successfully installed nasa-prithvi-wetlands-0.1.0


## Overview
This notebook focuses on fine-tuning the [Prithvi EO v2.0 model](https://huggingface.co/collections/ibm-nasa-geospatial/prithvi-for-earth-observation-6740a7a81883466bf41d93d6) to classify seagrass.

This notebook:
1. Is intended to be run on Google Colab.
2. Uses Terratorch to fine-tune Prithvi EO v2.0 300m.
3. Uses a seagrass patch dataset for fine-tuning derived from the IMaRS SIMM Seagrass Project.
4. Uses fine-tuned model for inference.

You may want to take this opportunity to double check you're using GPUs on Google Colab before proceeding any further. We have tested this notebook using T4 GPU on the free colab account.

## Setup
1. Install terratorch

To install the necessary packages, execute the cell below. This will take a few minutes. Once the installation process is done, a window will pop up to ask you to restart the session. This is normal and you should proceed to restart using the interface in the pop up window. Once the session has restarted, its important that you ignore the cell below, and go straight to section 0.1.3.


2. Import dependencies

In [24]:

import albumentations
import gdown
import lightning.pytorch as pl
import os
import terratorch
import torch
import matplotlib.pyplot as plt
import warnings
from pathlib import Path
from terratorch.datamodules import MultiTemporalCropClassificationDataModule

warnings.filterwarnings('ignore')

In [None]:
# Download tuning dataset .bz2 from Google Drive and place in `dataset_path`` directory.
dataset_path = "data/tuning_patches"
fname = "seagrass_tuning_patches.tar.bz2"
gdown.download(
https://drive.google.com/file/d//view?usp=sharing    
    "https://drive.google.com/uc?id=1VWl2mkTTAG3ih741n3S5UDUj71ygNKDp",
    os.path.join(dataset_path, fname),
    quiet=False,
)

# uzip the downloaded .bz2 file
!bzip2 -xvjf {os.path.join(dataset_path, fname)}

A tuning dataset should now be in the dataset_path directory.
The dataset used is derived from the IMaRS SIMM Seagrasss Project.
For methods to generate the patches, see the `py/generate_seagrass_patches.py` script in the repo.



The patches need to be split into training and validation sets.
The sets are specified using a `.txt` file listing the patch file names for each set.
Example `training_chips.txt`:

```
chip_257_266
chip_328_501
chip_171_477
chip_236_281
chip_134_482
chip_120_493
chip_161_390
```

The training and validation patches are expected to be in `training_chips` and `validation_chips` subdirectories of the dataset path.

4. Truncate the dataset for demonstration purposes. Reducing the training dataset to a third of the original size means that model training takes only a few minutes with the resources available during the workshop.

In [None]:
training_data_truncation = 800
validation_data_trunction = 4
with open(f"{dataset_path}/training_data.txt", "r") as f:
      training_data_list = f.readlines()
truncated = training_data_list[0:training_data_truncation]
with open(f"{dataset_path}/training_data.txt", "w") as f:
    for i in truncated:
        f.write(i)

with open(f"{dataset_path}/validation_data.txt", "r") as f:
      training_data_list = f.readlines()
truncated = training_data_list[0:validation_data_trunction]
with open(f"{dataset_path}/validation_data.txt", "w") as f:
    for i in truncated:
        f.write(i)

## Dataset Details

Lets start with analysing the dataset. 

Please note: we have also set the batch_size parameter to 4 and max_epochs to 1 to avoid running out of memory or runtime for users of the free tier colab compute resources. This is enough to demonstrate the entire workflow to the user, but may not result in the best performance. It'll be best to find additional compute resources and increase batch_size and max_epochs in the downloaded config file for improved performance.


In [12]:
# Each merged sample includes the stacked bands of three time steps
!ls "{dataset_path}/training_chips" | head

ls: cannot access 'data/tuning_patches/training_chips': No such file or directory


In [None]:
# Modify parameters to select the batch size, number of workers, model backbone and epochs ahead of initalizing the MultiTemporalCropClassificationDataModule class for multi-temporal crop classification. 
batch_size = 4
num_workers = 2
prithvi_backbone = "prithvi_eo_v2_300_tl" # Model can be either prithvi_eo_v1_100, prithvi_eo_v2_300, prithvi_eo_v2_300_tl, prithvi_eo_v2_600, prithvi_eo_v2_600_tl

# Total number of epochs the training will run for.
max_epochs =  1 # Use 1 epoch for demos


#### Initialise the Datamodules class 

A Datamodule is a shareable, reusable class that encapsulates all the steps needed to process the data. Here we are using an adjusted dataset class for this dataset (general dataset class could be used as well). To learn more about MultiTemporalCropClassificationDataModule, take a look at the [TerraTorch docs](https://ibm.github.io/terratorch/stable/datamodules/?h=multitemporalcropclassificationdatamodule#terratorch.datamodules.multi_temporal_crop_classification.MultiTemporalCropClassificationDataModule).

In [None]:
datamodule = MultiTemporalCropClassificationDataModule(
    batch_size=batch_size,
    num_workers=num_workers,
    data_root=f"{dataset_path}",
    train_transform=[
        terratorch.datasets.transforms.FlattenTemporalIntoChannels(),  # Required for temporal data
        albumentations.D4(), # Random flips and rotation
        albumentations.pytorch.transforms.ToTensorV2(),
        terratorch.datasets.transforms.UnflattenTemporalFromChannels(n_timesteps=3),
    ],
    val_transform=None,  # Using ToTensor() by default
    test_transform=None,
    expand_temporal_dimension=True,
    use_metadata=False, # The crop dataset has metadata for location and time
    reduce_zero_label=True,
)

In [None]:
# Setup train and val datasets
datamodule.setup("fit")

In [None]:
datamodule.batch_size

In [None]:
# Mean and standard deviation calculated from the training dataset for all 6 bands, and 3 timesteps, for zero mean normalization.
# checking for the dataset means and stds
datamodule.means, datamodule.stds

In [None]:
# checking datasets train split size
train_dataset = datamodule.train_dataset
len(train_dataset)

In [None]:
# checking datasets available bands
train_dataset.all_band_names

In [None]:
# checking datasets classes
train_dataset.class_names

In [None]:
# plotting a few samples
for i in range(5):
    train_dataset.plot(train_dataset[i])

In [None]:
# checking datasets validation split size
val_dataset = datamodule.val_dataset
len(val_dataset)

In [None]:
# checking datasets testing split size
datamodule.setup("test")
test_dataset = datamodule.test_dataset
len(test_dataset)

# Fine-tune Prithvi

Here we setup the fine-tuning including which type of task, which head to use and the model parameters. In this case we are doing segemtation task (you can take a look at this and other downstream tasks here [TerraTorch docs](https://ibm.github.io/terratorch/stable/tasks/)) and using a unet decoder. We also set the numbers of images per label with the "backbone_num_frames" parameter to allow us to perform multi-temporal classification. 

In [None]:
pl.seed_everything(0)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="../output/multicrop/checkpoints/",
    mode="max",
    monitor="val/Multiclass_Jaccard_Index", # Variable to monitor
    filename="best-{epoch:02d}",
)

# Lightning Trainer
trainer = pl.Trainer(
    accelerator="auto",
    strategy="auto",
    devices=1, # Lightning multi-gpu often fails in notebooks
    precision='bf16-mixed',  # Speed up training
    num_nodes=1,
    logger=True, # Uses TensorBoard by default
    max_epochs=max_epochs,
    log_every_n_steps=5,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
    default_root_dir="../output/multicrop",
)

# Model
model = terratorch.tasks.SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",
    model_args={
        # Backbone
        "backbone": prithvi_backbone,
        "backbone_pretrained": True,
        "backbone_num_frames": 3,
        "backbone_bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
        "backbone_coords_encoding": [], # use ["time", "location"] for time and location metadata
        
        # Necks 
        "necks": [
            {
                "name": "SelectIndices",
                # "indices": [2, 5, 8, 11]  # 100m model
                "indices": [5, 11, 17, 23]  # 300m model
                # "indices": [7, 15, 23, 31]  # 300m model
            },
            {
                "name": "ReshapeTokensToImage",
                "effective_time_dim": 3
            },
            {"name": "LearnedInterpolateToPyramidal"},
        ],
        
        # Decoder
        "decoder": "UNetDecoder",
        "decoder_channels": [512, 256, 128, 64],
        
        # Head
        "head_dropout": 0.1,
        "num_classes": 13,
    },

    loss="ce",
    lr=1e-4,
    optimizer="AdamW",
    ignore_index=-1,
    freeze_backbone=True,  # Speeds up fine-tuning
    freeze_decoder=False,
    plot_on_val=True,
)

In [None]:
# Training
trainer.fit(model, datamodule=datamodule)

# Test the fine-tuned model

Let's gather and specify the relevant files for carrying out testing. Look for your .ckpt file produced during the fine-tuning process here it is in '../output/multicrop/checkpoints/best-epoch=00.ckpt'. We have also provided a model that has been trained on the full dataset so that we can compare it to our model. 

In [None]:
best_ckpt_path = "../output/multicrop/checkpoints/best-epoch=00.ckpt"

# Download best model checkpoint fine-tuned on full dataset
best_ckpt_100_epoch_path = "multicrop_best-epoch=76.ckpt"

if not os.path.isfile(best_ckpt_100_epoch_path):
    gdown.download("https://drive.google.com/uc?id=1cO5a9PmV70j6mvlTc8zH8MnKsRCGbefm")

In [None]:
# calculate test metrics
trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_100_epoch_path)

In [None]:
# get predictions
preds = trainer.predict(model, datamodule=datamodule, ckpt_path=best_ckpt_100_epoch_path)

In [None]:
# get data 
data_loader = trainer.predict_dataloaders
batch = next(iter(data_loader))

# plot
for i in range(batch_size):
    sample = {key: batch[key][i] for key in batch}
    sample["prediction"] = preds[0][0][0][i].cpu().numpy()

    datamodule.predict_dataset.plot(sample)

# Fine-tuning via CLI

We also run the fine-tuning via a [CLI](https://ibm.github.io/terratorch/stable/quick_start/#training-with-lightning-tasks). All parameteres we have specified in the notebook can be put in a [yaml]( ../configs/prithvi_v2_eo_300_tl_unet_multitemporal_crop.yaml), and can be run using the command below. Take a look at the [TerraTorch docs](https://ibm.github.io/terratorch/stable/tutorials/the_yaml_config/) for how to setup the config.

You might want to restart the session to free up GPU memory.

In [None]:
# First let's get the config file from github.com.
!git init
!git remote add origin https://github.com/IBM/ML4EO-workshop-2025.git
!git fetch --all
!git checkout origin/main -- "Prithvi-EO/configs/prithvi_v2_eo_300_tl_unet_multitemporal_crop.yaml"

In [None]:
# Run fine-tuning
!terratorch fit -c "Prithvi-EO/configs/prithvi_v2_eo_300_tl_unet_multitemporal_crop.yaml"