# Benchmarking Protein Language Model Pretraining on UniRef50

---
## 1. Setup

In [5]:
import boto3
import os
import sagemaker
from time import strftime

boto_session = boto3.session.Session()
sagemaker_session = sagemaker.session.Session(boto_session)
REGION_NAME = sagemaker_session.boto_region_name
S3_BUCKET = sagemaker_session.default_bucket()
S3_PREFIX = "plm-uniref50-benchmarking"
S3_FOLDER = sagemaker.s3.s3_path_join("s3://", S3_BUCKET, S3_PREFIX)
print(f"S3 uri is {S3_FOLDER}")

EXPERIMENT_NAME = "uniref50-benchmarking"

SAGEMAKER_EXECUTION_ROLE = sagemaker.session.get_execution_role(sagemaker_session)
print(f"Assumed SageMaker role is {SAGEMAKER_EXECUTION_ROLE}")

INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole


S3 uri is s3://sagemaker-us-west-2-111918798052/plm-uniref50-benchmarking
Assumed SageMaker role is arn:aws:iam::111918798052:role/bloyal-esm-training-230722-SageMakerExecutionRole-SGS3MFJIVY26


---
## 2. Data Processing

## 2.1. Download UniRef50 FASTA File and Convert to Partitioned CSVs

In [None]:
from sagemaker.pytorch.processing import PyTorchProcessor
from sagemaker.processing import ProcessingOutput

raw_data_uri = os.path.join(S3_FOLDER, "data", "raw", "big")

processor = PyTorchProcessor(
    base_job_name="fasta-processing",
    instance_count=1,
    instance_type="ml.m5.xlarge",
    framework_version="2.0.0",
    py_version="py310",
    role=SAGEMAKER_EXECUTION_ROLE,
    sagemaker_session=sagemaker_session,
)

processor.run(
    arguments=[
        "--source",
        "https://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref50/uniref50.fasta.gz",
        "--output_dir",
        "/opt/ml/processing/output",
        "--max_records_per_partition",
        "1000000"
    ],
    code="get_fasta.py",
    source_dir="scripts/processing",

    outputs=[
        ProcessingOutput(
            source="/opt/ml/processing/output",  # When the job finishes, SageMaker will copy data from here...
            destination=raw_data_uri,  # ...to here
        )
    ],
    wait=False,
)

## 2.2. Convert CSVs to HuggingFace Dataset and Tokenize

In [None]:
raw_data_uri = os.path.join(S3_FOLDER, "data", "raw")
raw_data_uri

In [None]:
from sagemaker.pytorch.processing import PyTorchProcessor
from sagemaker.processing import ProcessingInput, ProcessingOutput

hf_data_uri = os.path.join(S3_FOLDER, "data", "processed", "arrow")

processor = PyTorchProcessor(
    base_job_name="hf-tokenization",
    instance_count=1,
    instance_type="ml.c5.9xlarge",
    framework_version="2.0.0",
    py_version="py310",
    role=SAGEMAKER_EXECUTION_ROLE,
    sagemaker_session=sagemaker_session,
    volume_size_in_gb=512
)

processor.run(
    arguments=[
        "--model_name_or_path",
        "facebook/esm2_t30_150M_UR50D",
        "--low_cpu_mem_usage",
        "True",
        "--max_seq_length",
        "512",
        "--preprocessing_num_workers",
        "24",
        "--line_by_line",
        "True",
        "--pad_to_max_length",
        "True",
        "--train_dir",
        "/opt/ml/processing/input",
        "--output_dir",
        "/opt/ml/processing/output",
    ],
    code="hf_tokenization.py",
    source_dir="scripts/processing",
    inputs=[
        ProcessingInput(
            source=os.path.join(raw_data_uri, 'csv'),  # When the job starts, SageMaker will copy data from here...
            destination="/opt/ml/processing/input",  # ...to here
        )
    ],
    outputs=[
        ProcessingOutput(
            source="/opt/ml/processing/output",  # When the job finishes, SageMaker will copy data from here...
            destination=hf_data_uri,  # ...to here
        )
    ],
    wait=False,
)

## 4. Training

## 4.1. Define Metrics

In [6]:
from sagemaker.huggingface import HuggingFace

# metric definition to extract the results
metric_definitions = [
    {"Name": "epoch", "Regex": "epoch.*=\D*(.*?)$"},
    {
        "Name": "max_train_gpu_utilization",
        "Regex": "max_train_gpu_utilization.*=\D*(.*?)$",
    },
    {"Name": "train_loss", "Regex": "train_loss.*=\D*(.*?)$"},
    {"Name": "train_runtime", "Regex": "train_runtime.*=\D*(.*?)$"},
    {"Name": "train_samples", "Regex": "train_samples.*=\D*(.*?)$"},
    {
        "Name": "train_samples_per_second",
        "Regex": "train_samples_per_second.*=\D*(.*?)$",
    },
    {"Name": "train_steps_per_second", "Regex": "train_steps_per_second.*=\D*(.*?)$"},
    {"Name": "eval_accuracy", "Regex": "eval_accuracy.*=\D*(.*?)$"},
    {"Name": "eval_loss", "Regex": "eval_loss.*=\D*(.*?)$"},
    {"Name": "eval_runtime", "Regex": "eval_runtime.*=\D*(.*?)$"},
    {"Name": "eval_samples", "Regex": "eval_samples.*=\D*(.*?)$"},
    {"Name": "eval_samples_per_second", "Regex": "eval_samples_per_second.*=\D*(.*?)$"},
    {"Name": "eval_steps_per_second", "Regex": "eval_steps_per_second.*=\D*(.*?)$"},
    {"Name": "eval_perplexity", "Regex": "perplexity.*=\D*(.*?)$"},
]

### 4.2. Train 150M Parameter ESM-2 on UniRef50 using HuggingFace

In [None]:
from sagemaker.huggingface import HuggingFace
from sagemaker.experiments.run import Run
from sagemaker.inputs import TrainingInput

experiment_name = "240112-esm-training-full-data"

hyperparameters = {
    "bf16": True,
    "config_name": "facebook/esm2_t30_150M_UR50D",
    "dataloader_num_workers": 8,
    "do_eval": False,
    "do_preprocess": False,
    "do_train": True,
    "gradient_accumulation_steps": 16,
    "logging_steps": 16,
    "num_train_epochs": 1,
    "output_dir": "/opt/ml/model",
    "per_device_train_batch_size": 24,
    "tokenizer_name": "facebook/esm2_t30_150M_UR50D",
    "dataset_dir": "/opt/ml/input/data/training",
    "torch_compile": True,
}

p4_estimator = HuggingFace(
    base_job_name="p4-hf",
    distribution={"torch_distributed": {"enabled": True}},
    entry_point="run_mlm.py",
    hyperparameters=hyperparameters,
    instance_count=1,
    instance_type="ml.p4d.24xlarge",
    metric_definitions=metric_definitions,
    pytorch_version="2.0.0",
    py_version="py310",
    role=SAGEMAKER_EXECUTION_ROLE,
    sagemaker_session=sagemaker_session,
    source_dir="scripts/training/cuda",
    transformers_version="4.28.1",
)

with Run(
    experiment_name=experiment_name,
    sagemaker_session=sagemaker_session,
) as run:
    p4_estimator.fit(
        {
            "training": TrainingInput(s3_data=hf_data_uri, input_mode="File"),
        },
        wait=False,
    )

### 4.3 Train 150M Parameter ESM-2 on UniRef50 using Trn1

In [7]:
from sagemaker.pytorch import PyTorch
from sagemaker.experiments.run import Run
from sagemaker.inputs import TrainingInput


neuron_cache = os.path.join(S3_FOLDER, "parallel-neuron-cache")

IMAGE_URI = f"763104351884.dkr.ecr.{REGION_NAME}.amazonaws.com/huggingface-pytorch-training-neuronx:1.13.1-transformers4.34.1-neuronx-py310-sdk2.15.0-ubuntu20.04"

MODEL_ID="facebook/esm2_t30_150M_UR50D"

# Additional training parameters
hyperparameters = {
    "num_train_epochs": 1,
    "model_id": MODEL_ID,
    "per_device_train_batch_size": 6,
    "bf16": True,
    "logging_steps": 8,
    "optim": "adamw_torch",
    "gradient_accumulation_steps": 16,
    "device": "xla",
}

# creates Hugging Face estimator
trn1_estimator = PyTorch(
    base_job_name="trn1-training",
    entry_point="torch_xla_train.py",
    source_dir="scripts/training/neuron",
    instance_type="ml.trn1.32xlarge",
    instance_count=1,
    image_uri=IMAGE_URI,
    output_path=f"{S3_FOLDER}/output",
    role=SAGEMAKER_EXECUTION_ROLE,
    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
    checkpoint_local_path="/opt/ml/checkpoints",
    sagemaker_session=sagemaker_session,
    distribution={"torch_distributed": {"enabled": True}},
    environment={
        "NEURON_COMPILE_CACHE_URL": neuron_cache,
        "FI_EFA_FORK_SAFE": "1",
        "XLA_USE_BF16": "1",
    },
    tags=[{"Key": "project", "Value": "esm-benchmarking"}],
)

with Run(
    experiment_name=EXPERIMENT_NAME,
    sagemaker_session=sagemaker_session,
) as run:
    trn1_estimator.fit(
        {
            "train": TrainingInput(
                s3_data=hf_data_uri, 
                input_mode="File"
            ),
        },
        wait=False,
    )

INFO:sagemaker:Creating training-job with name: trn1-training-2024-02-02-03-56-38-065
