# Overview

This notebook provides a practical guide to fine-tuning the Prithvi Earth Observation (EO) v2.0 model, specifically for the task of identifying burn scars within Harmonized Landsat Sentinel (HLS) imagery. The primary learning objectives are:

- To demonstrate the use of Terratorch for fine-tuning the Prithvi EO v2.0 (300M parameter model) to detect burn scars.
- To illustrate the integration and use of Hugging Face datasets with Prithvi EO models during the fine-tuning process.
- To develop an understanding of how various training parameters affect model performance and the utilization of hardware resources.

AWS Sagemaker Training jobs will be used for finetuning the models. 


# Setup
Go to "Kernel"
Select "prithvi_eo"

In [None]:
import boto3
import numpy as np
import yaml
import rasterio
import sagemaker

from datetime import time
from glob import glob

from huggingface_hub import hf_hub_download, snapshot_download
from pathlib import Path

from sagemaker import get_execution_role
from sagemaker.estimator import Estimator

## Dataset preparation

For this hands-on session, Burn Scars example will be used for fine-tuning. All of the data and pre-trained models are available in Huggingface. Huggingface packages and git will be utilized to download, and prepare datasets and pretrained models.


### Download HLS Burn Scars dataset from Huggingface: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars

In [None]:
DATA_PATH = "../data/hls_burn_scars"

snapshot_download(
    repo_id="ibm-nasa-geospatial/hls_burn_scars",
    allow_patterns="hls_burn_scars.tar.gz",
    repo_type="dataset",
    local_dir=DATA_PATH,
)
!tar -xvzf ../data/hls_burn_scars/hls_burn_scars.tar.gz -C ../data/hls_burn_scars

In [None]:
# Prepare sagemaker session with files uploaded to s3 bucket

BUCKET_NAME = <BUCKET_NAME>

# Prepare sagemaker session with files uploaded to s3 bucket
sagemaker_session = sagemaker.Session()
train_images = sagemaker_session.upload_data(path=f'{DATA_PATH}/training', bucket=BUCKET_NAME, key_prefix='datasets/training')
val_images = sagemaker_session.upload_data(path=f'{DATA_PATH}/validation', bucket=BUCKET_NAME, key_prefix='datasets/validation')
test_images = sagemaker_session.upload_data(path=f'{DATA_PATH}/validation', bucket=BUCKET_NAME, key_prefix='datasets/test')

## Prepare config and parameters

In [None]:
with open('../configs/prithvi_v2_eo_300_tl_unet_burnscars.yaml') as config_file:
    config = yaml.safe_load(config_file)

config

In [None]:
def calculate_band_statistics(image_directory, image_pattern, bands=[0, 1, 2, 3, 4, 5]):
    """
    Calculate the mean and standard deviation of each band in a folder of GeoTIFF files.

    Args:
        image_directory (str): Directory where the source GeoTIFF files are stored that are passed to model for training.
        image_pattern (str): Pattern of the GeoTIFF file names that globs files for computing stats.
        bands (list, optional): List of bands to calculate statistics for. Defaults to [0, 1, 2, 3, 4, 5].

    Raises:
        Exception: If no images are found in the given directory.

    Returns:
        tuple: Two lists containing the means and standard deviations of each band.
    """
    # Initialize lists to store the means and standard deviations
    all_means = []
    all_stds = []

    # Use glob to get a list of all .tif images in the directory
    all_images = glob(f"{image_directory}/{image_pattern}")

    # Make sure there are images to process
    if not all_images:
        raise Exception("No images found")

    # Get the number of bands
    num_bands = len(bands)

    # Initialize arrays to hold sums and sum of squares for each band
    band_sums = np.zeros(num_bands)
    band_sq_sums = np.zeros(num_bands)
    pixel_counts = np.zeros(num_bands)

    # Iterate over each image
    for image_file in all_images:
        with rasterio.open(image_file) as src:
            # For each band, calculate the sum, square sum, and pixel count
            for band in bands:
                data = src.read(band + 1)  # rasterio band index starts from 1
                band_sums[band] += np.nansum(data)
                band_sq_sums[band] += np.nansum(data**2)
                pixel_counts[band] += np.count_nonzero(~np.isnan(data))

    # Calculate means and standard deviations for each band
    for i in bands:
        mean = band_sums[i] / pixel_counts[i]
        std = np.sqrt((band_sq_sums[i] / pixel_counts[i]) - (mean**2))
        all_means.append(mean)
        all_stds.append(std)

    return all_means, all_stds

# Configuration for training

-   `identifier`: This variable will be used as a prefix for all artifacts related to fine-tuning and deployments. Please update it with an appropriate identifier.
-   `usecase`: This variable refers to the use cases the Prithvi model will be fine-tuned for, e.g., `burn_scars`, `flood_detection`, etc. For this hands-on session, we will be using `burn_scars`. If you have your own data, please update accordingly.
-   `data_path`: Data path is where the data locally resides. This will be used to find the files for fine-tuning. These files will then be used to calculate statistics like `mean` and `standard deviation`. These files will also be uploaded to an S3 bucket for the training job to use.
-   `batch_size`: This is the number of data samples processed by the model in one iteration during training. We are using `4` by default. Depending on the availability of GPUs and resources, this can be increased.
-   `num_workers`: This is the number of worker processes used for data loading during training. We are using `2` by default. This can be adjusted based on CPU and I/O capabilities.
-   `num_classes`: This variable represents the number of classes in the fine-tuning job. For `burn_scars`, we have two classes: `burn_scar` and `no_burn_scar`. Update it according to the data you are using for training.
-   `prithvi_backbone`: This variable represents the Prithvi Earth Observation Foundation Model (EO FM) pre-trained using HLS data. There are several variations:
    -   `prithvi_eo_v1_100`: This is an older version of the Prithvi EO FM. It will not be used in this hands-on session.
    -   `prithvi_eo_v2_300`: This version of the Prithvi EO FM has approximately 300 million parameters (typically around 24 Transformer encoder layers). It can be selected for faster training and a lower memory footprint.
    -   `prithvi_eo_v2_300_tl`: This version also has ~300 million parameters (typically around 24 Transformer encoder layers) and is pre-trained with **T**emporal and **L**ocation embeddings. It is ideal for fine-tuning use cases where spatial and temporal information is important and a smaller footprint is desired. For example, crop classification using imagery from multiple time steps.
    -   `prithvi_eo_v2_600`: This is a larger version of the Prithvi EO FM with approximately 600 million parameters (typically 32 Transformer encoder layers). It can be selected for use cases requiring high precision, accuracy, or recall. Note: the memory footprint of this model is significantly larger than the 300M versions. Ensure sufficient resources are available.
    -   `prithvi_eo_v2_600_tl`: This version also has ~600 million parameters (typically 32 Transformer encoder layers) and includes **T**emporal and **L**ocation embeddings. It is best suited for high-performance fine-tuning on use cases where precise spatial and temporal information is crucial, such as detailed change detection or multi-temporal crop type mapping. The resource considerations are similar to the `prithvi_eo_v2_600` model.
-   `base_path`: This variable specifies the base directory for training operations, including the path for input data, configuration files, and the location for storing model artifacts post-training. For SageMaker training jobs, `/opt/ml/data` is commonly used. If using a different environment, please update accordingly.
-   `max_epochs`: This variable limits the number of `epochs` (full passes through the training dataset) a fine-tuning job runs for. A higher number of epochs equates to longer training time and may lead to better-performing models, but this needs to be validated on a case-by-case basis to avoid overfitting.
-   `indices`: For most of our fine-tuning jobs, a `decoder` is added on top of the selected Prithvi backbone. This variable specifies which Transformer blocks (by their index) from the Prithvi backbone will provide their output feature embeddings to this decoder. Commonly, features from blocks at or around 1/4, 1/2, 3/4, and the final depth of the backbone are used to capture multi-scale information. The selection of these indices impacts the architecture and parameter count of the decoder, not the Prithvi backbone itself.
-   `means`: This is the mean of the pixel values across each channel of the training dataset. The mean values, along with standard deviations, will be used for zero-center normalization of input values.
-   `stds`: This is the standard deviation of the pixel values across each channel of the training dataset. The standard deviation values, along with means, will be used for zero-center normalization of input values.
-   `model_path`: This variable specifies where the model artifacts will be stored after training.

In [None]:
# Parameters to modify
identifier = <identifier>
usecase = <usecase>
#local data path
data_path = '../data/hls_burn_scars/'

batch_size = 4
num_workers = 2

num_classes = 2

"""
Model backbone can be either:
  - prithvi_eo_v1_100
  - prithvi_eo_v2_300
  - prithvi_eo_v2_300_tl
  - prithvi_eo_v2_600
  - prithvi_eo_v2_600_tl
"""
prithvi_backbone = 'prithvi_eo_v2_300' 

base_path = '/opt/ml/data'

max_epochs = 100

config['data']['init_args']['batch_size'] = batch_size
config['data']['init_args']['num_workers'] = num_workers

config['data']['init_args']['num_classes'] = num_classes


config['model']['init_args']['model_args']['backbone'] = prithvi_backbone


indices = [5, 11, 17, 23]
if 'prithvi_eo_v2_100' in prithvi_backbone:
    indices = [2, 5, 8, 11]  # indices for prithvi_eo_v1_100
elif 'prithvi_eo_v2_300' in prithvi_backbone: 
    indices = [5, 11, 17, 23]  # indices for prithvi_eo_v2_300
elif 'prithvi_eo_v2_600' in prithvi_backbone:
    indices = [7, 15, 23, 31]  # indices for prithvi_eo_v2_600

config['model']['init_args']['model_args']['necks'][0]['indices'] = indices

means, stds = calculate_band_statistics(data_path, 'training/*_merged.tif')

model_path = f"{base_path}/{usecase}/checkpoints"

In [None]:
# Mean and standard deviation calculated from the training dataset for all 6 bands,
# for zero center normalization.
config['data']['init_args']['means'] = [float(val) for val in means]
config['data']['init_args']['stds'] = [float(val) for val in stds]

# Total number of epochs the training will run for. Since we are short on time,
# we will just be running it for 1 epoch. This can be updated to any positive integer.
config['trainer']['max_epochs'] = max_epochs

config['data']['init_args']['test_data_root'] = f"{base_path}/test"
config['data']['init_args']['val_data_root'] = f"{base_path}/validation"
config['data']['init_args']['train_data_root'] = f"{base_path}/training"

config['trainer']['default_root_dir'] = f"{base_path}/{usecase}"

config['trainer']['callbacks'][2]['init_args']["dirpath"] = model_path

In [None]:
# Rename configuration file name to user specific filename
import os

config_filename = f"{identifier}-burn_scars_Prithvi-EO.yaml"
config_filepath = f"../configs/{config_filename}"
with open(config_filepath, 'w') as config_file:
    yaml.dump(config, config_file, default_flow_style=False)

# Upload config files to s3 bucket
configs = sagemaker_session.upload_data(path=config_filepath, bucket=BUCKET_NAME, key_prefix='data/configs')

In [None]:
# Setup variables for training using sagemaker

name = f'{identifier}-sagemaker'
role = get_execution_role()
input_s3_uri = f"s3://{BUCKET_NAME}/data"
model_name = f"{identifier}-workshop.ckpt",

environment_variables = {
    'CONFIG_FILE': f"{base_path}/configs/{config_filename}",
    'MODEL_DIR': model_path,
    'MODEL_NAME': model_name,
    'S3_URL': input_s3_uri,
    'ROLE_ARN': role,
    'ROLE_NAME': role.split('/')[-1],
    'EVENT_TYPE': usecase,
    'VERSION': 'v1'
}
account_id = boto3.client('sts').get_caller_identity().get('Account')
ecr_container_url = f'{account_id}.dkr.ecr.us-west-2.amazonaws.com/eo_training:latest'
sagemaker_role = 'SageMaker-ExecutionRole-20240206T151814'

instance_type = 'ml.p3.2xlarge'

instance_count = 1
memory_volume = 50

In [None]:
# Establish an estimator (model) using sagemaker and the configurations from the previous cell.
estimator = Estimator(image_uri=ecr_container_url,
                      role=get_execution_role(),
                      base_job_name=name,
                      instance_count=1,
                      environment=environment_variables,
                      instance_type=instance_type)

estimator.fit()

In [None]:
# Save important values in a file for reuse.
export_values = {
    'identifier': identifier,
    'model_name': model_name,
    'config_filename': config_filename,
    'bucket_name': BUCKET_NAME
}

with open('../variables.yaml', 'w') as variable_export:
    yaml.dump(export_values, variable_export, default_flow_style=False)
