# Fine-Tune the ESM-2 Protein Language Model on Paired Antibody Sequence Data

Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: MIT-0

Note: We recommend running this notebook on a **ml.m5.large** instance with the **Data Science 3.0** image.

### What is a Protein?

Proteins are complex molecules that are essential for life. The shape and structure of a protein determines what it can do in the body. Knowing how a protein is folded and how it works helps scientists design drugs that target it. For example, if a protein causes disease, a drug might be made to block its function. The drug needs to fit into the protein like a key in a lock. Understanding the protein's molecular structure reveals where drugs can attach. This knowledge helps drive the discovery of innovative new drugs.

![Proteins are made up of long chains of amino acids](img/protein.png)

### What is a Protein Language Model?

Proteins are made up of linear chains of molecules called amino acids, each with its own chemical structure and properties. If we think of each amino acid in a protein like a word in a sentence, it becomes possible to analyze them using methods originally developed for analyzing human language. Scientists have trained these so-called, "Protein Language Models", or pLMs, on millions of protein sequences from thousands of organisms. With enough data, these models can begin to capture the underlying evolutionary relationships between different amino acid sequences.

It can take a lot of time and compute to train a pLM from scratch for a certain task. For example, a team at Tsinghua University [recently described](https://www.biorxiv.org/content/10.1101/2023.07.05.547496v3) training a 100 Billion-parameter pLM on 768 A100 GPUs for 164 days! Fortunately, in many cases we can save time and resources by adapting an existing pLM to our needs. This technique is called "fine-tuning" and also allows us to borrow advanced tools from other types of language modeling

### What is ESM-2?

[ESM-2](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1) is a pLM trained using unsupervied masked language modelling on 250 Million protein sequences by researchers at [Facebook AI Research (FAIR)](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1). It is available in several sizes, ranging from 8 Million to 15 Billion parameters. The smaller models are suitable for various sequence and token classification tasks. The FAIR team also adapted the 3 Billion parameter version into the ESMFold protein structure prediction algorithm. They have since used ESMFold to predict the struture of [more than 700 Million metagenomic proteins](https://esmatlas.com/about). 

ESM-2 is a powerful pLM. However, it has traditionally required multiple A100 GPU chips to fine-tune. In this notebook, we demonstrate how to use QLoRA to fine-tune ESM-2 in on an inexpensive Amazon SageMaker training instance. We will use ESM-2 to predict [subcellular localization](https://academic.oup.com/nar/article/50/W1/W228/6576357). Understanding where proteins appear in cells can help us understand their role in disease and find new drug targets. 

Note: ESM checkpoint names

- esm2_t48_15B_UR50D  
- esm2_t36_3B_UR50D (12,009 MB)
- esm2_t33_650M_UR50D (3,641 MB)
- esm2_t30_150M_UR50D (1,643 MB)
- esm2_t12_35M_UR50D (1,171 MB)
- esm2_t6_8M_UR50D (1,037 MB))  

## 0. Setup

In [None]:
%pip install -U pip
%pip install -U transformers datasets torchinfo accelerate bitsandbytes boto3 sagemaker peft nvidia-ml-py3
# %pip install -U --disable-pip-version-check  --no-warn-conflicts -r notebook-requirements.txt

In [None]:
import boto3

# from datasets import Dataset
import datasets
from datasets import load_dataset, Dataset
import os
import pandas as pd
import random
import sagemaker
from sagemaker.experiments.run import Run
from sagemaker.huggingface import HuggingFace
from sagemaker.inputs import TrainingInput
from time import strftime
from transformers import AutoTokenizer

boto_session = boto3.session.Session()
sagemaker_session = sagemaker.session.Session(boto_session)
S3_BUCKET = sagemaker_session.default_bucket()
s3 = boto_session.client("s3")
sagemaker_client = boto_session.client("sagemaker")
sagemaker_execution_role = sagemaker.session.get_execution_role(sagemaker_session)
REGION_NAME = sagemaker_session.boto_region_name
print(f"Assumed SageMaker role is {sagemaker_execution_role}")

S3_PREFIX = "esm-pair-oas-ft"
S3_PATH = sagemaker.s3.s3_path_join("s3://", S3_BUCKET, S3_PREFIX)
print(f"S3 path is {S3_PATH}")

EXPERIMENT_NAME = "esm-pair-oas-ft-" + strftime("%Y-%m-%d-%H-%M-%S")
print(f"Experiment name is {EXPERIMENT_NAME}")

# import argparse
# import boto3
# import json
# import logging
# import math
# import os
# import random

# import datasets
# import torch
# from accelerate import Accelerator
# from accelerate.logging import get_logger
# from accelerate.utils import set_seed, DistributedDataParallelKwargs

# from datasets import load_dataset
# import sagemaker
# from sagemaker.experiments.run import Run
# from sagemaker.huggingface import HuggingFace, HuggingFaceModel
# from sagemaker.inputs import TrainingInput
# from time import strftime
# from torch.utils.data import DataLoader
# from tqdm.auto import tqdm
# from typing import Dict, List

import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_MAPPING,
    DataCollatorForLanguageModeling,
    get_scheduler,
    EsmForMaskedLM,
    EsmTokenizer,
)

# from torchinfo import summary

In [None]:
# boto_session = boto3.session.Session()
# sagemaker_session = sagemaker.session.Session(boto_session)
# S3_BUCKET = sagemaker_session.default_bucket()
# s3 = boto_session.client("s3")
# sagemaker_client = boto_session.client("sagemaker")
# sagemaker_execution_role = sagemaker.session.get_execution_role(sagemaker_session)
# REGION_NAME = sagemaker_session.boto_region_name
# print(f"Assumed SageMaker role is {sagemaker_execution_role}")

# PREFIX = "esm-oas-fine-tuning"
# S3_PATH = sagemaker.s3.s3_path_join("s3://", S3_BUCKET, PREFIX)
# print(f"S3 path is {S3_PATH}")

# EXPERIMENT_NAME = PREFIX + strftime("%Y-%m-%d-%H-%M-%S")
# print(f"Experiment name is {EXPERIMENT_NAME}")

## 1. Prepare OAS Paired Sequence Data

### 1.1. Download OAS Paired Sequence Data

In [None]:
PRETRAINED_MODEL_NAME = "facebook/esm2_t6_8M_UR50D"
DATASET_NAME = "bloyal/oas-paired-sequence-data"
DATASET_CONFIG = "rat_SD"
MAX_SEQ_LENGTH = 256

In [None]:
raw_datasets = load_dataset(DATASET_NAME, DATASET_CONFIG)
raw_datasets["train"][:3]

### 1.2. Remove duplicate Heavy Chain CDR3 sequences

In [None]:
df = raw_datasets["train"].to_pandas()
df = df.drop_duplicates(["cdr3_aa_heavy"], ignore_index=True)
print(df.head())
df = datasets.Dataset.from_pandas(df)

### 1.3. Split into train-validation-test data

In [None]:
# 80% train, 20% test + validation
train_test = df.train_test_split(test_size=0.2)
# Split the 10% test + valid in half test, half valid
test_valid = train_test["test"].train_test_split(test_size=0.5)
df = datasets.DatasetDict(
    {
        "train": train_test["train"],
        "validation": test_valid["train"],
        "test": test_valid["test"],
    }
)
print(df)

### 1.4. Tokenize sequence data

In [None]:
tokenizer = transformers.EsmTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)


def get_cdr_mask(examples, max_length=256):
    cdr_mask = []
    for example in zip(*examples.values()):
        example_mask = [0]
        for chain in (example[1:5], example[5:9]):
            seq = chain[0]
            chain_mask = [0] * len(seq)
            for i in range(1, 4):
                cdr_start = seq.find(chain[i])
                cdr_len = len(chain[i])
                chain_mask = (
                    chain_mask[:cdr_start]
                    + [1] * cdr_len
                    + chain_mask[(cdr_start + cdr_len) :]
                )
            example_mask += chain_mask + [0]
        example_mask = (example_mask + [0] * max_length)[:max_length]
        cdr_mask.append(example_mask)

    return cdr_mask


def tokenize_function(examples):
    tokenized_data = tokenizer(
        examples["sequence_alignment_aa_heavy"],
        examples["sequence_alignment_aa_light"],
        return_special_tokens_mask=True,
        padding="max_length",
        truncation=True,
        max_length=MAX_SEQ_LENGTH,
    )
    tokenized_data["cdr_mask"] = get_cdr_mask(examples, max_length=MAX_SEQ_LENGTH)
    return tokenized_data


tokenized_datasets = df.map(
    tokenize_function,
    batched=True,
    num_proc=os.cpu_count(),
    remove_columns=raw_datasets["train"].column_names,
    desc="Creating and tokenizing paired sequences",
)

### 1.5. Validate a random sample from the training set

In [None]:
for index in random.sample(range(len(df["train"])), 1):
    decoded_seq = tokenizer.decode(tokenized_datasets["train"][index]["input_ids"])
    print(
        f"""
Sample {index} of the training set:\n
original_sequence:
{decoded_seq}\n
original_cdrs:
{
    [
        df['train'][index]['cdr1_aa_heavy'],
        df['train'][index]['cdr2_aa_heavy'],
        df['train'][index]['cdr3_aa_heavy'],
        df['train'][index]['cdr1_aa_light'],
        df['train'][index]['cdr2_aa_light'], 
        df['train'][index]['cdr3_aa_light']
    ]}\n
{'#'*50}\n
sequence_length:
{len(tokenized_datasets['train'][index]['input_ids'])}\n
input_ids:
{tokenized_datasets['train'][index]['input_ids']}\n
special_tokens_mask:
{tokenized_datasets['train'][index]['special_tokens_mask']}\n
attention_mask:
{tokenized_datasets['train'][index]['attention_mask']}\n
cdr_mask:
{tokenized_datasets['train'][index]['cdr_mask']}\n
Decoded CDRs:"""
    )
    print(
        [
            tokenizer.decode(seq) if cdr == 1 else 0
            for seq, cdr in zip(
                tokenized_datasets["train"][index]["input_ids"],
                tokenized_datasets["train"][index]["cdr_mask"],
            )
        ]
    )

### 1.6. Group pairs of sequences into chunks of 512 tokens

In [None]:
def chunk_seqs(examples):
    result = {
        k: [x + y for x, y in zip(examples[k][::2], examples[k][1::2])]
        for k in examples.keys()
    }

    return result


tokenized_datasets = tokenized_datasets.map(
    chunk_seqs,
    batched=True,
    num_proc=os.cpu_count(),
    desc="Combining pairs of tokenized sequences.",
)

In [None]:
for index in random.sample(range(len(tokenized_datasets["train"])), 1):
    print(
        f"""
Sample {index} of the training set:\n

sequence_length:
{len(tokenized_datasets['train'][index]['input_ids'])}\n
input_ids:
{tokenized_datasets['train'][index]['input_ids']}\n
special_tokens_mask:
{tokenized_datasets['train'][index]['special_tokens_mask']}\n
attention_mask:
{tokenized_datasets['train'][index]['attention_mask']}\n
cdr_mask:
{tokenized_datasets['train'][index]['cdr_mask']}\n
Decoded CDRs:"""
    )
    print(
        [
            tokenizer.decode(seq) if cdr == 1 else 0
            for seq, cdr in zip(
                tokenized_datasets["train"][index]["input_ids"],
                tokenized_datasets["train"][index]["cdr_mask"],
            )
        ]
    )

### 1.7. Save encoded data to S3

In [None]:
tokenized_datasets["train"].save_to_disk(S3_PATH + "/data/train")
tokenized_datasets["validation"].save_to_disk(S3_PATH + "/data/validation")
tokenized_datasets["test"].save_to_disk(S3_PATH + "/data/test")

tokenized_datasets["train"].save_to_disk("data/train")
tokenized_datasets["validation"].save_to_disk("data/validation")

## 2.0. Define a custom data collator

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union, Mapping
from transformers.data.data_collator import _torch_collate_batch


@dataclass
class DataCollatorForCDRLanguageModeling(DataCollatorForLanguageModeling):
    cdr_probability: float = 0.3  # New attribute

    def torch_mask_tokens(
        self,
        inputs: Any,
        special_tokens_mask: Optional[Any] = None,
        cdr_mask: Optional[Any] = None,  # New parameter
    ) -> Tuple[Any, Any]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        import torch

        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(
                    val, already_has_special_tokens=True
                )
                for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)

        # New Code ###########################
        if cdr_mask is not None:
            probability_matrix.masked_fill_(cdr_mask.bool(), value=self.cdr_probability)
        # ###################################

        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = (
            torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        )
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
            self.tokenizer.mask_token
        )

        # 10% of the time, we replace masked input tokens with random word
        indices_random = (
            torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
            & masked_indices
            & ~indices_replaced
        )
        random_words = torch.randint(
            len(self.tokenizer), labels.shape, dtype=torch.long
        )
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

    def torch_call(
        self, examples: List[Union[List[int], Any, Dict[str, Any]]]
    ) -> Dict[str, Any]:
        """
        Handle dict or lists with proper padding and conversion to tensor.
        """

        if isinstance(examples[0], Mapping):
            batch = self.tokenizer.pad(
                examples,
                return_tensors="pt",
                pad_to_multiple_of=self.pad_to_multiple_of,
            )
        else:
            batch = {
                "input_ids": _torch_collate_batch(
                    examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of
                )
            }

        # If special token mask has been preprocessed, pop it from the dict.
        special_tokens_mask = batch.pop("special_tokens_mask", None)
        cdr_mask = batch.pop("cdr_mask", None)  # New code
        if self.mlm:
            batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
                batch["input_ids"],
                special_tokens_mask=special_tokens_mask,
                cdr_mask=cdr_mask,  # New code
            )
        else:
            labels = batch["input_ids"].clone()
            if self.tokenizer.pad_token_id is not None:
                labels[labels == self.tokenizer.pad_token_id] = -100
            batch["labels"] = labels
        return batch

In [None]:
import torch
from torch.utils.data import DataLoader

data_collator = DataCollatorForCDRLanguageModeling(
    tokenizer=tokenizer, cdr_probability=0.3, mlm_probability=0.1
)

train_dataloader = DataLoader(
    tokenized_datasets["test"], shuffle=True, collate_fn=data_collator, batch_size=3
)

torch.set_printoptions(threshold=10_000)
batch = next(iter(train_dataloader))

print("---Example batch---")
print(f"input_ids:\n{batch['input_ids']}")
print(f"attention_mask:\n{batch['attention_mask']}")
print(f"labels:\n{batch['labels']}")

## 3.0 Train

### 3.1. (Optional) Test the training loop locally

In [None]:
# !SM_CHANNEL_TRAIN="data/train" \
#   SM_CHANNEL_VALIDATION="data/validation" \
#   SM_MODEL_DIR="data/output" \
#   python scripts/oas_mlm_accelerate.py \
#   --model_name_or_path="facebook/esm2_t6_8M_UR50D" \
#   --output_dir="output" \
#   --mixed_precision="bf16" \
#   --max_train_steps=64 \
#   --lora=True \
#   --use_gradient_checkpointing=True \
#   --quantization="8bit"

### 3.2. Submit SageMaker training job

**esm2_t6_8M_UR50D**

- bf16 only: 2.48, 0.70 GB, 169 sec
- LoRA: 2.61, 0.62 GB, 172 sec.
- LoRA + 8bit quant: 2.64, 0.89 GB, 389 sec.
- LoRA + 4bit quant: 2.65, 0.83 GB, 236 sec

**esm2_t30_150M_UR50D**

- bf16 only: 1.57, 4.99 GB, 1148 sec
- LoRA: 1.61, 3.74 GB, 1055 sec
- LoRA + 8bit quant: 
- LoRA + 4bit quant: 1.62, 6.3 GB, 1359 sec

In [None]:
# Additional training parameters
hyperparameters = {
    "model_name_or_path": "facebook/esm2_t33_650M_UR50D",
    "num_train_epochs": 1,
    # "max_train_steps": 64,
    "mixed_precision": "bf16",
    "lora": True,
    "use_gradient_checkpointing": True,
    "quantization": "4bit",
}

# creates Hugging Face estimator
huggingface_estimator = sagemaker.huggingface.HuggingFace(
    base_job_name="esm-oas-mlm-lora-gc-4bit",
    entry_point="oas_mlm_accelerate.py",
    source_dir="scripts",
    instance_type="ml.g5.2xlarge",
    instance_count=1,
    transformers_version="4.28.1",
    pytorch_version="2.0.0",
    py_version="py310",
    output_path=f"{S3_PATH}/output",
    role=sagemaker_execution_role,
    hyperparameters=hyperparameters,
    # metric_definitions=metric_definitions,
    checkpoint_local_path="/opt/ml/checkpoints",
    sagemaker_session=sagemaker_session,
    keep_alive_period_in_seconds=1800,
    tags=[{"Key": "project", "Value": "esm-ft"}],
)

with Run(
    experiment_name=EXPERIMENT_NAME,
    sagemaker_session=sagemaker_session,
) as run:
    huggingface_estimator.fit(
        {
            "train": TrainingInput(
                s3_data=S3_PATH + "/data/train", input_mode="FastFile"
            ),
            "validation": TrainingInput(
                s3_data=S3_PATH + "/data/validation", input_mode="FastFile"
            ),
        },
        wait=False,
    )