## Deploy Pixtral 12b on SageMaker using LMI v12 for Realtime inference

This notebook demonstrates how to deploy the [Pixtral-12b](https://huggingface.co/mistralai/Pixtral-12B-2409), gated model using the LMI v12 container. The deployment method uses [DJL serving](https://docs.aws.amazon.com/sagemaker/latest/dg/deploy-models-frameworks-djl-serving.html) and the LMI v12 (0.30.0) or latest.

The model is deployed on SageMaker Endpoints for Realtime inferencing on ml.g5.12xlarge instance.

Note: This model is not stored in the typical HuggingFace pretrained format, so more configurations are required to deploy this successfully. While there are community versions of this model that have been converted into the HuggingFace pretrained format, those models are not compatible with LMI v12 as they are not compatible with vLLM. 

If you have fine-tuned this model and saved the artifacts in the HuggingFace pretrained format, you will need to convert the artifacts back into the mistral format. You can read more about that process in this discussion: https://huggingface.co/mistral-community/pixtral-12b/discussions/4.

### Install Required dependencies

In [None]:
%pip install -Uq sagemaker boto3

## Create the SageMaker model object

We will initialize necessary variables that could be used throught the deployment.

In [None]:
import sagemaker
import boto3

sess = sagemaker.Session() # sagemaker session for interacting with different AWS APIs

sagemaker_session_bucket = None # bucket to house artifacts
if sagemaker_session_bucket is None and sess is not None:
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role() # execution role for the endpoint
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)
region = sess.boto_region_name
print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {region}")

The LMI container for djl-inference supports deploying popular vision language models like Pixtral with minimal set up. You can refer to the documentation for [Vision Language Models in LMI](https://github.com/deepjavalibrary/djl-serving/blob/master/serving/docs/lmi/user_guides/vision_language_models.md) for more details. The `mistralai/Pixtral-12B-2409` model will be downloaded from HuggingFace and it's a gated model. 

To access gated models from HuggingFace, you need token that needs to be set to the env variable `HF_TOKEN` in the following code. You can refer [Accessing Private/Gated Models](https://huggingface.co/docs/transformers.js/en/guides/private) to obtain the access token.

In [None]:
from sagemaker.djl_inference import DJLModel

image_uri =f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.30.0-lmi12.0.0-cu124" 

# You can also obtain the image_uri programatically as follows.
# image_uri = image_uris.retrieve(framework="djl-lmi", version="0.30.0", region="us-west-2")

model = DJLModel(
    role=role,
    image_uri=image_uri,
    env={
        "HF_MODEL_ID": "mistralai/Pixtral-12B-2409",
        "HF_TOKEN": "<REPLACE_WITH_YOUR_HF_TOKEN>", #since the model "mistralai/Pixtral-12B-2409" is gated model, you need HF_TOKEN
        "OPTION_ENGINE": "Python",
        "OPTION_MPI_MODE": "true",
        "OPTION_ROLLING_BATCH": "lmi-dist",
        "OPTION_MAX_MODEL_LEN": "8192", # this can be tuned depending on instance type + memory available
        "OPTION_MAX_ROLLING_BATCH_SIZE": "16", # this can be tuned depending on instance type + memory available
        "OPTION_TOKENIZER_MODE": "mistral",
        "OPTION_ENTRYPOINT": "djl_python.huggingface",
        "OPTION_TENSOR_PARALLEL_DEGREE": "max",
        "OPTION_LIMIT_MM_PER_PROMPT": "image=4", # this can be tuned to control how many images per prompt are allowed
    }
)

## Deploy the model

We will deploy the model by providing the necessary arguments. You can refer to list of parameters to deploy the model on SageMaker here https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.model.Model.deploy

In [None]:
predictor = model.deploy(instance_type="ml.g5.12xlarge", initial_instance_count=1)

## Test prompts

The following prompts demonstrate how to use the `mistralai/Pixtral-12B-2409` model for:
- Text only inference
- Single image inference
- Multi image inference
- Local image inference

For the multi image inference use-case, we use two images. However, the model is configured to accept up to 4 images in a single prompt. This setting can be tuned with the `OPTION_LIMIT_MM_PER_PROMPT` configuration.

In [None]:
IMAGE_1_KITTEN = "https://resources.djl.ai/images/kitten.jpg"
IMAGE_2_TRUCK = "https://resources.djl.ai/images/truck.jpg"

text_only_payload = {
    "messages": [
        {
            "role": "user",
            "content": "I would like to get better at basketball. Can you provide me a 3 month plan to improve my skills?"
        }
    ],
    "max_tokens": 1024,
    "temperature": 0.6,
    "top_p": 0.9,
}

single_image_payload = {
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Can you describe the following image and tell me what it contains?",
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": IMAGE_1_KITTEN
                    }
                }
            ]
        }
    ],
    "max_tokens": 1024,
    "temperature": 0.6,
    "top_p": 0.9,
}

multi_image_payload = {
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Can you describe the following images and tell me what they have in common? If they have nothing in common, please explain why.",
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": IMAGE_1_KITTEN
                    }
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": IMAGE_2_TRUCK
                    }
                }
            ]
        }
    ],
    "max_tokens": 1024,
    "temperature": 0.6,
    "top_p": 0.9,
}

# Text Only Inference

In [None]:
print(f"Prompt is:\n {text_only_payload['messages'][0]['content']}")
text_only_output = predictor.predict(text_only_payload)
print("Response is:\n")
print(text_only_output['choices'][0]['message']['content'])
print('----------------------------')

# Single Image Inference

In [None]:
from PIL import Image
import requests
from io import BytesIO

response_kitten = requests.get(IMAGE_1_KITTEN)
img_kitten = Image.open(BytesIO(response_kitten.content))

In [None]:
print("This is the image provided to the model")
img_kitten.show()
single_image_output = predictor.predict(single_image_payload)
print(single_image_output['choices'][0]['message']['content'])
print('----------------------------')

# Multi Image Inference

In [None]:
response_truck = requests.get(IMAGE_2_TRUCK)
img_truck = Image.open(BytesIO(response_truck.content))

In [None]:
print("These are the images provided to the model")
img_kitten.show()
img_truck.show()
multi_image_output = predictor.predict(multi_image_payload)
print(multi_image_output['choices'][0]['message']['content'])
print('----------------------------')

# Local image inference

This example will demo sending the local file, as a part of input payload to the Pixtral model for inference. The input image has Earnings release for Amazon Q2 2024 and we will extract this financial data from the image in the format of choice, such as json, csv or markdown table. For this example, we will be asking Pixtral to generate markdown table representing as it is in the image.

In [None]:
import base64

local_file_path = 'Pixtral_data/AMZN-Q2-2024-Earnings-Release.jpg'
with open(local_file_path, "rb") as image_file:
    base64_encoded = base64.b64encode(image_file.read()).decode('utf-8')


local_image_payload = {
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Extract the text from the provided image. Represent in a single markdown table as the source image. Ensure you verify the values before responding.",
                    #"text": "Extract the text from the provided image. Try building a json response. Ensure you verify the values before responding.",
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpg;base64,{base64_encoded}"
                    }
                }
            ]
        }
    ],
    "max_tokens": 1024,
    "temperature": 0.0,
    "top_p": 0.9,
}

In [None]:
local_image_output = predictor.predict(local_image_payload)

In [None]:
print('Input image to the model')
Image.open(local_file_path).show()
print('----------------------------')
print('Response from the model\n\n')
print(local_image_output['choices'][0]['message']['content'])

In [None]:
# clean up resources
predictor.delete_endpoint()
model.delete_model()