## Overview
This notebook focuses on finetuning the Prithvi EO v2.0 model to classify crops in a HLS scene. The main take aways from this notebook will be as follows:
1. Learn how to use Terratorch to finetune Prithvi EO v2.0 300m for crop classification (13 classes).
2. Use Huggingface datasets with Prithvi EO.
3. Understand the effects of spefic parameters in training and hardware utilization
4. Use finetuned model for inference.

## Setup
1. Go to "Kernel"
2. Select "prithvi_eo"

In [None]:
import os
import sys
import torch
import gdown
import terratorch
import albumentations
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from terratorch.datamodules import MultiTemporalCropClassificationDataModule
import warnings
warnings.filterwarnings('ignore')

3. Download the dataset from Google Drive

In [None]:
from huggingface_hub import hf_hub_download, snapshot_download
from pathlib import Path

dataset_path = "../data/multi-temporal-crop-classification"

snapshot_download(
    repo_id="ibm-nasa-geospatial/multi-temporal-crop-classification",
    allow_patterns="*.tgz",
    repo_type="dataset",
    local_dir=dataset_path,
)
snapshot_download(
    repo_id="ibm-nasa-geospatial/multi-temporal-crop-classification",
    allow_patterns="*.txt",
    repo_type="dataset",
    local_dir=dataset_path,
)
!mkdir ../data/multi-temporal-crop-classification/training_chips; tar -xzf ../data/multi-temporal-crop-classification/training_chips.tgz -C ../data/multi-temporal-crop-classification/
!mkdir ../data/multi-temporal-crop-classification/validation_chips; tar -xzf ../data/multi-temporal-crop-classification/validation_chips.tgz -C ../data/multi-temporal-crop-classification/


## Multi-temporal Crop Dataset

Lets start with analyzing the dataset


In [None]:
!ls "{dataset_path}"

In [None]:
# Each merged sample includes the stacked bands of three time steps
!ls "{dataset_path}/training_chips" | head

In [None]:
# Parameters to modify
batch_size = 4
num_workers = 2

num_classes = 2

prithvi_backbone = "prithvi_eo_v2_300_tl", # Model can be either prithvi_eo_v1_100, prithvi_eo_v2_300, prithvi_eo_v2_300_tl, prithvi_eo_v2_600, prithvi_eo_v2_600_tl


indices = [5, 11, 17, 23]

if prithvi_backbone == 'prithvi_eo_v2_100':
    indices = [2, 5, 8, 11] # indices for prithvi_eo_v1_100
elif prithvi_backbone == 'prithvi_eo_v2_300' or prithvi_backbone == 'prithvi_eo_v2_300_tl': 
    indices = [5, 11, 17, 23] # indices for prithvi_eo_v2_300
elif prithvi_backbone == 'prithvi_eo_v2_600' or prithvi_backbone == 'prithvi_eo_v2_600_tl':
    indices = [7, 15, 23, 31] # indices for prithvi_eo_v2_600

# Total number of epochs the training will run for. Since we are short on time, we will just be running it for 1 epoch. This can be updated to any positive integer.
max_epochs = 1 


In [None]:
# Adjusted dataset class for this dataset (general dataset could be used as well)
datamodule = MultiTemporalCropClassificationDataModule(
    batch_size=batch_size,
    num_workers=num_workers,
    data_root=dataset_path,
    train_transform=[
        terratorch.datasets.transforms.FlattenTemporalIntoChannels(),  # Required for temporal data
        albumentations.D4(), # Random flips and rotation
        albumentations.pytorch.transforms.ToTensorV2(),
        terratorch.datasets.transforms.UnflattenTemporalFromChannels(n_timesteps=3),
    ],
    val_transform=None,  # Using ToTensor() by default
    test_transform=None,
    expand_temporal_dimension=True,
    use_metadata=False, # The crop dataset has metadata for location and time
    reduce_zero_label=True,
)

# Setup train and val datasets
datamodule.setup("fit")

In [None]:
# Mean and standard deviation calculated from the training dataset for all 6 bands, and 3 timesteps, for zero mean normalization.
# checking for the dataset means and stds
datamodule.means, datamodule.stds

In [None]:
# checking datasets train split size
train_dataset = datamodule.train_dataset
len(train_dataset)

In [None]:
# checking datasets available bands
train_dataset.all_band_names

In [None]:
# checking datasets classes
train_dataset.class_names

In [None]:
# plotting a few samples
for i in range(5):
    train_dataset.plot(train_dataset[i])

In [None]:
# checking datasets validation split size
val_dataset = datamodule.val_dataset
len(val_dataset)

In [None]:
# checking datasets testing split size
datamodule.setup("test")
test_dataset = datamodule.test_dataset
len(test_dataset)

# Fine-tune Prithvi

In [None]:
pl.seed_everything(0)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="../output/multicrop/checkpoints/",
    mode="max",
    monitor="val/Multiclass_Jaccard_Index", # Variable to monitor
    filename="best-{epoch:02d}",
)

# Lightning Trainer
trainer = pl.Trainer(
    accelerator="auto",
    strategy="auto",
    devices=1, # Lightning multi-gpu often fails in notebooks
    precision='bf16-mixed',  # Speed up training
    num_nodes=1,
    logger=True, # Uses TensorBoard by default
    max_epochs=1, # For demos
    log_every_n_steps=5,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
    default_root_dir="../output/multicrop",
)

# Model
model = terratorch.tasks.SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",
    model_args={
        # Backbone
        "backbone": "prithvi_eo_v2_300_tl",
        "backbone_pretrained": True,
        "backbone_num_frames": 3,
        "backbone_bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
        "backbone_coords_encoding": [], # use ["time", "location"] for time and location metadata
        
        # Necks 
        "necks": [
            {
                "name": "SelectIndices",
                # "indices": [2, 5, 8, 11]  # 100m model
                "indices": [5, 11, 17, 23]  # 300m model
                # "indices": [7, 15, 23, 31]  # 300m model
            },
            {
                "name": "ReshapeTokensToImage",
                "effective_time_dim": 3
            },
            {"name": "LearnedInterpolateToPyramidal"},
        ],
        
        # Decoder
        "decoder": "UNetDecoder",
        "decoder_channels": [512, 256, 128, 64],
        
        # Head
        "head_dropout": 0.1,
        "num_classes": 13,
    },

    loss="ce",
    lr=1e-4,
    optimizer="AdamW",
    ignore_index=-1,
    freeze_backbone=True,  # Speeds up fine-tuning
    freeze_decoder=False,
    plot_on_val=True,
)

In [None]:
# You can start the tensorboard with (run it in a terminal window)
!pip install tensorboard
!tensorboard --logdir output &

# tensorboard can be accessed by updating the `lab` part in the current jupyterlab browser tab with `/proxy/6006/`:
# Eg: https://gvipa9zcdsccwe6.studio.us-west-2.sagemaker.aws/jupyterlab/default/lab -> https://gvipa9zcdsccwe6.studio.us-west-2.sagemaker.aws/jupyterlab/default/proxy/6006/

In [None]:
# Training
trainer.fit(model, datamodule=datamodule)

# Test the fine-tuned model

In [None]:
best_ckpt_path = "../output/multicrop/checkpoints/best-epoch=00.ckpt"

In [None]:
def run_test_and_plot(model, ckpt_path):

    # calculate test metrics
    trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path)

    # get predictions
    preds = trainer.predict(model, datamodule=datamodule, ckpt_path=ckpt_path)

    # get data 
    data_loader = trainer.predict_dataloaders
    batch = next(iter(data_loader))

    # plot
    
    for i in range(batch_size):
        sample = {key: batch[key][i] for key in batch}
        sample["prediction"] = preds[0][0][0][i].cpu().numpy()

        datamodule.predict_dataset.plot(sample)
        

In [None]:
run_test_and_plot(model, best_ckpt_path)

In [None]:
best_ckpt_100_epoch_path = "multicrop_best-epoch=76.ckpt"

if not os.path.isfile(best_ckpt_100_epoch_path):
    gdown.download("https://drive.google.com/uc?id=1cO5a9PmV70j6mvlTc8zH8MnKsRCGbefm")

In [None]:
run_test_and_plot(model, best_ckpt_100_epoch_path)

# Fine-tuning via CLI

You might want to restart the session to free up GPU memory.

In [None]:
# Run fine-tuning
!terratorch fit -c "../configs/prithvi_v2_eo_300_tl_unet_multitemporal_crop.yaml"

To run this via terminal:
1. Open terminal from the jupyterlab home page
2. Activate conda `source /opt/conda/bin/activate`
3. Activate appropriate conda environment `conda activate prithvi_eo`
4. Navigate to the notebook directory: `cd "/home/sagemaker-user/ESA-NASA-workshop-2025/Track 1 (EO)/TerraMind/notebooks/"`
5. run terratorch training script: `terratorch fit -c "../configs/terramind_v1_base_sen1floods11.yaml"`