# Pretraining Protein Language Models on UniRef Sequences

---
## 1. Setup

In [None]:
import boto3
import os
import sagemaker

boto_session = boto3.session.Session()
sagemaker_session = sagemaker.session.Session(boto_session)
REGION_NAME = sagemaker_session.boto_region_name
S3_BUCKET = sagemaker_session.default_bucket()
S3_PREFIX = "plm-pretraining"
S3_FOLDER = sagemaker.s3.s3_path_join("s3://", S3_BUCKET, S3_PREFIX)
print(f"S3 uri is {S3_FOLDER}")

EXPERIMENT_NAME = "plm-pretraining"

SAGEMAKER_EXECUTION_ROLE = sagemaker.session.get_execution_role(sagemaker_session)
print(f"Assumed SageMaker role is {SAGEMAKER_EXECUTION_ROLE}")

RAW_DATA_URI = os.path.join(S3_FOLDER, "data", "raw")
print(f"Raw data uri is {RAW_DATA_URI}")

PROCESSED_DATA_URI = os.path.join(S3_FOLDER, "data", "processed")
print(f"Processed data uri is {PROCESSED_DATA_URI}")

---
## 2. Data Processing

## 2.1. Download UniRef50 FASTA File and Convert to Partitioned CSVs

In [None]:
from sagemaker.pytorch.processing import PyTorchProcessor
from sagemaker.processing import ProcessingOutput

processor = PyTorchProcessor(
    base_job_name="fasta-processing",
    instance_count=1,
    instance_type="ml.m5.xlarge",
    framework_version="2.0.0",
    py_version="py310",
    role=SAGEMAKER_EXECUTION_ROLE,
    sagemaker_session=sagemaker_session,
)

processor.run(
    arguments=[
        "--source",
        "https://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref50/uniref50.fasta.gz",
        "--output_dir",
        "/opt/ml/processing/output",
        "--max_records_per_partition",
        "500000",
    ],
    code="fasta_to_csv.py",
    source_dir="scripts/processing",
    outputs=[
        ProcessingOutput(
            source="/opt/ml/processing/output",  # When the job finishes, SageMaker will copy data from here...
            destination=RAW_DATA_URI,  # ...to here
        )
    ],
    wait=False,
)

## 2.2. Convert CSVs to HuggingFace Dataset and Tokenize

In [None]:
from sagemaker.pytorch.processing import PyTorchProcessor
from sagemaker.processing import ProcessingInput, ProcessingOutput


processor = PyTorchProcessor(
    base_job_name="hf-tokenization",
    instance_count=1,
    instance_type="ml.c5.9xlarge",
    framework_version="2.0.0",
    py_version="py310",
    role=SAGEMAKER_EXECUTION_ROLE,
    sagemaker_session=sagemaker_session,
    volume_size_in_gb=512,
)

processor.run(
    arguments=[
        "--tokenizer_name",
        "facebook/esm2_t30_150M_UR50D",
        "--max_seq_length",
        "512",
        "--preprocessing_num_workers",
        "24",
        "--line_by_line",
        "True",
        "--train_size",
        "10000000",
        "--validation_size",
        "50000",
        "--test_size",
        "50000",
    ],
    code="tokenize_uniref_csv.py",
    source_dir="scripts/processing",
    inputs=[
        ProcessingInput(
            source=os.path.join(
                RAW_DATA_URI, "csv"
            ),  # When the job starts, SageMaker will copy data from here...
            destination="/opt/ml/processing/input",  # ...to here
        )
    ],
    outputs=[
        ProcessingOutput(
            source="/opt/ml/processing/output",  # When the job finishes, SageMaker will copy data from here...
            destination=PROCESSED_DATA_URI,  # ...to here
        )
    ],
    wait=False,
)

---
## 3. Training

### 3.1. CUDA on ml.p4d.24xlarge

In [None]:
metric_definitions = [
    {"Name": "TrainingLoss", "Regex": "'loss': ([0-9.]+)"},
    {"Name": "Epoch", "Regex": "'epoch': ([0-9.]+)"},
    {"Name": "TrainingRuntime", "Regex": "'train_runtime': ([0-9.]+)"},
    {
        "Name": "TrainingSamplesPerSecond",
        "Regex": "'train_samples_per_second': ([0-9.]+)",
    },
    {"Name": "TrainingStepsPerSecond", "Regex": "'train_steps_per_second': ([0-9.]+)"},
]

In [None]:
from sagemaker.experiments.run import Run
from sagemaker.huggingface import HuggingFace
from sagemaker.inputs import TrainingInput

hyperparameters = {
    "bf16": True,
    "config_name": "facebook/esm2_t30_150M_UR50D",
    "dataloader_num_workers": 8,
    "do_eval": True,
    "do_preprocess": False,
    "do_train": True,
    "gradient_accumulation_steps": 16,
    "logging_steps": 16,
    "num_train_epochs": 1,
    "output_dir": "/opt/ml/model",
    "per_device_train_batch_size": 24,
    "tokenizer_name": "facebook/esm2_t30_150M_UR50D",
    "dataset_dir": "/opt/ml/input/data/training",
    "torch_compile": True,
    "pad_to_max_length": True,
    "max_seq_length": 512
}

p4_estimator = HuggingFace(
    base_job_name="p4-plm-training",
    distribution={"torch_distributed": {"enabled": True}},
    entry_point="run_mlm.py",
    hyperparameters=hyperparameters,
    instance_count=1,
    instance_type="ml.p4d.24xlarge",
    metric_definitions=metric_definitions,
    pytorch_version="2.0.0",
    py_version="py310",
    role=SAGEMAKER_EXECUTION_ROLE,
    sagemaker_session=sagemaker_session,
    source_dir="scripts/training/cuda",
    transformers_version="4.28.1",
)

with Run(
    experiment_name=EXPERIMENT_NAME,
    sagemaker_session=sagemaker_session,
) as run:
    p4_estimator.fit(
        {
            "training": TrainingInput(
                s3_data=os.path.join(PROCESSED_DATA_URI, "arrow"), input_mode="File"
            ),
        },
        wait=False,
    )

### 4.3 Torch-NeuronX on ml.trn1.32xlarge

(Optional) Pre-compile

In [None]:
from sagemaker.inputs import TrainingInput
from sagemaker.pytorch import PyTorch

NEURON_CACHE = os.path.join(S3_FOLDER, "parallel-neuron-cache")
IMAGE_URI = f"763104351884.dkr.ecr.{REGION_NAME}.amazonaws.com/huggingface-pytorch-training-neuronx:1.13.1-transformers4.34.1-neuronx-py310-sdk2.15.0-ubuntu20.04"
MODEL_ID = "facebook/esm2_t30_150M_UR50D"

hyperparameters = {
    "data_dir": "/opt/ml/input/data/training",
    "gradient_accumulation_steps": 8,
    "logging_steps": 16,
    "model_id": MODEL_ID,
    "steps_this_run": 64,
    "optim": "adamw_torch",
    "per_device_train_batch_size": 6,
}

trn1_estimator = PyTorch(
    base_job_name="trn1-plm-precompilation",
    entry_point="torch_xla_train.py",
    source_dir="scripts/training/neuron",
    instance_type="ml.trn1.32xlarge",
    instance_count=1,
    image_uri=IMAGE_URI,
    output_path=f"{S3_FOLDER}/output",
    role=SAGEMAKER_EXECUTION_ROLE,
    hyperparameters=hyperparameters,
    sagemaker_session=sagemaker_session,
    distribution={"torch_distributed": {"enabled": True}},
    environment={
        "NEURON_COMPILE_CACHE_URL": NEURON_CACHE,
        "XLA_USE_BF16": "1",
        "RUN_NEURON_PARALLEL_COMPILE": "1",
    },
    tags=[{"Key": "project", "Value": "esm-benchmarking"}],
)


trn1_estimator.fit(
    {
        "training": TrainingInput(
            s3_data=os.path.join(PROCESSED_DATA_URI, "arrow"), input_mode="File"
        ),
    },
    wait=False,
)

In [None]:
metric_definitions = [
    {"Name": "epoch", "Regex": "Epoch: (.*?),"},
    {"Name": "step", "Regex": "Step: (.*?),"},
    {"Name": "train_loss", "Regex": "Training Loss: (.*?),"},
    {"Name": "train_perplexity", "Regex": "Training Perplexity: (.*?),"},
    {
        "Name": "train_samples_per_sec",
        "Regex": "Training Samples/sec: (.*?),",
    },
    {"Name": "train_tokens_per_sec", "Regex": "Training Tokens/sec: (.*?)$"},
]

In [None]:
from sagemaker.experiments.run import Run
from sagemaker.inputs import TrainingInput
from sagemaker.pytorch import PyTorch

NEURON_CACHE = os.path.join(S3_FOLDER, "parallel-neuron-cache")
IMAGE_URI = f"763104351884.dkr.ecr.{REGION_NAME}.amazonaws.com/huggingface-pytorch-training-neuronx:1.13.1-transformers4.34.1-neuronx-py310-sdk2.15.0-ubuntu20.04"
MODEL_ID="facebook/esm2_t30_150M_UR50D"

hyperparameters = {
    "data_dir": "/opt/ml/input/data/training",
    "gradient_accumulation_steps": 8,
    "logging_steps": 16,
    "model_id": MODEL_ID,
    "num_train_epochs": 1,
    "optim": "adamw_torch",
    "per_device_train_batch_size": 6,
}

trn1_estimator = PyTorch(
    base_job_name="trn1-plm-training",
    entry_point="torch_xla_train.py",
    source_dir="scripts/training/neuron",
    instance_type="ml.trn1.32xlarge",
    instance_count=1,
    image_uri=IMAGE_URI,
    output_path=f"{S3_FOLDER}/output",
    role=SAGEMAKER_EXECUTION_ROLE,
    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
    checkpoint_local_path="/opt/ml/checkpoints",
    sagemaker_session=sagemaker_session,
    distribution={"torch_distributed": {"enabled": True}},
    environment={
        "NEURON_COMPILE_CACHE_URL": NEURON_CACHE,
        "XLA_USE_BF16": "1",
    },
    tags=[{"Key": "project", "Value": "esm-benchmarking"}],
)

with Run(
    experiment_name=EXPERIMENT_NAME,
    sagemaker_session=sagemaker_session,
) as run:
    trn1_estimator.fit(
        {
            "training": TrainingInput(s3_data=os.path.join(PROCESSED_DATA_URI,"arrow"), input_mode="File"),
        },
        wait=False,
    )