# Deploy pre-trained ESM-2 model to Inferentia2

Note: This notebook was last tested in SageMaker Studio on the PyTorch 1.13 Python 3.9 CPU Optimized image on a ml.c5.4xlarge instance.

---
## 1. Install neuronx and dependencies

Install the neuronx compiler. NOTE: You will need to restart your notebook kernel after running this cell

In [None]:
%%sh
apt-get update -y
apt-get install gpg-agent -y
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add -
add-apt-repository https://apt.repos.neuron.amazonaws.com
apt-get update -y
apt-get install aws-neuronx-dkms=2.* aws-neuronx-collectives=2.* aws-neuronx-runtime-lib=2.* aws-neuronx-tools=2.* -y

In [None]:
%pip install -q --upgrade pip
%pip install -q --upgrade --extra-index-url https://pip.repos.neuron.amazonaws.com \
  neuronx-cc==2.* torch sagemaker boto3 awscli transformers accelerate boto3 datasets \
  torch-neuronx=='1.13.1.1.10.1'

In [None]:
import boto3
import sagemaker

boto_session = boto3.session.Session(profile_name=None, region_name=None)
sagemaker_session = sagemaker.session.Session(boto_session)
S3_BUCKET = sagemaker_session.default_bucket()
sagemaker_client = boto_session.client("sagemaker")
sagemaker_execution_role = sagemaker.session.get_execution_role(sagemaker_session)
print(f"Assumed SageMaker role is {sagemaker_execution_role}")

---
## 2. Compile pretrained model

In [None]:
import torch
import torch_neuronx
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig
import timeit
from timeit import default_timer as timer

# MODEL_ID="facebook/esm2_t33_650M_UR50D"
# MODEL_ID="facebook/esm2_t12_35M_UR50D"
MODEL_ID = "facebook/esm2_t6_8M_UR50D"


tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForMaskedLM.from_pretrained(MODEL_ID, torchscript=True)
model.eval()

sequence = "QVQLVESGGGVVQPRSLTLSCAASGFTFSSYGL<mask>HWVRQAPGKGLEWVANIWYDGANKYYGDSVKGRFTISRDNSRNTLYLQMNSLTAEDTAVYYCARWIEYGSGKDAFDVWGQGTMVIVSS"
max_length = 128
tokenized_sequence = tokenizer.encode_plus(
    sequence,
    max_length=max_length,
    padding="max_length",
    truncation=True,
    return_tensors="pt",
)
tracing_input = tokenized_sequence["input_ids"], tokenized_sequence["attention_mask"]

print("Testing model inference")
print(model(*tracing_input)[0])


print("Beginning model trace")
model_trace_start_time = timer()
neuron_model = torch_neuronx.trace(model, tracing_input)
neuron_model.save("traced_esm.pt")
print(f"Model trace completed in {round(timer() - model_trace_start_time, 0)} seconds.")

---
## 3. Assemble model package

In [None]:
!tar -czvf model.tar.gz traced_esm.pt

In [None]:
S3_PREFIX = "compiled-model"
S3_PATH = sagemaker.s3.s3_path_join("s3://", S3_BUCKET, S3_PREFIX)
print(f"S3 path is {S3_PATH}")
s3_model_uri = sagemaker_session.upload_data("model.tar.gz", S3_BUCKET, S3_PREFIX)
print(f"Model artifact uploaded to {s3_model_uri}")

---
## 4. Define inference script

We define the code needed to load our neuron-compiled model in a `inference.py` file.

In [None]:
%%writefile scripts/requirements.txt
--extra-index-url=https://pip.repos.neuron.amazonaws.com
transformers
torch-neuronx==1.13.1.1.10.1

In [None]:
%%writefile scripts/inference.py

import os
import json
import torch
import torch_neuronx
from transformers import AutoTokenizer

JSON_CONTENT_TYPE = "application/json"
# MODEL_ID = "facebook/esm2_t33_650M_UR50D"
# MODEL_ID = "facebook/esm2_t12_35M_UR50D"
MODEL_ID = "facebook/esm2_t6_8M_UR50D"

def model_fn(model_dir):
    """Load the model from HuggingFace"""
    print(f"torch-neuronx version is {torch_neuronx.__version__}")
    tokenizer_init = AutoTokenizer.from_pretrained(MODEL_ID)
    model_file = os.path.join(model_dir, "traced_esm.pt")
    neuron_model = torch.jit.load(model_file)
    return (neuron_model, tokenizer_init)

def input_fn(serialized_input_data, content_type=JSON_CONTENT_TYPE):
    """ Process the request payload """
    
    if content_type == JSON_CONTENT_TYPE:
        input_data = json.loads(serialized_input_data)
        return input_data.pop("inputs", input_data)
    else:
        raise Exception("Requested unsupported ContentType in Accept: " + content_type)
        return


def predict_fn(input_data, model_and_tokenizer):
    """ Run model inference """
    
    model_bert, tokenizer = model_and_tokenizer
    max_length = 128
    tokenized_sequence = tokenizer.encode_plus(
        input_data,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    prediction_input = (
        tokenized_sequence["input_ids"],
        tokenized_sequence["attention_mask"],
    )
    output = neuron_model(*prediction_input)[0]
    mask_token_index = (tokenized_sequence.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
    mask_index_predictions = output[0, mask_token_index]
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(mask_index_predictions)
    return {
        list(tokenizer.get_vocab().keys())[idx]: round(v.item(), 3)
        for idx, v in enumerate(probs[0])
    }


def output_fn(prediction_output, accept=JSON_CONTENT_TYPE):
    """ Process the response payload """
    if accept == JSON_CONTENT_TYPE:
        return json.dumps(prediction_output), accept

    raise Exception("Requested unsupported ContentType in Accept: " + accept)


---
## 5. Deploy inf2 model endpoint

In [None]:
%%time
from sagemaker.pytorch.model import PyTorchModel

ecr_image = f"763104351884.dkr.ecr.{sagemaker_session.boto_region_name}.amazonaws.com/pytorch-inference-neuronx:1.13.0-neuronx-py38-sdk2.9.0-ubuntu20.04"

inf2_model = PyTorchModel(
    model_data=s3_model_uri,
    role=sagemaker_execution_role,
    sagemaker_session=sagemaker_session,
    source_dir="scripts",
    entry_point="inference.py",
    image_uri=ecr_image,
)

# Let SageMaker know that we've already compiled the model via neuron-cc
inf2_model._is_compiled_model = True

inf2_predictor = inf2_model.deploy(
    instance_type="ml.inf2.xlarge", initial_instance_count=1
)

In [None]:
# %%time

# from sagemaker.huggingface.model import HuggingFaceModel

# ecr_image = f"763104351884.dkr.ecr.{sagemaker_session.boto_region_name}.amazonaws.com/huggingface-pytorch-inference-neuronx:1.13.0-transformers4.28.1-neuronx-py38-sdk2.9.1-ubuntu20.04"
# inf2_model = HuggingFaceModel(
#     model_data = s3_model_uri,
#     role=sagemaker_execution_role,
#     source_dir="scripts",
#     entry_point="inference.py",
#     image_uri=ecr_image
# )
# inf2_model._is_compiled_model = True

# inf2_predictor = inf2_model.deploy(
#     instance_type="ml.inf2.xlarge", initial_instance_count=1
# )

In [None]:
example = "QVQLVESGGGVVQ<mask>PGRSLTLSCAASGFTFSSYGLHWVRQAPGKGLE"
inf2_predictor.predict({"inputs": example})

---
## 6. Benchmark endpoint

Deploy the same model to a ml.g5.xlarge endpoint.

In [None]:
%%time

from sagemaker.huggingface.model import HuggingFaceModel

hub = {"HF_MODEL_ID": MODEL_ID, "HF_TASK": "fill-mask"}

ecr_image = f"763104351884.dkr.ecr.{sagemaker_session.boto_region_name}.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04"

# create Hugging Face Model Class
g5_model = HuggingFaceModel(env=hub, role=sagemaker_execution_role, image_uri=ecr_image)

g5_predictor = g5_model.deploy(instance_type="ml.g5.xlarge", initial_instance_count=1)

In [None]:
example = "QVQLVESGGGVVQ<mask>PGRSLTLSCAASGFTFSSYGLHWVRQAPGKGLE"
g5_predictor.predict({"inputs": example})

Define some benchmarking helpers

In [None]:
import numpy as np
import datetime
import math
import time
import boto3
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
import numpy as np
from tqdm import tqdm
import random
from datasets import load_dataset, DatasetDict


def inference_latency(model, inputs):
    """
    inference_latency is a simple method to return the latency of a model inference.

        Parameters:
            model: torch model onbject loaded using torch.jit.load
            inputs: model() args

        Returns:
            latency in seconds
    """
    error = False
    start = time.time()
    try:
        results = model(inputs)
    except:
        error = True
        results = []
    return {"latency": time.time() - start, "error": error, "result": results}


uniref = (
    load_dataset("bloyal/small-uniref30", split="test")
    .remove_columns(["id", "num"])
    .shuffle()["text"]
)


def random_sequence(max_length=128):
    seq = random.choice(uniref)
    seq = list(seq[:max_length])
    seq[random.randint(0, len(seq) - 1)] = "<mask>"
    seq = "".join(seq)
    return seq


def benchmark(predictor, number_of_clients=2, number_of_runs=10000):
    # Defining Auxiliary variable

    t = tqdm(range(number_of_runs), position=0, leave=True)

    # Starting parallel clients
    cw_start = datetime.datetime.utcnow()

    results = Parallel(n_jobs=number_of_clients, prefer="threads")(
        delayed(inference_latency)(predictor.predict, {"inputs": random_sequence()})
        for mod in t
    )
    avg_throughput = t.total / t.format_dict["elapsed"]

    cw_end = datetime.datetime.utcnow()

    # Computing metrics and print
    latencies = [res["latency"] for res in results]
    errors = [res["error"] for res in results]
    error_p = sum(errors) / len(errors) * 100
    p50 = np.quantile(latencies[-10000:], 0.50) * 1000
    p90 = np.quantile(latencies[-10000:], 0.95) * 1000
    p95 = np.quantile(latencies[-10000:], 0.99) * 1000

    print(f"Avg Throughput: :{avg_throughput:.1f}\n")
    print(f"50th Percentile Latency:{p50:.1f} ms")
    print(f"90th Percentile Latency:{p90:.1f} ms")
    print(f"95th Percentile Latency:{p95:.1f} ms\n")
    print(f"Errors percentage: {error_p:.1f} %\n")

    # Querying CloudWatch
    print("Getting Cloudwatch:")
    cloudwatch = boto3.client("cloudwatch")
    statistics = ["SampleCount", "Average", "Minimum", "Maximum"]
    extended = ["p50", "p90", "p95", "p100"]

    # Give 5 minute buffer to end
    cw_end += datetime.timedelta(minutes=5)

    # Period must be 1, 5, 10, 30, or multiple of 60
    # Calculate closest multiple of 60 to the total elapsed time
    factor = math.ceil((cw_end - cw_start).total_seconds() / 60)
    period = factor * 60
    print("Time elapsed: {} seconds".format((cw_end - cw_start).total_seconds()))
    print("Using period of {} seconds\n".format(period))

    cloudwatch_ready = False
    # Keep polling CloudWatch metrics until datapoints are available
    while not cloudwatch_ready:
        time.sleep(30)
        print("Waiting 30 seconds ...")
        # Must use default units of microseconds
        model_latency_metrics = cloudwatch.get_metric_statistics(
            MetricName="ModelLatency",
            Dimensions=[
                {"Name": "EndpointName", "Value": predictor.endpoint_name},
                {"Name": "VariantName", "Value": "AllTraffic"},
            ],
            Namespace="AWS/SageMaker",
            StartTime=cw_start,
            EndTime=cw_end,
            Period=period,
            Statistics=statistics,
            ExtendedStatistics=extended,
        )

        if len(model_latency_metrics["Datapoints"]) > 0:
            print(
                "{} latency datapoints ready".format(
                    model_latency_metrics["Datapoints"][0]["SampleCount"]
                )
            )
            side_avg = model_latency_metrics["Datapoints"][0]["Average"] / 1000
            side_p50 = (
                model_latency_metrics["Datapoints"][0]["ExtendedStatistics"]["p50"]
                / 1000
            )
            side_p90 = (
                model_latency_metrics["Datapoints"][0]["ExtendedStatistics"]["p90"]
                / 1000
            )
            side_p95 = (
                model_latency_metrics["Datapoints"][0]["ExtendedStatistics"]["p95"]
                / 1000
            )
            side_p100 = (
                model_latency_metrics["Datapoints"][0]["ExtendedStatistics"]["p100"]
                / 1000
            )

            print(f"50th Percentile Latency:{side_p50:.1f} ms")
            print(f"90th Percentile Latency:{side_p90:.1f} ms")
            print(f"95th Percentile Latency:{side_p95:.1f} ms\n")

            cloudwatch_ready = True

In [None]:
benchmark(g5_predictor)

In [None]:
benchmark(inf2_predictor)

In [None]:
# inf2_predictor.delete_endpoint()

In [None]:
try:
    inf2_predictor.delete_endpoint()
    g5_predictor.delete_endpoint()
except:
    pass