# Serve GPT-J-6B on SageMaker with DJL Serving using the SageMaker Python SDK

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. 

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)

---

## Background

This notebook will illustrate how one can use [DJL Serving](https://sagemaker.readthedocs.io/en/stable/frameworks/djl/using_djl.html) to deploy text generation models like GPT-J-6B on SageMaker for real-time inference.


### Update and install required dependencies

In [None]:
%%bash
pip install -U pip --quiet
pip install -U sagemaker --quiet
pip install -U boto3 --quiet
pip install -U transformers --quiet

### Configure instance type, S3 bucket etc

In [None]:
import sagemaker

# Replace with your own settings - we recommend the ml.g5.12xlarge instance for this notebook
instance_type = "ml.g5.12xlarge"
role = sagemaker.get_execution_role()  # execution role for the endpoint
session = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = session._region_name

### Specifying the Model

With DJL Serving on SageMaker, one can provide the model artifacts in 2 ways.
1. A HuggingFace Hub model ID
2. Model Artifacts uploaded to S3 saved in the HuggingFace pretrained format.

For more details on the model artifact structure, please see the [DJL Serving SageMaker docs](https://sagemaker.readthedocs.io/en/stable/frameworks/djl/using_djl.html#model-artifacts).

We highly recommend the S3 option. DJL Serving implements a fast downloading mechanism for models stored in S3, which can significantly reduce the endpoint startup time compared to using a HuggingFace Hub model ID. For production use-cases, we recommend storing model artifacts in S3. For experimentation and smaller models, using a HuggingFace Hub model ID is the easiest way to get started.

We will demonstrate both options in this notebook.

#### Using a HuggingFace Hub model ID

The EleutherAI/gpt-j-6b model is available from HuggingFace [here](https://huggingface.co/EleutherAI/gpt-j-6b).

In [None]:
model_id = "EleutherAI/gpt-j-6b"

#### Using a Pretrained Model stored in S3

The following code demonstrates how to save a model and upload it to S3 so that it can be downloaded and served by SageMaker using DJL Serving.

In [None]:
# from sagemaker.s3 import S3Uploader
# from transformers import AutoTokenizer, AutoModelForCausalLM
# model = AutoModel.from_pretrained(model_id)
# model.save_pretrained("gpt-j-6B")

# tokenizer = AutoTokenizer.from_pretrained(model_id)
# tokenizer.save_pretrained("gpt-j-6B")

# bucket = session.default_bucket()      # bucket to house artifacts
# s3_location = f"s3://{bucket}/djl-serving/gpt-j-6B"
# S3Uploader.upload("gpt-j-6B", s3_location)

For demo purposes, we already have a copy of this model available in S3 that we will use.

In [None]:
pretrained_model_location = f"s3://sagemaker-examples-files-prod-{region}/models/gpt-j-6b-model/"
print(f"Pretrained model will be downloaded from ---- > {pretrained_model_location}")

### Deploy the model to SageMaker

The following code is all that is needed to host this model on SageMaker. We will use the `pretrained_model_location` pointing to artifacts in S3 to reduce the container startup time, but feel free to experiment with using `model_id` directly.

In [None]:
from sagemaker.djl_inference import DJLModel

model = DJLModel(
    pretrained_model_location,  # C an also use model_id here
    role,
    task="text-generation",
    number_of_partitions=2,
    data_type="fp16",
)

predictor = model.deploy(initial_instance_count=1, instance_type=instance_type)

### Run inference using the endpoint

Once the endpoint is created and in-service, we can issue inference requests like this.

In [None]:
data = {
    "inputs": [
        "My favorite thing about Math is",
        "My least favorite thing about Math is",
    ],
    "parameters": {
        "max_length": 200,
        "temperature": 0.1,
    },
}
outputs = predictor.predict(data)
for output in outputs:
    print(output[0]["generated_text"])

### Clean up resources after testing

In [None]:
# Delete SageMaker endpoint and model
predictor.delete_endpoint()
model.delete_model()

### Extensions

In this notebook we demonstrated how to deploy the GPT-J-6B model for text generation on SageMaker. DJL Serving supports a wide variety of models and NLP tasks. This notebook is a great starting point to experiment with other models and NLP tasks.

For example, we can host the [flan-t5-xl](https://huggingface.co/google/flan-t5-xl) model for translation by changing `model_id=google/flan-t5-xl` and `task="text2text-generation` in DJLModel.

For more information on using DJL Serving on SageMaker with the Python SDK, please see our [documentation](https://sagemaker.readthedocs.io/en/stable/frameworks/djl/index.html).

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/inference|generativeai|deepspeed|GPT-J-6B_DJLServing_with_PySDK.ipynb)
