# Deploy pre-trained ESM-2 model to Inferentia2

Note: This notebook was last tested in SageMaker Studio on the PyTorch 1.13 Python 3.9 CPU Optimized image on a ml.m5.large instance.

---
## 1. Install neuronx and dependencies

In [None]:
%%sh

apt-get install gpg-agent -y
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add -
add-apt-repository https://apt.repos.neuron.amazonaws.com
apt-get update -y
apt-get install aws-neuronx-dkms=2.* aws-neuronx-collectives=2.* aws-neuronx-runtime-lib=2.* aws-neuronx-tools=2.* -y

In [None]:
%pip install -q --upgrade pip
%pip install -q --upgrade --extra-index-url https://pip.repos.neuron.amazonaws.com \
  neuronx-cc==2.* torch-neuronx torch sagemaker boto3 awscli transformers accelerate boto3 --no-cache

---
## 2. Compile pretrained model

In [None]:
import torch
import torch_neuronx
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig
import timeit
from timeit import default_timer as timer

start_time = timer()

print("Preparing example input")
sequence = (
    "KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"
)
max_length = 128
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
tokenized_sequence = tokenizer.encode_plus(
    sequence,
    max_length=max_length,
    padding="max_length",
    truncation=True,
    return_tensors="pt",
)
tracing_input = tokenized_sequence["input_ids"], tokenized_sequence["attention_mask"]

print("Loading pre-trained model")
model = AutoModelForMaskedLM.from_pretrained(
    "facebook/esm2_t12_35M_UR50D", return_dict=False
)
model.eval()
model_load_end_time = timer()

print("Beginning model trace")
neuron_model = torch_neuronx.trace(model, tracing_input)
neuron_model.eval()
model_trace_end_time = timer()
print(
    f"Model trace completed in {round(model_trace_end_time - model_load_end_time, 3)} seconds."
)

end_time = timer()

print(f"Process completed in {round(end_time - start_time, 3)} seconds.")

---
## 3. Assemble model package

In [None]:
import boto3
import sagemaker

boto_session = boto3.session.Session(
    profile_name="aws-hcls-ml-sa-Admin", region_name="us-west-2"
)
sagemaker_session = sagemaker.session.Session(boto_session)
S3_BUCKET = sagemaker_session.default_bucket()
sagemaker_client = boto_session.client("sagemaker")
sagemaker_execution_role = sagemaker.session.get_execution_role(sagemaker_session)
print(f"Assumed SageMaker role is {sagemaker_execution_role}")

S3_PREFIX = "compiled-model"
S3_PATH = sagemaker.s3.s3_path_join("s3://", S3_BUCKET, S3_PREFIX)
print(f"S3 path is {S3_PATH}")

In [None]:
neuron_model.save("neuron_compiled_model.pt")

In [None]:
!tar -czvf model.tar.gz neuron_compiled_model.pt
!rm neuron_compiled_model.pt

In [None]:
# s3_model_uri = sagemaker_session.upload_data("model.tar.gz", S3_BUCKET, S3_PREFIX)
s3_model_uri = "s3://sagemaker-us-west-2-111918798052/compiled-model/model.tar.gz"
print(f"Model artifact uploaded to {s3_model_uri}")

In [None]:
!rm 'model.tar.gz'

---
## 4. Define inference script

In [36]:
%%writefile scripts/inference/inf2/requirements.txt

transformers
torch
torch-neuronx --extra-index-url https://pip.repos.neuron.amazonaws.com
accelerate
neuronx-cc==2.*

Overwriting scripts/inf2-deploy/requirements.txt


In [37]:
%%writefile scripts/inference/inf2/inference.py

import os
import json
import torch
from transformers import AutoTokenizer

JSON_CONTENT_TYPE = "application/json"


def model_fn(model_dir):
    """ Load the model from HuggingFace """
    tokenizer_init = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
    model_file = os.path.join(model_dir, "neuron_compiled_model.pt")
    model_neuron = torch.jit.load(model_file) ## This is throwing some errors during inf2 deployment
    return (model_neuron, tokenizer_init)


def input_fn(serialized_input_data, content_type=JSON_CONTENT_TYPE):
    """ Process the request payload """
    
    if content_type == JSON_CONTENT_TYPE:
        input_data = json.loads(serialized_input_data)
        return input_data.pop("inputs", input_data)
    else:
        raise Exception("Requested unsupported ContentType in Accept: " + content_type)
        return


def predict_fn(input_data, model_and_tokenizer):
    """ Run model inference """
    
    model_bert, tokenizer = model_and_tokenizer
    max_length = 128
    tokenized_sequence = tokenizer.encode_plus(
        input_data,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    prediction_input = (
        tokenized_sequence["input_ids"],
        tokenized_sequence["attention_mask"],
    )
    output = neuron_model(*prediction_input)[0]
    mask_token_index = (tokenized_sequence.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
    mask_index_predictions = output[0, mask_token_index]
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(mask_index_predictions)
    return {
        list(tokenizer.get_vocab().keys())[idx]: round(v.item(), 3)
        for idx, v in enumerate(probs[0])
    }


def output_fn(prediction_output, accept=JSON_CONTENT_TYPE):
    """ Process the response payload """
    if accept == JSON_CONTENT_TYPE:
        return json.dumps(prediction_output), accept

    raise Exception("Requested unsupported ContentType in Accept: " + accept)


Overwriting scripts/inf2-deploy/inference.py


---
## 5. Deploy model endpoint

In [38]:
from sagemaker.pytorch.model import PyTorchModel

ecr_image = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:1.13.1-neuronx-py310-sdk2.12.0-ubuntu20.04"

pytorch_model = PyTorchModel(
    model_data=s3_model_uri,
    role=sagemaker_execution_role,
    sagemaker_session=sagemaker_session,
    source_dir="scripts/inf2-deploy",
    entry_point="inference.py",
    image_uri=ecr_image,
)

# Let SageMaker know that we've already compiled the model via neuron-cc
pytorch_model._is_compiled_model = True

In [39]:
%%time

predictor = pytorch_model.deploy(
    instance_type="ml.inf2.xlarge", initial_instance_count=1
)

----------------------------------------------!CPU times: user 3.01 s, sys: 884 ms, total: 3.89 s
Wall time: 24min 1s


In [None]:
predictor.serializer = sagemaker.serializers.JSONSerializer()
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()

test_seq = "MAAAVVLAAGLRAARRAVAATGVRGGQVRGAAGVT<mask>GNEVAKAQQATPGGAAPTIFSRILDKSLPADILYEDQQCLVFRDVAPQAPVHFLVIPKKPIPRISQAEEEDQQLLGHLLLVAKQTAKAEGLGDGYRLVINDGKLGAQSVYHLHIHVLGGRQLQWPPG"
sample = {"inputs": test_seq}
predictor.predict(sample)

In [34]:
try:
    predictor.delete_endpoint()
except:
    pass