# Deploy Mistral Small 3.1 on Amazon SageMaker with LMI

**Recommended kernel(s):** This notebook can be run with any Amazon SageMaker Studio kernel.

In this notebook, you will learn how to deploy the Mistral Small 3.1 24B instruct model (HuggingFace model ID: [mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)) using Amazon SageMaker. The inference image will be the SageMaker-managed [LMI (Large Model Inference)](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-container-docs.html) Docker image. LMI images feature a [DJL serving](https://github.com/deepjavalibrary/djl-serving) stack powered by the [Deep Java Library](https://djl.ai/).

## Mistral Small 3.1: State-of-the-Art in a Compact Package

Mistral Small 3.1 is the latest model from Mistral AI, featuring improved text performance, multimodal capabilities, and an expanded 128K token context window. With just 24 billion parameters, this model achieves state-of-the-art performance while being compact enough to run on a single RTX 4090 or a Mac with 32GB RAM when quantized.

### Key Features and Capabilities

- **Top-Tier Performance**: Outperforms comparable models like Gemma 2 27B and GPT-4o Mini across multiple benchmarks including MMLU, HumanEval, and various multimodal tasks
- **Multimodal Understanding**: Handles both text and image inputs with advanced vision capabilities for tasks such as document verification, diagnostics, and visual inspection
- **Long Context Window**: Supports up to 128K tokens for processing lengthy documents and conversations at speeds of 150 tokens per second
- **Multilingual Support**: Fluent in dozens of languages including English, French, German, Chinese, Arabic, and many more making it suitable for global applications
- **Efficiency**: Delivers impressive inference speeds (150 tokens/second) with minimal computational requirements compared to larger proprietary models
- **Open Source**: Released under Apache 2.0 license for both research and commercial applications offering flexibility for developers and enterprises

![Performance Comparison](imgs/image-mistral.webp)

### Use Cases

Mistral Small 3.1 is well-suited for a variety of applications, including:

- Fast-response conversational assistants
- Low-latency function calling in automated workflows
- Domain-specific expert systems via fine-tuning
- Programming and math reasoning
- Document analysis and processing
- Image understanding and analysis

### License agreement
* This model is available under the Apache 2.0 license, as detailed on the original [model card](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503).
* This notebook is a sample notebook and not intended for production use.

### Execution environment setup
This notebook requires the following third-party Python dependencies:
* AWS [`sagemaker`](https://sagemaker.readthedocs.io/en/stable/index.html) with a version greater than or equal to 2.242.0

Let's install or upgrade these dependencies using the following command:

In [None]:
%pip install -Uq sagemaker

### Setup

In [None]:
import sagemaker
import boto3
from sagemaker.session import Session
import logging
from sagemaker.s3 import S3Uploader

print(sagemaker.__version__)

In [None]:
try:
    role = sagemaker.get_execution_role()
    sagemaker_session = sagemaker.Session()
    
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

In [None]:
HF_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"

base_name = HF_MODEL_ID.split('/')[-1].replace('.', '-').lower()
model_lineage = HF_MODEL_ID.split("/")[0]
base_name

## Configure Model Serving Properties

Now we'll create a `serving.properties` file that configures how the model will be served. This configuration is crucial for optimal performance and memory utilization.

Key configurations explained:
- **Engine**: Python backend for model serving
- **Model Settings**:
  - Using Mistral Small 3.1 24B Instruct model from Hugging Face
  - Maximum sequence length for efficient processing
  - Model loading timeout for proper initialization
- **Performance Optimizations**:
  - Tensor parallelism across all available GPUs
  - Optimized GPU memory utilization target
  - vLLM rolling batch with appropriate max size for efficient batching
  - Tool call parser and config format specific to Mistral

### Understanding KV Cache and Context Window

The `max_model_len` parameter controls the maximum sequence length the model can handle, which directly affects the size of the KV (Key-Value) cache in GPU memory.

1. Start with a conservative value (current: 8192)
2. Monitor GPU memory usage
3. Incrementally increase if memory permits
4. Target the model's full context window capability (up to 128K tokens)

While Mistral Small 3.1 supports up to 128K context window as mentioned in Mistral AI's official blog post, we begin with a smaller value for efficient deployment; this can be adjusted based on your specific requirements and available resources.

In [None]:
# Create the directory that will contain the files
from pathlib import Path

model_dir = Path('code')
model_dir.mkdir(exist_ok=True)

In [None]:
%%writefile code/serving.properties
engine=Python
option.tensor_parallel_degree=max
option.gpu_memory_utilization=.87
option.model_id=mistralai/Mistral-Small-3.1-24B-Instruct-2503
option.rolling_batch=vllm
option.max_model_len=8192
option.tool_call_parser=mistral
option.enable_auto_tool_choice=true
option.trust_remote_code=true
option.max_rolling_batch_size=16
option.tokenizer_mode=mistral
option.config_format=mistral
option.load_format=mistral
#option.limit_mm_per_prompt='image=4'

## Configure vLLM Requirements

(Optional) The `requirements.txt` file specifies the vLLM version needed for model inference. vLLM inference framework provides optimized serving capabilities.

### Version Considerations
- **vLLM 0.8.1**: Required for Mistral Small 3.1 model
- **transformers 4.50.0**: Updated transformers library

### Performance Impact
Different vLLM versions can affect:
- Inference speed
- Memory utilization
- Batch processing efficiency
- Compatibility with other libraries

### API Considerations

The vLLM API has changed in newer versions, so we need to patch the default DJL implementation of the VLLMRollingBatch.inference method with the following:

```self.engine.add_request(request_id, prompt=prompt_inputs, params=sampling_params, **request_params)```

So provide a custom model.py file that contains the patch:

In [None]:
%%writefile code/requirements.txt
transformers==4.50.0
vllm==0.8.1

In [None]:
%%writefile code/model.py
#!/usr/bin/env python

from djl_python.huggingface import HuggingFaceService
from djl_python.inputs import Input
from djl_python.rolling_batch.vllm_rolling_batch import VLLMRollingBatch
from djl_python.rolling_batch.rolling_batch_vllm_utils import (
    update_request_cache_with_output, create_lora_request, get_lora_request,
    get_prompt_inputs)
import types
import logging

# Create the service
_service = HuggingFaceService()

# Define the patched inference method for VLLMRollingBatch
def patched_inference(self, new_requests):
    """
    Patched version of the inference method that works with vLLM 0.8.1
    """
    # Import necessary classes from vllm within the function to ensure they're available
    from vllm.sampling_params import RequestOutputKind
    from vllm import SamplingParams
    self.add_new_requests(new_requests)
    # step 0: register new requests to engine
    for request in new_requests:
        from vllm.utils import random_uuid
        request_id = random_uuid()
        # Chat completions request route
        if request.parameters.get("sampling_params") is not None:
            prompt_inputs = request.parameters.get("engine_prompt")
            sampling_params = request.parameters.get("sampling_params")
            sampling_params.output_kind = RequestOutputKind.DELTA
        # LMI request route
        else:
            prompt_inputs = get_prompt_inputs(request)
            params = self.translate_vllm_params(request.parameters)
            sampling_params = SamplingParams(**params)
        request_params = dict()
        if request.adapter is not None:
            adapter_name = request.adapter.get_property("name")
            request_params["lora_request"] = get_lora_request(
                adapter_name, self.lora_requests)
        
        # This is the key change: using the new API format for add_request
        # Changed from:
        # self.engine.add_request(request_id=request_id, inputs=prompt_inputs, params=sampling_params, **request_params)
        # To:
        self.engine.add_request(request_id, prompt=prompt_inputs, params=sampling_params, **request_params)
        
        self.request_cache[request_id] = {
            "request_output": request.request_output
        }
    request_outputs = self.engine.step()

    # step 1: put result to cache and request_output
    for request_output in request_outputs:
        self.request_cache = update_request_cache_with_output(
            self.request_cache, request_output, self.get_tokenizer())

    for request in self.active_requests:
        request_output = request.request_output
        if request_output.finished:
            request.last_token = True

    return self.postprocess_results()

def handle(inputs: Input):
    """
    Default handler function
    """
    if not _service.initialized:
        # Apply the monkey patch to VLLMRollingBatch.inference
        try:
            # Import the necessary modules from vllm
            import vllm
            from vllm import SamplingParams
            from vllm.sampling_params import RequestOutputKind
            
            # Patch for vLLM 0.8.1
            logging.info("Patching VLLMRollingBatch.inference for vLLM 0.8.1")
            VLLMRollingBatch.inference = patched_inference
            logging.info("Successfully patched VLLMRollingBatch.inference")
        except Exception as e:
            logging.error(f"Failed to patch VLLMRollingBatch.inference: {e}")
        
        # Initialize the service
        props = inputs.get_properties()
        _service.initialize(props)

    if inputs.is_empty():
        # initialization request
        return None

    return _service.inference(inputs)


## Upload Uncompressed Artifacts to S3
SageMaker allows us to provide [uncompress files](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-uncompressed.html). Thus, we directly upload the folder that contains `model.py`, `serving.properties` and `requirements.txt` to S3.

This process:
1. Determines the S3 bucket location (using SageMaker default bucket)
2. Defines a prefix path for organization
3. Uploads the packaged artifacts

> **Note**: The default SageMaker bucket follows the naming pattern: `sagemaker-{region}-{account-id}`

In [None]:
from sagemaker.s3 import S3Uploader

sagemaker_default_bucket = sagemaker_session.default_bucket()

code_model_uri = S3Uploader.upload(
    local_path="code",
    desired_s3_uri=f"s3://{sagemaker_default_bucket}/lmi/{base_name}/code"
)

print(f"code_model_uri: {code_model_uri}")

## Configure Model Container and Instance

For deploying Mistral Small 3.1 24B Instruct, we'll use:
- **LMI (Deep Java Library) Inference Container**: A container optimized for large language model inference
- **G6 or G6e Instance**: AWS's GPU instance types powered by powerful NVIDIA GPUs

Key configurations:
- The container URI points to the DJL inference container in ECR (Elastic Container Registry)
- We use a GPU instance that offers:
  - Sufficient NVIDIA GPUs for tensor parallelism
  - Adequate GPU memory for model weights and KV cache
  - High network bandwidth
  - Sufficient system memory

> **Note**: The region in the container URI should match your AWS region.

In [None]:
gpu_instance_type = "ml.g6.12xlarge"

In [None]:
image_uri = "763104351884.dkr.ecr.{}.amazonaws.com/djl-inference:0.32.0-lmi14.0.0-cu126".format(sagemaker_session.boto_session.region_name)
print(image_uri)

## Create SageMaker Model

Now we'll create a SageMaker Model object that combines our:
- Container image (LMI)
- Model artifacts (configuration files)
- IAM role (for permissions)

This step defines the model configuration but doesn't deploy it yet. The Model object represents the combination of:

1. **Container Image** (`image_uri`): DJL Inference optimized for LLMs
2. **Model Data** (`model_data`): Our configuration files in S3
3. **IAM Role** (`role`): Permissions for model execution

### Required Permissions
The IAM role needs:
- S3 read access for model artifacts
- CloudWatch permissions for logging
- ECR permissions to pull the container

#### HUGGING_FACE_HUB_TOKEN 
Mistral Small 3.1 24B Instruct is available on Hugging Face, but you may need to provide your Hugging Face token if it's a gated model or if you want to access it from a private repository.

In [None]:
# Specify the S3 URI for your uncompressed code files 
model_data = {
    "S3DataSource": {
        "S3Uri": f"{code_model_uri}/",
        "S3DataType": "S3Prefix",
        "CompressionType": "None"
    }
}

In [None]:
HUGGING_FACE_HUB_TOKEN = "REPLACE WITH YOUR HF TOKEN"

In [None]:
from sagemaker.utils import name_from_base
from sagemaker.model import Model

model_name = name_from_base(base_name, short=True)+"-1"
print(model_name)
# Create model
mistral_3_model = Model(
    name = model_name,
    image_uri=image_uri,
    model_data=model_data,  # Path to uncompressed code files
    role=role,
    env={
        "HF_TASK": "Image-Text-to-Text",
        "OPTION_LIMIT_MM_PER_PROMPT": "image=4",# For multimodal capabilities
        'HUGGING_FACE_HUB_TOKEN': HUGGING_FACE_HUB_TOKEN # If needed
    },
)

## Deploy Model to SageMaker Endpoint

Now we'll deploy our model to a SageMaker endpoint for real-time inference. This is a significant step that:
1. Provisions the specified compute resources (G6 instance)
2. Deploys the model container
3. Sets up the endpoint for API access

### Deployment Configuration
- **Instance Count**: 1 instance for single-node deployment
- **Instance Type**: GPU instance for high-performance inference

![Accuracy Comparison](imgs/mistral-instruct-knowledge.webp)

> ⚠️ **Important**: 
> - Deployment will take 15-20 minutes
> - Monitor the CloudWatch logs for progress

In [None]:
%%time

from sagemaker.utils import name_from_base

endpoint_name = name_from_base(base_name, short=True)

mistral_3_model.deploy(
    endpoint_name=endpoint_name,
    initial_instance_count=1,
    instance_type=gpu_instance_type
)

### Use the code below to create a predictor from an existing endpoint and make inference

In [None]:
from sagemaker.serializers import JSONSerializer, IdentitySerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.predictor import Predictor

endpoint_name = "mistral-small-3-1-24b-instruct-2503-250401-1938" # replace with your endpoint name

small_3_predictor = Predictor(
    endpoint_name=endpoint_name,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer()
)

## Text-only Inference

Mistral Small 3.1 handles standard text generation tasks with exceptional quality. With its 24B parameter size, it delivers responses that match or exceed much larger models while maintaining fast inference speeds of approximately 150 tokens per second as noted in Mistral AI's official documentation.

In [None]:
%%time

payload = {
    "messages" : [
        {
            "role": "user",
            "content": [{
                "type": "text", 
                "text": "Write me a poem about Machine Learning."
            }]
        }
    ],
    "max_tokens":300,
    "temperature": 0.7,
    "top_p": 0.9,
    "top_k": 250
}

response = small_3_predictor.predict(payload)
print(response['choices'][0]['message']['content'])

# Print usage statistics
print("=== Token Usage ===")
usage = response['usage']
print(f"Prompt Tokens: {usage['prompt_tokens']}")
print(f"Completion Tokens: {usage['completion_tokens']}")
print(f"Total Tokens: {usage['total_tokens']}")

## Multimodality

Mistral Small 3.1 models are multimodal, handling both text and image input while generating text output. As described by Mistral AI, this multimodal capability makes the model well-suited for various image understanding tasks like document verification, diagnostics, visual inspection for quality checks, and object detection.

Here's an example of how to query the model with an image:

In [None]:
from IPython.display import Image as IPyImage

IPyImage(url="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG")

In [None]:
import json

payload = {
  "messages": [
    {
      "role": "user",
      "content": [
        {
          "type": "image_url", 
          "image_url": {"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"},
        },
        {"type": "text", "text": "What animal is on the candy?"}
      ]
    }
  ]
}

response = small_3_predictor.predict(payload)
print(response['choices'][0]['message']['content'])

# Print usage statistics
print("=== Token Usage ===")
usage = response['usage']
print(f"Prompt Tokens: {usage['prompt_tokens']}")
print(f"Completion Tokens: {usage['completion_tokens']}")
print(f"Total Tokens: {usage['total_tokens']}")

In [None]:
IMAGE_1_KITTEN = "https://resources.djl.ai/images/kitten.jpg"
IMAGE_2_TRUCK = "https://resources.djl.ai/images/truck.jpg"


In [None]:
from PIL import Image
import requests
from io import BytesIO

response_kitten = requests.get(IMAGE_1_KITTEN)
img_kitten = Image.open(BytesIO(response_kitten.content))

response_truck = requests.get(IMAGE_2_TRUCK)
img_truck = Image.open(BytesIO(response_truck.content))

In [None]:
multi_image_payload = {
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Can you describe the following images and tell me what they have in common? If they have nothing in common, please explain why.",
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": IMAGE_1_KITTEN
                    }
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": IMAGE_2_TRUCK
                    }
                }
            ]
        }
    ],
    "max_tokens": 1024,
    "temperature": 0.6,
    "top_p": 0.9,
}
print("These are the images provided to the model")
img_kitten.show()
img_truck.show()
multi_image_output = small_3_predictor.predict(multi_image_payload)
print(multi_image_output['choices'][0]['message']['content'])
print('----------------------------')

In [None]:
# Clean up
small_3_predictor.delete_model()
small_3_predictor.delete_endpoint(delete_endpoint_config=True)

## Conclusion

You've successfully deployed the Mistral Small 3.1 24B Instruct model on Amazon SageMaker and learned how to interact with it for both text and image inputs. This powerful model combines efficient performance with state-of-the-art capabilities in a relatively compact package.

Key takeaways:

1. Mistral Small 3.1 offers exceptional performance comparable to much larger models
2. The model supports both text and image inputs with its multimodal capabilities
3. SageMaker provides a robust platform for deploying and scaling inference
4. Streaming implementation allows for real-time interactive applications

For production deployments, consider:
- Adjusting instance types based on traffic patterns
- Implementing auto-scaling for cost optimization
- Fine-tuning the model for your specific domain
- Experimenting with different inference parameters for optimal user experience

For more information about Mistral Small 3.1, visit [Mistral AI's official page](https://mistral.ai/news/mistral-small-3-1/).

#### Distributed by:
- AWS
- Mistral