# HLS Foundation Model Finetuning notebook

This notebook demonstrates the steps to fintune the HLS foundation model (A.K.A Prithvi) which is trained using HLSL30 and HLSS30 datasets. 

Note: Entierty of this notebook is desigend to work well within the AWS sagemaker environment. AWS sagemaker environment access for your account can be found using http://smd-ai-workshop-creds-webapp.s3-website-us-east-1.amazonaws.com/.

![HLS Training](../images/HLS-training.png)

In [None]:
# Install required packages
!pip install -r ../requirements.txt

# Create directories needed for data, model, and config preparations
!mkdir datasets
!mkdir models
!mkdir configs

## Dataset preparation

For this hands-on session, Burn Scars example will be used for fine-tuning. All of the data and pre-trained models are available in Huggingface. Huggingface packages and git will be utilized to download, and prepare datasets and pretrained models.


Note: Git Large File Storage (git LFS) is utilized to download larger files from huggingface.

In [None]:
# Install git lfs
! sudo apt-get install git-lfs; git lfs install

### Download HLS Burn Scars dataset from Huggingface: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars

In [None]:
! cd datasets; git clone https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars; tar -xvzf hls_burn_scars/hls_burn_scars.tar.gz 

## Download config and Pre-trained model

The HLS Foundation Model (pre-trained model), and configuration for Burn Scars downstream task are available in Huggingface. We use `huggingface_hub` python package to download the files locally.

In [None]:
# Define constants
BUCKET_NAME = 'workshop-1-015' # Replace this with the bucket name available from https://creds-workshop.nasa-impact.net/ 
CONFIG_PATH = './configs'
DATASET_PATH = './datasets'
MODEL_PATH = './models'

In [None]:
# Download pre-trained model file from huggingface
! cd models && curl https://www.nsstc.uah.edu/data/sujit.roy/Prithvi_checkpoints/checkpoint.pt > prithvi_global_v1.pt;

*Warning: * Before running the remaining cells please update the details in the configuration file as mentioned below:

1. Update line number 13 from `data_root = '<path to data root>'` to `data_root = '/opt/ml/data/'`. This is the base of our data inside of sagemaker.
2. Update line number 41 from `pretrained_weights_path = '<path to pretrained weights>'` to `pretrained_weights_path = f"{data_root}/models/Prithvi_100M.pt"`. This provides the pre-trained model path to the train script.
3. Update line number 53 from `experiment = '<experiment name>'` to `experiment = 'burn_scars'` or your choice of experiment name.
4. Update line number 54 from `project_dir = '<project directory name>'` to `project_dir = 'v1'` or your choice of project directory name. 
5. Save the config file.

In [None]:
# Prepare sagemaker session with files uploaded to s3 bucket
import sagemaker

sagemaker_session = sagemaker.Session()
train_images = sagemaker_session.upload_data(path='datasets/training', bucket=BUCKET_NAME, key_prefix='data/training')
val_images = sagemaker_session.upload_data(path='datasets/validation', bucket=BUCKET_NAME, key_prefix='data/validation')
test_images = sagemaker_session.upload_data(path='datasets/validation', bucket=BUCKET_NAME, key_prefix='data/test')

In [None]:
# Rename configuration file name to user specific filename
import os

identifier = 'workshop-015' # Please update this with an identifier

config_filename = '../configs/burn_scars.yaml'
new_config_filename = f"../configs/{identifier}-burn_scars.py"
os.rename(config_filename, new_config_filename)

In [None]:
# Upload config files to s3 bucket
configs = sagemaker_session.upload_data(path=new_config_filename, bucket=BUCKET_NAME, key_prefix='data/configs')
models = sagemaker_session.upload_data(path='models/prithvi_global_v1.pt', bucket=BUCKET_NAME, key_prefix='data/models')


Note: For HLS Foundation Model, MMCV and MMSEG were used. These libraries use pytorch underneath them for training, data distribution etc. However, these packages are not available in sagemaker by default. Thus, custom script training is required. Sagemaker utilizes Docker for custom training scripts. If interested, the code included in the image we are using for training (637423382292.dkr.ecr.us-west-2.amazonaws.com/sagemaker_hls:latest) is bundled with this repository, and the train script used is `train.py`.

The current HLS Foundation model fits in a single NVIDIA Tesla V100 GPU (16GB VRAM). Hence, `ml.p3.2xlarge` instance is used for training.

In [None]:
# Setup variables for training using sagemaker
from datetime import time
from sagemaker import get_execution_role
from sagemaker.estimator import Estimator


name = f'{identifier}-sagemaker'
role = get_execution_role()
input_s3_uri = f"s3://{BUCKET_NAME}/data"
finetuned_model_name = f"{identifier}-workshop.pth"
environment_variables = {
    'CONFIG_FILE': f"/opt/ml/data/configs/{new_config_filename.split('/')[-1]}",
    'MODEL_DIR': "/opt/ml/data/models/",
    'MODEL_NAME': finetuned_model_name,
    'S3_URL': input_s3_uri,
    'BUCKET_NAME': BUCKET_NAME,
    'ROLE_ARN': role,
    'ROLE_NAME': role.split('/')[-1],
    'EVENT_TYPE': 'burn_scars',
    'VERSION': 'v1'
}

ecr_container_url = '637423382292.dkr.ecr.us-west-2.amazonaws.com/prithvi_global:latest'
sagemaker_role = 'SageMaker-ExecutionRole-20240206T151814'

instance_type = 'ml.p3.2xlarge'

instance_count = 1
memory_volume = 50

In [None]:
# Establish an estimator (model) using sagemaker and the configurations from the previous cell.
estimator = Estimator(image_uri=ecr_container_url,
                      role=get_execution_role(),
                      base_job_name=name,
                      instance_count=1,
                      environment=environment_variables,
                      instance_type=instance_type)


In [None]:
# Start training
estimator.fit()

In [None]:
image_config = {
     'RepositoryAccessMode': 'Platform'
}

In [None]:
IMAGE_URI = '637423382292.dkr.ecr.us-west-2.amazonaws.com/prithvi_global_inference'

ENV = {
    "CHECKPOINT_FILENAME": f"s3://{BUCKET_NAME}/models/{finetuned_model_name}",
    "S3_CONFIG_FILENAME": f"s3://{BUCKET_NAME}/{new_config_filename}",
    "BUCKET_NAME": BUCKET_NAME,
    "AIP_PREDICT_ROUTE": "/invocations",
    "BACKBONE_FILENAME": f"s3://{BUCKET_NAME}/data/models/prithvi_global_v1.pt"
}

primary_container = {
    'ContainerHostname': 'ModelContainer',
    'Image': IMAGE_URI,
    'ImageConfig': image_config,
    'Environment': ENV
}

In [None]:
model_name = f'prithvi-global-{identifier}'
execution_role_arn = get_execution_role()

In [None]:
import boto3
sagem = boto3.client('sagemaker')

resp = sagem.create_model(
        ModelName=model_name,
        PrimaryContainer=primary_container,
        ExecutionRoleArn=execution_role_arn
    )

endpoint_config_name = f'{model_name}-endpoint-config'

sagem.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            'VariantName': 'v1',
            'ModelName': model_name,
            'InitialInstanceCount': 1,
            'InstanceType': 'ml.p3.2xlarge'
        },
    ],
)

endpoint_name = f'{model_name}-endpoint'

sagem.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name,
)

sagem.describe_endpoint(EndpointName=endpoint_name)

In [None]:
import json
sm = sagemaker.Session().sagemaker_runtime_client

query = {
    'bounding_box': [27.844454, 36.076852, 28.860310, 37.279049],
    'date': '2023-07-23',
    'model_id': 'burn'
}

response = sm.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=json.dumps(query),
    ContentType="application/json"
)

json.loads(response['Body'].read())

In [None]:
sagem.delete_model(ModelName=model_name)
sagem.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sagem.delete_endpoint(EndpointName=endpoint_name,)