# AWS Machine Learning Purpose-built Accelerators Tutorial
## Learn how to use [AWS Trainium](https://aws.amazon.com/machine-learning/trainium/) and [AWS Inferentia](https://aws.amazon.com/machine-learning/inferentia/) with [Amazon SageMaker](https://aws.amazon.com/sagemaker/), to optimize your ML workload
## Part 3/3 - Compiling and deploying a Bert model to AWS Inferentia 2 with SageMaker + [Hugging Face Optimum Neuron](https://huggingface.co/docs/optimum-neuron/index)

**SageMaker studio Kernel: PyTorch 1.13 Python 3.9 CPU - ml.t3.medium** 

In this tutorial, you'll learn how to compile a model to AWS Inferentia and then deploy it to a SageMaker real-time endpoint powered by AWS Inferentia2. First we'll kick-off a SageMaker job to compile the model. We need to do this once. After that, we can deploy our model to a SageMaker endpoint and finally get some predictions.

In section 02, you extract some metadata from the Optimum Neuron API and render a table with the current tested/supported models (similar models not listed there can also be compatible, but you need to check by yourself). This table is important for you to understand which models can be selected for deployment. However, if you also need to fine-tune your model, check a similar table in the notebook **Part 2** to see which models can be fine-tuned with AWS Trainium using HF Optimum Neuron. That way you can plan your end2end solution and start implementing it right now.

## 1) Install some required packages

In [None]:
%pip install -U optimum-neuron==0.0.8 onnx>=1.14.0

## 2) Supported models/tasks for Inference

In [None]:
import re
import pandas as pd
from IPython.display import HTML, Markdown
from optimum.exporters.tasks import TasksManager
from optimum.exporters.neuron.model_configs import *
from optimum.neuron.distributed.parallelizers_manager import ParallelizersManager
from optimum.neuron.utils.training_utils import (
    _SUPPORTED_MODEL_NAMES,
    _SUPPORTED_MODEL_TYPES,
    _generate_supported_model_class_names
)

In [None]:
# retrieve supported models for Tensor Parallelism
tp_support = list(ParallelizersManager._MODEL_TYPE_TO_PARALLEL_MODEL_CLASS.keys())

# build compability table for inference
meta = [(k,list(v['neuron'].keys())) for k,v in TasksManager._SUPPORTED_MODEL_TYPE.items() if v.get('neuron') is not None]
data_inference = {'Model': []}
for m,t in meta:
    model_id = len(data_inference['Model'])
    model_link = f'<a target="_new" href="https://huggingface.co/models?sort=trending&search={m}">{m}</a>'
    data_inference['Model'].append(f"{model_link} <font style='color: red;'><b>[TP]</b></font>" if m in tp_support else model_link)
    for task in t:
        if data_inference.get(task) is None: data_inference[task] = [''] * len(meta)
        data_inference[task][model_id] = f'<a target="_new" href="https://huggingface.co/models?pipeline_tag={task}&sort=trending&search={m}">list</a>'

df_inference = pd.DataFrame.from_dict(data_inference).set_index('Model')

In each new release of HF Optimum Neuron, support for new models is added. So, it is expected to see different values for the following tables when you upgrade the library.

Models with **[TP]** after the name support Tensor Parallelism

In [106]:
Markdown(df_inference.to_markdown())

| Model                                                                                                                                       | feature-extraction                                                                                                                  | fill-mask                                                                                                              | multiple-choice                                                                                                              | question-answering                                                                                                              | text-classification                                                                                                              | token-classification                                                                                                              | zero-shot-image-classification                                                                                                       | stable-diffusion                                                                                                                  | text-generation                                                                                                       | semantic-segmentation                                                                                                              |
|:--------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------|
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=albert">albert</a>                                                | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=albert">list</a>          | <a target="_new" href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=trending&search=albert">list</a>      | <a target="_new" href="https://huggingface.co/models?pipeline_tag=multiple-choice&sort=trending&search=albert">list</a>      | <a target="_new" href="https://huggingface.co/models?pipeline_tag=question-answering&sort=trending&search=albert">list</a>      | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-classification&sort=trending&search=albert">list</a>      | <a target="_new" href="https://huggingface.co/models?pipeline_tag=token-classification&sort=trending&search=albert">list</a>      |                                                                                                                                      |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=bert">bert</a> <font style='color: red;'><b>[TP]</b></font>       | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=bert">list</a>            | <a target="_new" href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=trending&search=bert">list</a>        | <a target="_new" href="https://huggingface.co/models?pipeline_tag=multiple-choice&sort=trending&search=bert">list</a>        | <a target="_new" href="https://huggingface.co/models?pipeline_tag=question-answering&sort=trending&search=bert">list</a>        | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-classification&sort=trending&search=bert">list</a>        | <a target="_new" href="https://huggingface.co/models?pipeline_tag=token-classification&sort=trending&search=bert">list</a>        |                                                                                                                                      |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=camembert">camembert</a>                                          | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=camembert">list</a>       | <a target="_new" href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=trending&search=camembert">list</a>   | <a target="_new" href="https://huggingface.co/models?pipeline_tag=multiple-choice&sort=trending&search=camembert">list</a>   | <a target="_new" href="https://huggingface.co/models?pipeline_tag=question-answering&sort=trending&search=camembert">list</a>   | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-classification&sort=trending&search=camembert">list</a>   | <a target="_new" href="https://huggingface.co/models?pipeline_tag=token-classification&sort=trending&search=camembert">list</a>   |                                                                                                                                      |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=clip">clip</a>                                                    | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=clip">list</a>            |                                                                                                                        |                                                                                                                              |                                                                                                                                 |                                                                                                                                  |                                                                                                                                   | <a target="_new" href="https://huggingface.co/models?pipeline_tag=zero-shot-image-classification&sort=trending&search=clip">list</a> |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=clip-text-model">clip-text-model</a>                              | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=clip-text-model">list</a> |                                                                                                                        |                                                                                                                              |                                                                                                                                 |                                                                                                                                  |                                                                                                                                   |                                                                                                                                      | <a target="_new" href="https://huggingface.co/models?pipeline_tag=stable-diffusion&sort=trending&search=clip-text-model">list</a> |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=convbert">convbert</a>                                            | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=convbert">list</a>        | <a target="_new" href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=trending&search=convbert">list</a>    | <a target="_new" href="https://huggingface.co/models?pipeline_tag=multiple-choice&sort=trending&search=convbert">list</a>    | <a target="_new" href="https://huggingface.co/models?pipeline_tag=question-answering&sort=trending&search=convbert">list</a>    | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-classification&sort=trending&search=convbert">list</a>    | <a target="_new" href="https://huggingface.co/models?pipeline_tag=token-classification&sort=trending&search=convbert">list</a>    |                                                                                                                                      |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=deberta">deberta</a>                                              | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=deberta">list</a>         | <a target="_new" href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=trending&search=deberta">list</a>     | <a target="_new" href="https://huggingface.co/models?pipeline_tag=multiple-choice&sort=trending&search=deberta">list</a>     | <a target="_new" href="https://huggingface.co/models?pipeline_tag=question-answering&sort=trending&search=deberta">list</a>     | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-classification&sort=trending&search=deberta">list</a>     | <a target="_new" href="https://huggingface.co/models?pipeline_tag=token-classification&sort=trending&search=deberta">list</a>     |                                                                                                                                      |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=deberta-v2">deberta-v2</a>                                        | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=deberta-v2">list</a>      | <a target="_new" href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=trending&search=deberta-v2">list</a>  | <a target="_new" href="https://huggingface.co/models?pipeline_tag=multiple-choice&sort=trending&search=deberta-v2">list</a>  | <a target="_new" href="https://huggingface.co/models?pipeline_tag=question-answering&sort=trending&search=deberta-v2">list</a>  | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-classification&sort=trending&search=deberta-v2">list</a>  | <a target="_new" href="https://huggingface.co/models?pipeline_tag=token-classification&sort=trending&search=deberta-v2">list</a>  |                                                                                                                                      |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=distilbert">distilbert</a>                                        | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=distilbert">list</a>      | <a target="_new" href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=trending&search=distilbert">list</a>  | <a target="_new" href="https://huggingface.co/models?pipeline_tag=multiple-choice&sort=trending&search=distilbert">list</a>  | <a target="_new" href="https://huggingface.co/models?pipeline_tag=question-answering&sort=trending&search=distilbert">list</a>  | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-classification&sort=trending&search=distilbert">list</a>  | <a target="_new" href="https://huggingface.co/models?pipeline_tag=token-classification&sort=trending&search=distilbert">list</a>  |                                                                                                                                      |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=electra">electra</a>                                              | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=electra">list</a>         | <a target="_new" href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=trending&search=electra">list</a>     | <a target="_new" href="https://huggingface.co/models?pipeline_tag=multiple-choice&sort=trending&search=electra">list</a>     | <a target="_new" href="https://huggingface.co/models?pipeline_tag=question-answering&sort=trending&search=electra">list</a>     | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-classification&sort=trending&search=electra">list</a>     | <a target="_new" href="https://huggingface.co/models?pipeline_tag=token-classification&sort=trending&search=electra">list</a>     |                                                                                                                                      |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=flaubert">flaubert</a>                                            | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=flaubert">list</a>        | <a target="_new" href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=trending&search=flaubert">list</a>    | <a target="_new" href="https://huggingface.co/models?pipeline_tag=multiple-choice&sort=trending&search=flaubert">list</a>    | <a target="_new" href="https://huggingface.co/models?pipeline_tag=question-answering&sort=trending&search=flaubert">list</a>    | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-classification&sort=trending&search=flaubert">list</a>    | <a target="_new" href="https://huggingface.co/models?pipeline_tag=token-classification&sort=trending&search=flaubert">list</a>    |                                                                                                                                      |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=gpt2">gpt2</a>                                                    |                                                                                                                                     |                                                                                                                        |                                                                                                                              |                                                                                                                                 |                                                                                                                                  |                                                                                                                                   |                                                                                                                                      |                                                                                                                                   | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-generation&sort=trending&search=gpt2">list</a> |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=mobilebert">mobilebert</a>                                        | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=mobilebert">list</a>      | <a target="_new" href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=trending&search=mobilebert">list</a>  | <a target="_new" href="https://huggingface.co/models?pipeline_tag=multiple-choice&sort=trending&search=mobilebert">list</a>  | <a target="_new" href="https://huggingface.co/models?pipeline_tag=question-answering&sort=trending&search=mobilebert">list</a>  | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-classification&sort=trending&search=mobilebert">list</a>  | <a target="_new" href="https://huggingface.co/models?pipeline_tag=token-classification&sort=trending&search=mobilebert">list</a>  |                                                                                                                                      |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=mpnet">mpnet</a>                                                  | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=mpnet">list</a>           | <a target="_new" href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=trending&search=mpnet">list</a>       | <a target="_new" href="https://huggingface.co/models?pipeline_tag=multiple-choice&sort=trending&search=mpnet">list</a>       | <a target="_new" href="https://huggingface.co/models?pipeline_tag=question-answering&sort=trending&search=mpnet">list</a>       | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-classification&sort=trending&search=mpnet">list</a>       | <a target="_new" href="https://huggingface.co/models?pipeline_tag=token-classification&sort=trending&search=mpnet">list</a>       |                                                                                                                                      |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=roberta">roberta</a> <font style='color: red;'><b>[TP]</b></font> | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=roberta">list</a>         | <a target="_new" href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=trending&search=roberta">list</a>     | <a target="_new" href="https://huggingface.co/models?pipeline_tag=multiple-choice&sort=trending&search=roberta">list</a>     | <a target="_new" href="https://huggingface.co/models?pipeline_tag=question-answering&sort=trending&search=roberta">list</a>     | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-classification&sort=trending&search=roberta">list</a>     | <a target="_new" href="https://huggingface.co/models?pipeline_tag=token-classification&sort=trending&search=roberta">list</a>     |                                                                                                                                      |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=roformer">roformer</a>                                            | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=roformer">list</a>        | <a target="_new" href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=trending&search=roformer">list</a>    | <a target="_new" href="https://huggingface.co/models?pipeline_tag=multiple-choice&sort=trending&search=roformer">list</a>    | <a target="_new" href="https://huggingface.co/models?pipeline_tag=question-answering&sort=trending&search=roformer">list</a>    | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-classification&sort=trending&search=roformer">list</a>    | <a target="_new" href="https://huggingface.co/models?pipeline_tag=token-classification&sort=trending&search=roformer">list</a>    |                                                                                                                                      |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=unet">unet</a>                                                    |                                                                                                                                     |                                                                                                                        |                                                                                                                              |                                                                                                                                 |                                                                                                                                  |                                                                                                                                   |                                                                                                                                      | <a target="_new" href="https://huggingface.co/models?pipeline_tag=stable-diffusion&sort=trending&search=unet">list</a>            |                                                                                                                       | <a target="_new" href="https://huggingface.co/models?pipeline_tag=semantic-segmentation&sort=trending&search=unet">list</a>        |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=vae-encoder">vae-encoder</a>                                      |                                                                                                                                     |                                                                                                                        |                                                                                                                              |                                                                                                                                 |                                                                                                                                  |                                                                                                                                   |                                                                                                                                      | <a target="_new" href="https://huggingface.co/models?pipeline_tag=stable-diffusion&sort=trending&search=vae-encoder">list</a>     |                                                                                                                       | <a target="_new" href="https://huggingface.co/models?pipeline_tag=semantic-segmentation&sort=trending&search=vae-encoder">list</a> |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=vae-decoder">vae-decoder</a>                                      |                                                                                                                                     |                                                                                                                        |                                                                                                                              |                                                                                                                                 |                                                                                                                                  |                                                                                                                                   |                                                                                                                                      | <a target="_new" href="https://huggingface.co/models?pipeline_tag=stable-diffusion&sort=trending&search=vae-decoder">list</a>     |                                                                                                                       | <a target="_new" href="https://huggingface.co/models?pipeline_tag=semantic-segmentation&sort=trending&search=vae-decoder">list</a> |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=xlm">xlm</a>                                                      | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=xlm">list</a>             | <a target="_new" href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=trending&search=xlm">list</a>         | <a target="_new" href="https://huggingface.co/models?pipeline_tag=multiple-choice&sort=trending&search=xlm">list</a>         | <a target="_new" href="https://huggingface.co/models?pipeline_tag=question-answering&sort=trending&search=xlm">list</a>         | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-classification&sort=trending&search=xlm">list</a>         | <a target="_new" href="https://huggingface.co/models?pipeline_tag=token-classification&sort=trending&search=xlm">list</a>         |                                                                                                                                      |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |
| <a target="_new" href="https://huggingface.co/models?sort=trending&search=xlm-roberta">xlm-roberta</a>                                      | <a target="_new" href="https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending&search=xlm-roberta">list</a>     | <a target="_new" href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=trending&search=xlm-roberta">list</a> | <a target="_new" href="https://huggingface.co/models?pipeline_tag=multiple-choice&sort=trending&search=xlm-roberta">list</a> | <a target="_new" href="https://huggingface.co/models?pipeline_tag=question-answering&sort=trending&search=xlm-roberta">list</a> | <a target="_new" href="https://huggingface.co/models?pipeline_tag=text-classification&sort=trending&search=xlm-roberta">list</a> | <a target="_new" href="https://huggingface.co/models?pipeline_tag=token-classification&sort=trending&search=xlm-roberta">list</a> |                                                                                                                                      |                                                                                                                                   |                                                                                                                       |                                                                                                                                    |

## 3) Compiling a pre-trained model to AWS Inferentia

In [None]:
import os
import boto3
import sagemaker

print(sagemaker.__version__)
if not sagemaker.__version__ >= "2.146.0": print("You need to upgrade or restart the kernel if you already upgraded")

sess = sagemaker.Session()
role = sagemaker.get_execution_role()
bucket = sess.default_bucket()
region = sess.boto_region_name

if not os.path.isdir('src'): os.makedirs('src', exist_ok=True)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {bucket}")
print(f"sagemaker session region: {region}")

### 3.1) Re-pack the checkpoints from the fine tuning job (previous notebook)

In this step, you'll download the checkpoints create in the fine tuning job and re-pack them for compiling them to Inferentia2.

Copy the SageMaker training job name from the previous notebook **02_ModelFineTuning** or from your AWS Console/SageMaker and set the variable **training_job_name**.

In [None]:
import os
import io
import tarfile

training_job_name=""
if os.path.isfile("training_job_name.txt"): training_job_name = open("training_job_name.txt", "r").read()
assert len(training_job_name)>0, "Please copy the name of the training_job you ran in the previous notebook and set training_job_name"

# Extract the artifacts
if not os.path.isdir("checkpoints"):
    downloader = sagemaker.s3.S3Downloader()
    print("Download checkpoints from S3...")
    data = downloader.read_bytes(f"s3://{bucket}/output/{training_job_name}/output/model.tar.gz")
    print("Extracting package...")
    with tarfile.open(fileobj=io.BytesIO(data), mode='r:gz') as tar:
        tar.extractall("checkpoints")    

#### 3.1.1) Upload the new re-packed checkpoint to SageMaker
These artifacts will be used by our compilation script (defined below) to compile the model.

In [None]:
import glob
checkpoints = glob.glob("checkpoints/checkpoint-*")
assert len(checkpoints)>0, "No checkpoint found in the directory model"
print(f"Uploading checkpoint: {checkpoints[0]} ...")
s3_checkpoint_uri = sess.upload_data(checkpoints[0], bucket=bucket, key_prefix="models/spam-classifier")
print(f"S3 URI: {s3_checkpoint_uri}")

### 3.2) Compiling script that will be invoked by SageMaker

In [None]:
%%writefile src/compile.py
import os
os.environ['NEURON_RT_NUM_CORES'] = '1'
import sys
import glob
import json
import torch
import shutil
import logging
import argparse
import traceback
import optimum.neuron
from transformers import AutoTokenizer

TASK="<<TASK>>"

def model_fn(model_dir, context=None):
    global TASK
    if "TASK" in TASK: raise Exception("Invalid TASK. You need to invoke the compilation job once to set TASK variable")
        
    NeuronModel = eval(f"optimum.neuron.NeuronModelFor{TASK}")
    
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    
    model = NeuronModel.from_pretrained(model_dir)
    return model,tokenizer

def input_fn(input_data, content_type, context=None):
    if content_type == 'application/json':
        req = json.loads(input_data)
        prompt = req.get('prompt')
        if prompt is None or len(prompt) < 3:
            raise("Invalid prompt. Provide an input like: {'prompt': 'text text text'}")
        return prompt
    else:
        raise Exception(f"Unsupported mime type: {content_type}. Supported: application/json")    

def predict_fn(input_object, model_tokenizer, context=None):
    model,tokenizer = model_tokenizer
    inputs = tokenizer(input_object, truncation=True, return_tensors="pt")
    logits = model(**inputs).logits
    idx = logits.argmax(1, keepdim=True)
    conf = torch.gather(logits, 1, idx)
    return torch.cat([idx,conf], 1)    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # hyperparameters sent by the client are passed as command-line arguments to the script.    
    parser.add_argument("--task", type=str, default="")
    parser.add_argument("--input_shapes", type=str, required=True)
    
    parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"])    
    parser.add_argument("--checkpoint_dir", type=str, default=os.environ["SM_CHANNEL_CHECKPOINT"])
    
    try:
        args, _ = parser.parse_known_args()
        
        # Set up logging        
        logging.basicConfig(
            level=logging.getLevelName("INFO"),
            handlers=[logging.StreamHandler(sys.stdout)],
            format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        )
        logger = logging.getLogger(__name__)
        logger.info(args)

        NeuronModel = eval(f"optimum.neuron.NeuronModel{'For' + args.task if len(args.task) > 0 else ''}")
        logger.info(f"Checkpoint files: {os.listdir(args.checkpoint_dir)}")
        
        input_shapes = json.loads(args.input_shapes)
        model = NeuronModel.from_pretrained(args.checkpoint_dir, export=True, **input_shapes)
        model.save_pretrained(args.model_dir)

        code_path = os.path.join(args.model_dir, 'code')
        os.makedirs(code_path, exist_ok=True)
        
        with open(__file__, 'r') as f:
            content = f.read()
            content = content.replace("<<TASK>>", "SequenceClassification")
            with open(os.path.join(code_path, "inference.py"), "w") as i:
                i.write(content)
        shutil.copyfile('requirements.txt', os.path.join(code_path, 'requirements.txt'))
    except Exception as e:
        print(traceback.format_exc())
        sys.exit(1)
        
    finally:
        print("Done! ", sys.exc_info())
        sys.exit(0)

In [None]:
%%writefile src/requirements.txt
neuronx-distributed
## 4.30 or higher is required
transformers==4.30.0
optimum-neuron==0.0.8

### 3.3) SageMaker Estimator
This object will help you to configure the compilation job (SageMaker Training Job).

This job will invoke **compile.py** script, which will compile our model to Inferentia2 and than save the artifacts for deployment.

In [None]:
import json
from sagemaker.pytorch import PyTorch

input_shapes={"batch_size": 1, "sequence_length": 512}

estimator = PyTorch(
    entry_point="compile.py", # Specify your train script
    source_dir="src",
    role=role,
    sagemaker_session=sess,
    instance_count=1,
    instance_type='ml.trn1.2xlarge',
    output_path=f"s3://{bucket}/output",
    disable_profiler=True,
    
    image_uri=f"763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.12.0-ubuntu20.04",
    
    volume_size = 512,
    hyperparameters={     
        "task": "SequenceClassification",
        "input_shapes": f"'{json.dumps(input_shapes)}'",
    }
)
estimator.framework_version = '1.13.1' # workround when using image_uri

In [None]:
estimator.fit({"checkpoint": s3_checkpoint_uri})

## 4) Deploy a SageMaker real-time endpoint

In [None]:
import logging
from sagemaker.utils import name_from_base
from sagemaker.pytorch.model import PyTorchModel

# depending on the inf2 instance you deploy the model you'll have more or less accelerators
# we'll ask SageMaker to launch 1 worker per core

model_data=estimator.model_data
print(f"Model data: {model_data}")

instance_type_idx=0 # default ml.inf2.xlarge
instance_types=['ml.inf2.xlarge', 'ml.inf2.8xlarge', 'ml.inf2.24xlarge','ml.inf2.48xlarge']
num_workers=[2,2,12,24]

print(f"Instance type: {instance_types[instance_type_idx]}. Num SM workers: {num_workers[instance_type_idx]}")
pytorch_model = PyTorchModel(
    image_uri=f"763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-inference-neuronx:1.13.1-neuronx-py310-sdk2.12.0-ubuntu20.04",
    model_data=model_data,
    role=role,    
    name=name_from_base('bert-spam-classifier'),
    sagemaker_session=sess,
    container_log_level=logging.DEBUG,
    model_server_workers=num_workers[instance_type_idx], # 1 worker per inferentia chip
    framework_version="1.13.1",
    env = {
        'SAGEMAKER_MODEL_SERVER_TIMEOUT' : '3600' 
    }
    # for production it is important to define vpc_config and use a vpc_endpoint
    #vpc_config={
    #    'Subnets': ['<SUBNET1>', '<SUBNET2>'],
    #    'SecurityGroupIds': ['<SECURITYGROUP1>', '<DEFAULTSECURITYGROUP>']
    #}
)
pytorch_model._is_compiled_model = True

In [None]:
predictor = pytorch_model.deploy(
    initial_instance_count=1,
    instance_type=instance_types[instance_type_idx],
    model_data_download_timeout=3600, # it takes some time to download all the artifacts and load the model
    container_startup_health_check_timeout=1800
)

## 5) Run a simple test

In [None]:
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()

In [105]:
import time

labels={0: "not spam", 1: "spam"}
not_spam=" Deezer.com 10,406,168 Artist DB\n\nWe have scraped the Deezer Artist DB, right now there are 10,406,168 listings according to Deezer.com\n\nPlease note in going through part of the list, it is obvious there are mistakes inside their system.\n\nExamples include and Artist with &amp; in its name might also be found with "and" but the Albums for each have different totals etc. Have no clue if there are duplicate albums etc do this error in their system. Even a comma in a name could mean the Artist shows up more than once, I saw in 1 instance that 1 Artist had 6 different ArtistIDs due to spelling errors.\n\nSo what is this DB, very simple, it gives you the ArtistID and the actual name of the Artist in another column. If you want to see the artist you add the baseurl to the ArtistID\n\nAn example is ArtistID 115 is AC/DC\n\n[https://www.deezer.com/us/artist/115](https://www.deezer.com/us/artist/115)\n\nYou do not have to use [https://www.deezer.com/us/artist/](https://www.deezer.com/us/artist/) if your first language is other than English, just see if Deezer supports your language and use that baseref\n\nFrench for example is [https://www.deezer.com/fr/artist/115](https://www.deezer.com/fr/artist/115)\n\nI am providing the DB in 3 different formats:\n\n \n\nI tried posting download links here but it seems Reddit does not like that so get them here:\n\n[https://pastebin\\[DOT\\]com/V3KJbgif](https://pastebin.com/V3KJbgif)\n\n&amp;#x200B;\n\n**Special thanks go to** [**/user/KoalaBear84**](https://www.reddit.com/user/KoalaBear84) **for writing the scraper.**\n\n&amp;#x200B;\n\n**Cross Posted to related Reddit Groups**"
spam="🚨 ATTENTION ALL USERS! 🚨\n\n🆘 Are you looking for a way to GET RICH QUICK? 🆘\n\n💰 Don't waste your time with boring old jobs! 💰\n\n💸 Join our CRAZY MONEY-MAKING SYSTEM today! 💸\n\n🤑 Just sign up and start earning BIG BUCKS right away! 🤑\n\n👉 Plus, if you refer your friends, you'll get even MORE CASH! 👈\n\n🔥 This is the HOTTEST OFFER of the year! 🔥\n\n👍 Don't wait"

for i,text in enumerate([not_spam, spam]):
    t=time.time()
    pred = predictor.predict({"prompt": text})
    elapsed = (time.time()-t)*1000
    print(f"Elapsed time: {elapsed}")
    print(f"Pred: {i} - {labels[pred[0][0]]} / score: {pred[0][1]}")

Elapsed time: 43.573856353759766
Pred: 0 - not spam / score: 4.785090923309326
Elapsed time: 33.95414352416992
Pred: 1 - spam / score: 4.689364910125732
