# TODO
 
* ## https://medium.com/analytics-vidhya/deploy-huggingface-s-bert-to-production-with-pytorch-serve-27b068026d18 

* ## https://github.com/pytorch/serve/tree/master/examples/Huggingface_Transformers

# Deploying our BERT PyTorch Model as REST EndPoint

In [None]:
!pip install -q transformers==2.8.0
!pip install -q torch==1.5.0 --upgrade --ignore-installed

In [None]:
!pip install torchserve

In [None]:
import boto3
import sagemaker
import pandas as pd

sess   = sagemaker.Session()
bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
region = boto3.Session().region_name
account_id = boto3.client('sts').get_caller_identity().get('Account')

sm = boto3.Session().client(service_name='sagemaker', region_name=region)

# Clone the TorchServe repository and install torch-model-archiver

You'll use `torch-model-archiver` to create a model archive file (.mar). The .mar model archive file contains model checkpoints along with it’s `state_dict` (dictionary object that maps each layer to its parameter tensor).

In [None]:
!pip install ./src_torchserve/serve/model-archiver/

# Retrieve PyTorch Models

In [None]:
%store -r s3_pytorch_model_path

In [None]:
print(s3_pytorch_model_path)

In [None]:
%store -r s3_transformer_pytorch_model_path

In [None]:
print(s3_transformer_pytorch_model_path)

In [None]:
!aws s3 cp --recursive $s3_transformer_pytorch_model_path ./Transformer_model/

# Create TorchServe Model Archive File

Once, setup_config.json, sample_text.txt and index_to_name.json are set properly, we can go ahead and package the model and start serving it. The artifacts realted to each operation mode (such as sample_text.txt, index_to_name.json) can be place in their respective folder. 

In [None]:
# !torch-model-archiver 
#    --model-name "bert" \
#    --version 1.0 \
#    --serialized-file ./bert_model/pytorch_model.bin \
#    --extra-files "./bert_model/config.json" \
#    --handler "./transformers_classifier_torchserve_handler.py"

In [None]:
model_name = 'DistilBertForSequenceClassification'

In [None]:
!torch-model-archiver \
    --model-name $model_name \
    --version 1.0 \
    --serialized-file Transformer_model/pytorch_model.bin \
    --handler ./src_torchserve/Transformer_handler_generalized.py \
    --extra-files "./Transformer_model/config.json,./src_torchserve/setup_config.json,./src_torchserve/Seq_classification_artifacts/index_to_name.json"

In [None]:
!ls ./*.mar

# Registering the Model on TorchServe and Running Inference

To register the model on TorchServe using the above model archive file, we run the following commands:

In [None]:
!mkdir -p ./model_store

In [None]:
!mv ./DistilBertForSequenceClassification.mar ./model_store/

# TorchServe requires Java 11 which is not installed by default in SageMaker Notebook Instances
https://tecadmin.net/install-java-on-amazon-linux/

In [None]:
# %%bash

# sudo amazon-linux-extras install java-openjdk11

In [None]:
# %%bash 

# torchserve \
# --start \
# --model-store ./model_store \
# --models distilbert-pytorch=DistilBertForSequenceClassification.mar &

## To run the inference using our registered model, open a new terminal and run: 

In [None]:
# !curl -X POST http://127.0.0.1:8080/predictions/distilbert-pytorch -T ./src_torchserve/Seq_classification_artifacts/sample_text.txt

# Prepare the Model for SageMaker Deployment

## Upload .mar to S3

In [None]:
torchserve_mar = 'DistilBertForSequenceClassification.mar'

In [None]:
s3_torchserve_mar = 's3://{}/models/torchserve/{}'.format(bucket, torchserve_mar)
print(s3_torchserve_mar)

In [None]:
!aws s3 cp ./model_store/$torchserve_mar $s3_torchserve_mar

In [None]:
%store s3_torchserve_mar

In [None]:
!tar cvfz ./DistilBertForSequenceClassification.tar.gz \
    ./model_store/DistilBertForSequenceClassification.mar


In [None]:
s3_torchserve_tar = 's3://{}/models/torchserve/DistilBertForSequenceClassification.tar.gz'.format(bucket)

In [None]:
!aws s3 cp ./DistilBertForSequenceClassification.tar.gz $s3_torchserve_tar

In [None]:
%store s3_torchserve_tar

### Create an Amazon ECR registry
Create a new docker container registry for your torchserve container images.

In [None]:
registry_name = 'torchserve'
!aws ecr create-repository --repository-name {registry_name}

### Build a TorchServe Docker container and push it to Amazon ECR

In [None]:
image_label = 'v1'
image = f'{account_id}.dkr.ecr.{region}.amazonaws.com/{registry_name}:{image_label}'

In [None]:
!docker build -t {registry_name}:{image_label} -f ./src_torchserve/Dockerfile ./src_torchserve
!$(aws ecr get-login --no-include-email --region {region})
!docker tag {registry_name}:{image_label} {image}
!docker push {image}

### Deploy endpoint and make prediction using Amazon SageMaker SDK

In [None]:
print(s3_torchserve_tar)

In [None]:
from sagemaker.model import Model
from sagemaker.predictor import RealTimePredictor

sm_model_name = 'distilbert-pytorch'

torchserve_model = Model(model_data = s3_torchserve_tar, 
                         image = image,
                         role  = role,
                         predictor_cls=RealTimePredictor,
                         name  = sm_model_name)

In [None]:
import time

endpoint_name = 'torchserve-endpoint-' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
print(endpoint_name)

predictor = torchserve_model.deploy(instance_type='ml.c5.4xlarge',
                                    initial_instance_count=1,
                                    endpoint_name = endpoint_name)

In [None]:
print(endpoint_name)

# _Wait Until the ^^ Endpoint ^^ is Deployed_

## Test the TorchServe hosted model

In [None]:
!cat ./src_torchserve/sample_text.txt

In [None]:
# file_name = './src_torchserve/sample_text.txt'
# with open(file_name, 'rb') as f:
#    payload = f.read()
#    payload = payload
   
# response = predictor.predict(data=payload)
# print(json.loads(response), sep = '\n')

In [None]:
import json
    
# reviews = ["This is great!", 
#            "This is terrible."]

predicted_classes = predictor.predict("This is great!")
print(predicted_classes.decode('utf-8'))

In [None]:
for predicted_class, review in zip(predicted_classes, reviews):
    print('[Predicted Star Rating: {}]'.format(predicted_class), review)