# ESM-2 Domain Adaptation

Note: This notebook was last run on the `PyTorch 1.13 Python 3.9 CPU Optimized` image on a `ml.m5.4xlarge` instance

In this notebook, we demonstrate how to perform full-parameter fine tuning of the ESM-2 protein language model on samples from the Observed Antibody Space database. Specifically, a collection of heavy chain amino acid sequences from antibodies found in human Covid-19 patients.

---
## 0. Install dependencies

Install python packages

In [None]:
%pip install -q --upgrade pip
%pip install -q --upgrade sagemaker boto3 awscli transformers accelerate datasets boto3 ipywidgets tqdm s3fs

In [1]:
import boto3
from datasets import load_dataset, DatasetDict
import os
import sagemaker
from sagemaker.experiments.run import Run
from sagemaker.inputs import TrainingInput
from sagemaker.pytorch import PyTorch
from time import strftime
from transformers import AutoTokenizer

boto_session = boto3.session.Session(profile_name="aws-hcls-ml-sa-Admin", region_name="us-west-2")
sagemaker_session = sagemaker.session.Session(boto_session)
S3_BUCKET = sagemaker_session.default_bucket()
s3 = boto_session.client("s3")
sagemaker_client = boto_session.client("sagemaker")
sagemaker_execution_role = sagemaker.session.get_execution_role(sagemaker_session)
REGION_NAME = sagemaker_session.boto_region_name
print(f"Assumed SageMaker role is {sagemaker_execution_role}")

S3_PREFIX = "esm-2-benchmarking"
S3_PATH = sagemaker.s3.s3_path_join("s3://", S3_BUCKET, S3_PREFIX)
print(f"S3 path is {S3_PATH}")

EXPERIMENT_NAME = f"esm-2-benchmarking-" + strftime("%Y-%m-%d-%H-%M-%S")
print(f"Experiment name is {EXPERIMENT_NAME}")

sagemaker.config INFO - Not applying SDK defaults from location: /Library/Application Support/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /Users/bloyal/Library/Application Support/sagemaker/config.yaml
Assumed SageMaker role is arn:aws:iam::111918798052:role/Admin
S3 path is s3://sagemaker-us-west-2-111918798052/esm-2-benchmarking
Experiment name is esm-2-benchmarking-2024-02-13-12-57-25


In [None]:
# MODEL_ID="facebook/esm2_t48_15B_UR50D"
# MODEL_ID="facebook/esm2_t36_3B_UR50D"
MODEL_ID="facebook/esm2_t33_650M_UR50D"
# MODEL_ID="facebook/esm2_t30_150M_UR50D"
# MODEL_ID="facebook/esm2_t12_35M_UR50D"
# MODEL_ID = "facebook/esm2_t6_8M_UR50D"

---
## 1. Process Data

Load OAS sequence data from HuggingFace

In [None]:
src = "bloyal/oas_paired_human_sars_cov_2"
train_sample_count = 10000
test_sample_count = int(train_sample_count * 0.2)

train_dataset = load_dataset(src, split=f"train[:{train_sample_count}]")
test_dataset = load_dataset(src, split=f"test[:{test_sample_count}]")
dataset = DatasetDict({"train": train_dataset, "test": test_dataset}).rename_column(
    "sequence_alignment_aa_heavy", "text"
)

dataset

Tokenize heavy chain sequences

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
sequence_length = 142


def tokenize_data(examples, tokenizer, sequence_length):
    encoding = tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=sequence_length,
    )
    return encoding


encoded_dataset = dataset.map(
    tokenize_data,
    batched=True,
    num_proc=os.cpu_count(),
    remove_columns=dataset["train"].column_names,
    fn_kwargs={
        "tokenizer": tokenizer,
        "sequence_length": sequence_length,
    },
)

encoded_dataset.set_format("torch", columns=["input_ids", "attention_mask"])
encoded_dataset

In [None]:
example = encoded_dataset["train"][0]
print(example["input_ids"])
print(tokenizer.decode(example["input_ids"]))

Upload data to S3

In [None]:
train_s3_uri = S3_PATH + "/data/train"
test_s3_uri = S3_PATH + "/data/test"

encoded_dataset["train"].save_to_disk(train_s3_uri)
encoded_dataset["test"].save_to_disk(test_s3_uri)

---
## 2. Train on Trn1

In [None]:
metric_definitions = [
    {"Name": "epoch", "Regex": "Epoch: ([0-9.]*)"},
    {"Name": "step", "Regex": "Step: ([0-9.]*)"},
    {"Name": "train_loss", "Regex": "Training Loss: ([0-9.e-]*)"},
    {"Name": "train_perplexity", "Regex": "Training Perplexity: ([0-9.e-]*)"},
    {
        "Name": "train_samples_per_second",
        "Regex": "Training Samples/sec: ([0-9.e-]*)",
    },
    {
        "Name": "train_tokens_per_second",
        "Regex": "Training Tokens/sec: ([0-9.e-]*)",
    },
    {"Name": "eval_loss", "Regex": "Eval Loss: ([0-9.e-]*)"},
    {"Name": "eval_perplexity", "Regex": "Eval Perplexity: ([0-9.e-]*)"},
    {
        "Name": "eval_samples_per_second",
        "Regex": "Eval Samples/sec: ([0-9.e-]*)",
    },
    {"Name": "eval_tokens_per_second", "Regex": "Eval Tokens/sec: ([0-9.e-]*)"},
]

In [None]:
neuron_cache = f"s3://{S3_BUCKET}/{S3_PREFIX}/neuron-cache"

# Additional training parameters
hyperparameters = {
    "num_train_epochs": 2,
    "model_id": MODEL_ID,
    "per_device_train_batch_size": 2,
    "per_device_eval_batch_size": 8,
    "bf16": True,
    "logging_steps": 8,
    "optim": "adamw_torch",
    "gradient_accumulation_steps": 4,
    "device": "xla",
}

# creates Hugging Face estimator
trn1_estimator = PyTorch(
    base_job_name="esm-2-oas-trn1",
    entry_point="trn1-oas-mlm-train-dp.py",
    source_dir="scripts/training/neuron",
    instance_type="ml.trn1.32xlarge",
    instance_count=1,
    image_uri=f"763104351884.dkr.ecr.{REGION_NAME}.amazonaws.com/pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.12.0-ubuntu20.04",
    output_path=f"{S3_PATH}/output",
    role=sagemaker_execution_role,
    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
    checkpoint_local_path="/opt/ml/checkpoints",
    sagemaker_session=sagemaker_session,
    keep_alive_period_in_seconds=1800,
    distribution={"torch_distributed": {"enabled": True}},
    environment={
        "NEURON_COMPILE_CACHE_URL": neuron_cache,
        "FI_EFA_FORK_SAFE": "1",
        "XLA_USE_BF16": "1",
    },
    tags=[{"Key": "project", "Value": "esm-benchmarking"}],
)


with Run(
    experiment_name=EXPERIMENT_NAME,
    sagemaker_session=sagemaker_session,
) as run:
    trn1_estimator.fit(
        {
            "train": TrainingInput(s3_data=train_s3_uri, input_mode="FastFile"),
            "test": TrainingInput(s3_data=test_s3_uri, input_mode="FastFile"),
        },
        wait=False,
    )

---
## 3. Train on multiple g5.2xlarge with Distributed Data Parallel

In [55]:
metric_definitions = [
    {"Name": "epoch", "Regex": "Epoch: ([0-9.]*)"},
    {"Name": "step", "Regex": "Step: ([0-9.]*)"},
    {"Name": "train_loss", "Regex": "Training Loss: ([0-9.e-]*)"},
    {"Name": "train_perplexity", "Regex": "Training Perplexity: ([0-9.e-]*)"},
    {
        "Name": "train_samples_per_second",
        "Regex": "Training Samples/sec: ([0-9.e-]*)",
    },
    {
        "Name": "train_tokens_per_second",
        "Regex": "Training Tokens/sec: ([0-9.e-]*)",
    },
    {"Name": "eval_loss", "Regex": "Eval Loss: ([0-9.e-]*)"},
    {"Name": "eval_perplexity", "Regex": "Eval Perplexity: ([0-9.e-]*)"},
    {
        "Name": "eval_samples_per_second",
        "Regex": "Eval Samples/sec: ([0-9.e-]*)",
    },
    {"Name": "eval_tokens_per_second", "Regex": "Eval Tokens/sec: ([0-9.e-]*)"},
]

In [None]:
# Additional training parameters
hyperparameters = {
    "num_train_epochs": 2,
    "model_id": MODEL_ID,
    "per_device_train_batch_size": 8,
    "per_device_eval_batch_size": 8,
    "bf16": True,
    "logging_steps": 8,
    "optim": "adamw_torch",
    "pretrain" : 1,
    "train_sample_count" : 10000
}

# creates Hugging Face estimator
p4_estimator = PyTorch(
    base_job_name="esm-2-oas-g5d",
    entry_point="cuda-oas-mlm-train-ddp.py",
    source_dir="scripts/training/cuda",
    instance_type="ml.g5.2xlarge",
    instance_count=2,
    image_uri=f"763104351884.dkr.ecr.{REGION_NAME}.amazonaws.com/pytorch-training:1.13.1-gpu-py39-cu117-ubuntu20.04-sagemaker",
    output_path=f"{S3_PATH}/output",
    role=sagemaker_execution_role,
    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
    sagemaker_session=sagemaker_session,
    distribution={"torch_distributed": {"enabled": True}},
    tags=[{"Key": "project", "Value": "esm-benchmarking"}],
)

with Run(
    experiment_name=EXPERIMENT_NAME,
    sagemaker_session=sagemaker_session,
) as run:
    p4_estimator.fit(
        {
            "train": TrainingInput(s3_data=train_s3_uri, input_mode="FastFile"),
            "test": TrainingInput(s3_data=test_s3_uri, input_mode="FastFile"),
        },
        wait=True,
    )

---
## 4. Train on g5.12xlarge with Fully Sharded Data Parallel

In [21]:
# Additional training parameters
hyperparameters = {
    "num_train_epochs": 2,
    "model_id": MODEL_ID,
    "per_device_train_batch_size": 24,
    "per_device_eval_batch_size": 24,
    "bf16": True,
    "logging_steps": 8,
    "optim": "adamw_torch",
    "pretrain" : 1,
    "train_sample_count" : 10000
}

# creates Hugging Face estimator
p4_estimator = PyTorch(
    base_job_name="esm-2-oas-g512x-ddp-fsdp",
    entry_point="cuda-oas-mlm-train-ddp-fsdp.py",
    source_dir="scripts/training/cuda",
    instance_type="ml.g5.12xlarge",
    instance_count=1,
    image_uri=f"763104351884.dkr.ecr.{REGION_NAME}.amazonaws.com/pytorch-training:2.0.1-gpu-py310-cu118-ubuntu20.04-sagemaker",
    output_path=f"{S3_PATH}/output",
    role=sagemaker_execution_role,
    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
    sagemaker_session=sagemaker_session,
    distribution={"torch_distributed": {"enabled": True}},
    tags=[{"Key": "project", "Value": "esm-benchmarking"}],
)

with Run(
    experiment_name=EXPERIMENT_NAME,
    sagemaker_session=sagemaker_session,
) as run:
    p4_estimator.fit(
        {
            "train": TrainingInput(s3_data=train_s3_uri, input_mode="FastFile"),
            "test": TrainingInput(s3_data=test_s3_uri, input_mode="FastFile"),
        },
        wait=False,
    )

INFO:sagemaker:Creating training-job with name: esm-2-oas-g512x-ddp-fsdp-2023-10-16-20-51-51-413


Using provided s3_resource


---
## 5. Train on p4d.24xlarge

In [None]:
metric_definitions = [
    {"Name": "epoch", "Regex": "Epoch: ([0-9.]*)"},
    {"Name": "step", "Regex": "Step: ([0-9.]*)"},
    {"Name": "train_loss", "Regex": "Training Loss: ([0-9.e-]*)"},
    {"Name": "train_perplexity", "Regex": "Training Perplexity: ([0-9.e-]*)"},
    {
        "Name": "train_samples_per_second",
        "Regex": "Training Samples/sec: ([0-9.e-]*)",
    },
    {
        "Name": "train_tokens_per_second",
        "Regex": "Training Tokens/sec: ([0-9.e-]*)",
    },
    {"Name": "eval_loss", "Regex": "Eval Loss: ([0-9.e-]*)"},
    {"Name": "eval_perplexity", "Regex": "Eval Perplexity: ([0-9.e-]*)"},
    {
        "Name": "eval_samples_per_second",
        "Regex": "Eval Samples/sec: ([0-9.e-]*)",
    },
    {"Name": "eval_tokens_per_second", "Regex": "Eval Tokens/sec: ([0-9.e-]*)"},
]

In [None]:
# Additional training parameters
hyperparameters = {
    "num_train_epochs": 2,
    "model_id": MODEL_ID,
    "per_device_train_batch_size": 8,
    "per_device_eval_batch_size": 8,
    "bf16": True,
    "logging_steps": 8,
    "optim": "adamw_torch",
}

# creates Hugging Face estimator
p4_estimator = PyTorch(
    base_job_name="esm-2-oas-p4d",
    entry_point="cuda-oas-mlm-train-ddp-fsdp.py",
    source_dir="scripts/training/cuda",
    instance_type="ml.p4d.24xlarge",
    instance_count=1,
    image_uri=f"763104351884.dkr.ecr.{REGION_NAME}.amazonaws.com/pytorch-training:2.0.1-gpu-py310-cu118-ubuntu20.04-sagemaker",
    output_path=f"{S3_PATH}/output",
    role=sagemaker_execution_role,
    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
    sagemaker_session=sagemaker_session,
    distribution={"torch_distributed": {"enabled": True}},
    tags=[{"Key": "project", "Value": "esm-benchmarking"}],
)

with Run(
    experiment_name=EXPERIMENT_NAME,
    sagemaker_session=sagemaker_session,
) as run:
    p4_estimator.fit(
        {
            "train": TrainingInput(s3_data=train_s3_uri, input_mode="FastFile"),
            "test": TrainingInput(s3_data=test_s3_uri, input_mode="FastFile"),
        },
        wait=False,
    )