# Use Amazon SageMaker for Parameter-Efficient Fine Tuning of the ESM-2 Protein Language Model

Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: MIT-0

Note: We recommend running this notebook on a **ml.m5.large** instance with the **Data Science 3.0** image.

### What is a Protein?

Proteins are complex molecules that are essential for life. The shape and structure of a protein determines what it can do in the body. Knowing how a protein is folded and how it works helps scientists design drugs that target it. For example, if a protein causes disease, a drug might be made to block its function. The drug needs to fit into the protein like a key in a lock. Understanding the protein's molecular structure reveals where drugs can attach. This knowledge helps drive the discovery of innovative new drugs.

![Proteins are made up of long chains of amino acids](img/protein.png)

### What is a Protein Language Model?

Proteins are made up of linear chains of molecules called amino acids, each with its own chemical structure and properties. If we think of each amino acid in a protein like a word in a sentence, it becomes possible to analyze them using methods originally developed for analyzing human language. Scientists have trained these so-called, "Protein Language Models", or pLMs, on millions of protein sequences from thousands of organisms. With enough data, these models can begin to capture the underlying evolutionary relationships between different amino acid sequences.

It can take a lot of time and compute to train a pLM from scratch for a certain task. For example, a team at Tsinghua University [recently described](https://www.biorxiv.org/content/10.1101/2023.07.05.547496v3) training a 100 Billion-parameter pLM on 768 A100 GPUs for 164 days! Fortunately, in many cases we can save time and resources by adapting an existing pLM to our needs. This technique is called "fine-tuning" and also allows us to borrow advanced tools from other types of language modeling

### What is LoRA?

One such method originally developed in 2021 for language analysis is ["Low-Rank Adaptation of Large Language Models"](https://www.biorxiv.org/content/10.1101/2023.07.05.547496v3), or "LoRA". This method adapts large pre-trained language models to new tasks. It does this by changing only a small part of the model. This makes the method very efficient. The small changed part targets the most important information needed for the new task. This allows quick customization of the model for new uses.

`peft` is an open source library from HuggingFace to easily run parameter-efficient fine tuning jobs. That includes the use of LoRA. In addition, we'll use int-8 quantization to further increase efficiency.
LoRA + quantization enables us to use less GPU memory (VRAM) to train large language models, giving us more compute flexibility.

### What is ESM-2?

[ESM-2](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1) is a pLM trained using unsupervied masked language modelling on 250 Million protein sequences by researchers at [Facebook AI Research (FAIR)](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1). It is available in several sizes, ranging from 8 Million to 15 Billion parameters. The smaller models are suitable for various sequence and token classification tasks. The FAIR team also adapted the 3 Billion parameter version into the ESMFold protein structure prediction algorithm. They have since used ESMFold to predict the struture of [more than 700 Million metagenomic proteins](https://esmatlas.com/about). 

ESM-2 is a powerful pLM. However, it has traditionally required multiple A100 GPU chips to fine-tune. In this notebook, we demonstrate how to use QLoRA to fine-tune ESM-2 in on an inexpensive Amazon SageMaker training instance. We will use ESM-2 to predict [subcellular localization](https://academic.oup.com/nar/article/50/W1/W228/6576357). Understanding where proteins appear in cells can help us understand their role in disease and find new drug targets. 

---
## 1. Set up environment

In [None]:
%pip install -q -U --disable-pip-version-check -r notebook-requirements.txt

Load the sagemaker package and create some service clients

In [None]:
import boto3
from datasets import Dataset, DatasetDict
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import random
import sagemaker
from sagemaker.experiments.run import Run
from sagemaker.huggingface import HuggingFace, HuggingFaceModel
from sagemaker.inputs import TrainingInput
import sagemaker_datawrangler
from time import strftime
from transformers import AutoTokenizer

boto_session = boto3.session.Session()
sagemaker_session = sagemaker.session.Session(boto_session)
S3_BUCKET = sagemaker_session.default_bucket()
s3 = boto_session.client("s3")
sagemaker_client = boto_session.client("sagemaker")
sagemaker_execution_role = sagemaker.session.get_execution_role(sagemaker_session)
REGION_NAME = sagemaker_session.boto_region_name
print(f"Assumed SageMaker role is {sagemaker_execution_role}")

S3_PREFIX = "esm-loc-ft"
S3_PATH = sagemaker.s3.s3_path_join("s3://", S3_BUCKET, S3_PREFIX)
print(f"S3 path is {S3_PATH}")

EXPERIMENT_NAME = "esm-loc-ft-" + strftime("%Y-%m-%d-%H-%M-%S")
print(f"Experiment name is {EXPERIMENT_NAME}")

---
## 2. Build Dataset

We'll use a version of the [DeepLoc-2 data set](https://services.healthtech.dtu.dk/services/DeepLoc-2.0/) to fine tune our localization model. It consists of several thousand protein sequences, each with one or more experimentally-observed location labels. This data was extracted by the DeepLoc team at Technical University of Denmark from the public [UniProt sequence database](https://www.uniprot.org/).

In [None]:
df = pd.read_csv(
    "https://services.healthtech.dtu.dk/services/DeepLoc-2.0/data/Swissprot_Train_Validation_dataset.csv"
).drop(["Unnamed: 0", "Partition"], axis=1)
df["Membrane"] = df["Membrane"].astype("int32")

# filter for sequences between 100 and 512 amino acides
df = df[df["Sequence"].apply(lambda x: len(x)).between(100, 512)]

# Remove unnecessary features
df = df[['Sequence', 'Kingdom','Membrane']]

# Resample rows to randomize and create equal distribution of Membrane values
weights = 1./df.groupby('Membrane')['Membrane'].transform('count')
df = df.sample(n=3000, weights=weights).reset_index(drop=True)

#Visualize data using the SageMaker Data Wrangler widget
df

Next, we tokenize the sequences and trim them to a max length of 512 amino acids.

In [41]:
dataset = Dataset.from_pandas(df).train_test_split(test_size=0.1, shuffle=True)
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")

def preprocess_data(examples, max_length=512):
    text = examples["Sequence"]
    encoding = tokenizer(
        text, 
        # padding="max_length", 
        truncation=True, 
        max_length=max_length
    )
    encoding["labels"] = examples['Membrane']
    return encoding


encoded_dataset = dataset.map(
    preprocess_data,
    batched=True,
    num_proc=os.cpu_count(),
    remove_columns=dataset["train"].column_names,
)

encoded_dataset.set_format("torch")
print(encoded_dataset)

Map (num_proc=2):   0%|          | 0/2700 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/300 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 2700
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 300
    })
})


Look at an example record

In [42]:
random_idx = random.randint(3, len(encoded_dataset["train"]))
example = encoded_dataset["train"][random_idx]

print(f"Viewing example record {random_idx}")
print(f"Raw sequence:\n{tokenizer.decode(example['input_ids'])}\n")
print(f"Tokenized sequence:\n{example['input_ids'].tolist()}\n")
print(f"Label:\n{example['labels']}")

Viewing example record 2447
Raw sequence:
<cls> M Q R R R R A P P A S Q P A Q D S G H S E E V E V Q F S A G R L G S A A P A G P P V R G T A E D E E R L E R E H F W K V I N A F R Y Y G T S M H E R V N R T E R Q F R S L P D N Q Q K L L P Q F P L H L D K I R K C V D H N Q E I L L T I V N D C I H M F E N K E Y G E D A N G K I M P A S T F D M D K L K S T L K Q F V R D W S G T G K A E R D A C Y K P I I K E I I K N F P K E R W D P S K V N I L V P G A G L G R L A W E I A M L G Y A C Q G N E W S F F M L F S S N F V L N R C S E V D K Y K L Y P W I H Q F S N N R R S A D Q I R P I F F P D V D P H S L P P G S N F S M T A G D F Q E I Y S E C N T W D C I A T C F F I D T A H N V I D Y I D T I W R I L K P G G I W I N L G P L L Y H F E N L A N E L S I E L S Y E D I K N V V L Q Y G F Q L E V E K E S V L S T Y T V N D L S M M K Y Y Y E C V L F V V R K P Q <eos>

Tokenized sequence:
[0, 20, 16, 10, 10, 10, 10, 5, 14, 14, 5, 8, 16, 14, 5, 16, 13, 8, 6, 21, 8, 9, 9, 7, 9, 7, 16, 18, 8, 5, 6, 10, 4, 6, 8, 5, 

Finally, we upload the processed training, test, and validation data to S3.

In [43]:
train_s3_uri = S3_PATH + "/data/train"
test_s3_uri = S3_PATH + "/data/test"

encoded_dataset["train"].save_to_disk(train_s3_uri)
encoded_dataset["test"].save_to_disk(test_s3_uri)

Saving the dataset (0/1 shards):   0%|          | 0/2700 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/300 [00:00<?, ? examples/s]

---
## 3. Train model in SageMaker

Next, we'll process the 3 Billion-parameter model with LoRA and train on a ml.g5.2xlarge instance.

Define the metrics for SageMaker to extract from the job logs and send to SageMaker Experiments. You can customize these to log more or fewer values.

In [72]:
metric_definitions = [
    {"Name": "epoch", "Regex": "'epoch': ([0-9.]*)"},
    {"Name": "learning_rate", "Regex": "'learning_rate': ([0-9.e-]*)"},
    {"Name": "train_loss", "Regex": "'loss': ([0-9.e-]*)"},
    {"Name": "train_runtime", "Regex": "'train_runtime': ([0-9.e-]*)"},
    {"Name": "train_samples_per_second", "Regex": "'train_samples_per_second': ([0-9.e-]*)"},
    {"Name": "train_steps_per_second", "Regex": "'train_steps_per_second': ([0-9.e-]*)"},
    {"Name": "eval_loss", "Regex": "'eval_loss': ([0-9.e-]*)"},
    {"Name": "eval_runtime", "Regex": "'eval_runtime': ([0-9.e-]*)"},
    {"Name": "eval_samples_per_second", "Regex": "'eval_samples_per_second': ([0-9.e-]*)"},
    {"Name": "eval_steps_per_second", "Regex": "'eval_steps_per_second': ([0-9.e-]*)"},
    {"Name": "eval_accuracy", "Regex": "'eval_accuracy': ([0-9.e-]*)"},
    {"Name": "eval_f1", "Regex": "'eval_f1': ([0-9.e-]*)"},
    {"Name": "eval_roc_auc", "Regex": "'eval_roc_auc': ([0-9.e-]*)"},

]

In [105]:
# Additional training parameters
hyperparameters = {
    "epochs": 1,
    "model_id": "facebook/esm2_t6_8M_UR50D",
    # "model_id": "facebook/esm2_t12_35M_UR50D",
    # "model_id": "facebook/esm2_t30_150M_UR50D",
    # "model_id": "facebook/esm2_t33_650M_UR50D",
    # "model_id": "facebook/esm2_t36_3B_UR50D",
    "lora": True,
    "use_gradient_checkpointing": True,
}

# creates Hugging Face estimator
hf_estimator = HuggingFace(
    base_job_name="esm-localization-fine-tuning",
    entry_point="lora-train.py",
    source_dir="scripts/training/peft",
    instance_type="ml.g5.2xlarge",
    instance_count=1,
    transformers_version="4.28",
    pytorch_version="2.0",
    py_version="py310",
    output_path=f"{S3_PATH}/output",
    role=sagemaker_execution_role,
    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
    checkpoint_local_path="/opt/ml/checkpoints",
    sagemaker_session=sagemaker_session,
    keep_alive_period_in_seconds=1800,
    tags=[{"Key": "project", "Value": "esm-fine-tuning"}],
)

In [None]:
with Run(
    experiment_name=EXPERIMENT_NAME,
    sagemaker_session=sagemaker_session,
) as run:
    hf_estimator.fit(
        {
            "train": TrainingInput(s3_data=train_s3_uri, input_mode="File"),
            "test": TrainingInput(s3_data=test_s3_uri, input_mode="File"),
        },
        wait=True,
    )

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: esm-localization-fine-tuning-2023-10-31-03-43-52-505


Using provided s3_resource
2023-10-31 03:43:52 Starting - Starting the training job...
2023-10-31 03:44:06 Downloading - Downloading input data
2023-10-31 03:44:06 Training - Training image download completed. Training in progress.[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2023-10-31 03:44:07,420 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2023-10-31 03:44:07,433 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)[0m
[34m2023-10-31 03:44:07,442 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2023-10-31 03:44:07,444 sagemaker_pytorch_container.training INFO     Invoking user training script.[0m
[34m2023-10-31 03:44:08,786 sagemaker-training-toolkit INFO     Installing dependencies from requirements.txt:[0m
[34m/opt/conda/bin/python3.10 -m pip install

You can view metrics and debugging information for this run in SageMaker Experiments. On the left-side navigation panel, select the Home icon, then "Experiments". From there, you can select your experiment name and training job name and view the Debugger insights.

While the training job is running, take a look at the training script

---
## 4. Deploy Model as Real-Time Inference Endpoint

In [None]:
%%time

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    model_data=hf_estimator.model_uri,
    role=sagemaker_execution_role,
    transformers_version="4.28.1",
    pytorch_version="2.0.0",
    py_version="py310",
    model_server_workers=1,
    env={"HF_TASK": "text-classification"},
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.2xlarge"
    role=sagemaker_execution_role
)

In [None]:
test_seq = "MAAAVVLAAGLRAARRAVAATGVRGGQVRGAAGVTDGNEVAKAQQATPGGAAPTIFSRILDKSLPADILYEDQQCLVFRDVAPQAPVHFLVIPKKPIPRISQAEEEDQQLLGHLLLVAKQTAKAEGLGDGYRLVINDGKLGAQSVYHLHIHVLGGRQLQWPPG"
sample = {"inputs": test_seq}
predictor.predict(sample)

In [None]:
try:
    predictor.delete_endpoint()
except:
    pass

In [97]:
from peft import PeftModelForSequenceClassification