# Deploy pre-trained ESM-2 model to Inferentia2

Note: This notebook was last tested in SageMaker Studio on the PyTorch 1.13 Python 3.9 CPU Optimized image on a ml.c5.4xlarge instance.

---
## 1. Install neuronx and dependencies

Install the neuronx compiler. NOTE: You will need to restart your notebook kernel after running this cell

In [None]:
%%sh
apt-get update -y
apt-get install gpg-agent -y
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add -
add-apt-repository https://apt.repos.neuron.amazonaws.com
apt-get update -y
apt-get install aws-neuronx-dkms=2.* aws-neuronx-collectives=2.* aws-neuronx-runtime-lib=2.* aws-neuronx-tools=2.* -y

In [None]:
%pip install -q --upgrade pip
%pip install -q --upgrade --extra-index-url https://pip.repos.neuron.amazonaws.com \
  neuronx-cc==2.* torch-neuronx torch sagemaker boto3 awscli transformers accelerate boto3 --no-cache

---
## 2. Compile pretrained model

In [26]:
import torch
import torch_neuronx
from transformers import AutoTokenizer, AutoModelForMaskedLM
import timeit
from timeit import default_timer as timer

MODEL_ID="facebook/esm2_t6_8M_UR50D"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForMaskedLM.from_pretrained(MODEL_ID, torchscript=True)
model.eval()

sequence = (
    "QVQLVESGGGVVQPRSLTLSCAASGFTFSSYGL<mask>HWVRQAPGKGLEWVANIWYDGANKYYGDSVKGRFTISRDNSRNTLYLQMNSLTAEDTAVYYCARWIEYGSGKDAFDVWGQGTMVIVSS"
)
max_length = 128
tokenized_sequence = tokenizer.encode_plus(
    sequence,
    max_length=max_length,
    padding="max_length",
    truncation=True,
    return_tensors="pt",
)
tracing_input = tokenized_sequence["input_ids"], tokenized_sequence["attention_mask"]

print("Testing model inference")
print(model(*tracing_input)[0])


print("Beginning model trace")
model_trace_start_time = timer()
neuron_model = torch_neuronx.trace(model, tracing_input)
neuron_model.save("traced_esm.pt")
print(
    f"Model trace completed in {round(timer() - model_trace_start_time, 0)} seconds."
)


Some weights of EsmForMaskedLM were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['lm_head.decoder.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Testing model inference
tensor([[[ 12.4939,  -7.4286,  -6.1493,  ..., -15.5151, -15.7623,  -7.3883],
         [ -7.1894, -15.7947,  -7.4696,  ..., -15.7256, -15.9145, -15.7955],
         [-10.4963, -17.5363, -10.8571,  ..., -16.3073, -16.2947, -17.5346],
         ...,
         [-10.6348, -18.9364,  -9.7578,  ..., -16.1243, -16.1494, -18.9432],
         [-10.5313, -18.7101,  -9.6842,  ..., -16.1191, -16.1363, -18.7131],
         [-10.5828, -18.6295,  -9.8384,  ..., -16.1243, -16.1400, -18.6290]]],
       grad_fn=<AddBackward0>)
Beginning model trace
Model trace completed in 39.0 seconds.


---
## 3. Assemble model package

In [27]:
!tar -czvf model.tar.gz traced_esm.pt

traced_esm.pt


In [None]:
import boto3
import sagemaker

boto_session = boto3.session.Session()
sagemaker_session = sagemaker.session.Session(boto_session)
S3_BUCKET = sagemaker_session.default_bucket()
sagemaker_client = boto_session.client("sagemaker")
sagemaker_execution_role = sagemaker.session.get_execution_role(sagemaker_session)
print(f"Assumed SageMaker role is {sagemaker_execution_role}")

S3_PREFIX = "compiled-model"
S3_PATH = sagemaker.s3.s3_path_join("s3://", S3_BUCKET, S3_PREFIX)
print(f"S3 path is {S3_PATH}")

s3_model_uri = sagemaker_session.upload_data("model.tar.gz", S3_BUCKET, S3_PREFIX)
print(f"Model artifact uploaded to {s3_model_uri}")

---
## 4. Define inference script

In [30]:
%%writefile scripts/inference/requirements.txt

transformers

Overwriting scripts/inference/inf2/requirements.txt


In [34]:
%%writefile scripts/inference/inference.py

import os
import json
import torch
import torch_neuronx
from transformers import AutoTokenizer

JSON_CONTENT_TYPE = "application/json"
MODEL_ID = "facebook/esm2_t6_8M_UR50D"


def model_fn(model_dir):
    """Load the model from HuggingFace"""
    tokenizer_init = AutoTokenizer.from_pretrained(MODEL_ID, device_map="auto")
    model_file = os.path.join(model_dir, "traced_esm.pt")
    neuron_model = torch.jit.load(model_file)
    return (neuron_model, tokenizer_init)


def input_fn(serialized_input_data, content_type=JSON_CONTENT_TYPE):
    """Process the request payload"""

    if content_type == JSON_CONTENT_TYPE:
        input_data = json.loads(serialized_input_data)
        return input_data
    else:
        raise Exception("Requested unsupported ContentType in Accept: " + content_type)
        return


def predict_fn(input_data, model_and_tokenizer):
    """Run model inference"""

    model, tokenizer = model_and_tokenizer
    max_length = 128
    tokenized_sequence = tokenizer.encode_plus(
        input_data,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    prediction_input = (
        tokenized_sequence["input_ids"],
        tokenized_sequence["attention_mask"],
    )
    output = model(*prediction_input)[0]
    mask_token_index = (tokenized_sequence.input_ids == tokenizer.mask_token_id)[
        0
    ].nonzero(as_tuple=True)[0]
    mask_index_predictions = output[0, mask_token_index]
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(mask_index_predictions)
    return {
        list(tokenizer.get_vocab().keys())[idx]: round(v.item(), 3)
        for idx, v in enumerate(probs[0])
    }


def output_fn(prediction_output, accept=JSON_CONTENT_TYPE):
    """Process the response payload"""
    if accept == JSON_CONTENT_TYPE:
        return json.dumps(prediction_output), accept

    raise Exception("Requested unsupported ContentType in Accept: " + accept)


Overwriting scripts/inference/inf2/inference.py


---
## 5. Deploy model endpoint

In [32]:
from sagemaker.pytorch.model import PyTorchModel

ecr_image = f"763104351884.dkr.ecr.{sagemaker_session.boto_region_name}.amazonaws.com/pytorch-inference-neuronx:1.13.0-neuronx-py38-sdk2.9.0-ubuntu20.04"

pytorch_model = PyTorchModel(
    model_data=s3_model_uri,
    role=sagemaker_execution_role,
    sagemaker_session=sagemaker_session,
    source_dir="scripts/inference",
    entry_point="inference.py",
    image_uri=ecr_image,
)

# Let SageMaker know that we've already compiled the model via neuron-cc
pytorch_model._is_compiled_model = True

In [33]:
%%time

predictor = pytorch_model.deploy(
    instance_type="ml.inf2.xlarge", initial_instance_count=1
)

-------------------!CPU times: user 727 ms, sys: 66.7 ms, total: 794 ms
Wall time: 10min 4s


In [42]:
predictor.serializer = sagemaker.serializers.JSONSerializer()
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()

test_seq = "QVQLVESGGGVVQ<mask>PGRSLTLSCAASGFTFSSYGLHWVRQAPGKGLE"
predictor.predict(test_seq)

{'<cls>': 0.0,
 '<pad>': 0.0,
 '<eos>': 0.0,
 '<unk>': 0.0,
 'L': 0.746,
 'A': 0.83,
 'G': 0.73,
 'V': 0.764,
 'S': 0.672,
 'E': 0.809,
 'R': 0.711,
 'T': 0.618,
 'I': 0.362,
 'D': 0.546,
 'P': 0.757,
 'K': 0.531,
 'Q': 0.731,
 'N': 0.355,
 'F': 0.354,
 'Y': 0.28,
 'M': 0.324,
 'H': 0.422,
 'W': 0.216,
 'C': 0.18,
 'X': 0.015,
 'B': 0.0,
 'U': 0.0,
 'Z': 0.0,
 'O': 0.0,
 '.': 0.0,
 '-': 0.0,
 '<null_1>': 0.0,
 '<mask>': 0.0}

In [None]:
try:
    predictor.delete_endpoint()
except:
    pass