# How to deploy the Gemma 3 27B instruct for inference using Amazon SageMakerAI
**Recommended kernel(s):** This notebook can be run with any Amazon SageMaker Studio kernel.

In this notebook, you will learn how to deploy the Gemma 3 27 B instruct model (HuggingFace model ID: [google/gemma-3-27b-it](https://huggingface.co/google/gemma-3-27b-it)) using Amazon SageMaker AI. The inference image will be the SageMaker-managed [LMI (Large Model Inference)](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-container-docs.html) Docker image. LMI images features a [DJL serving](https://github.com/deepjavalibrary/djl-serving) stack powered by the [Deep Java Library](https://djl.ai/). 

Gemma 3 models are multimodal, handling text and image input and generating text output, with open weights for both pre-trained variants and instruction-tuned variants. Gemma 3 has a large, 128K context window, multilingual support in over 140 languages, and is available in more sizes than previous versions. Gemma 3 models are well-suited for a variety of text generation and image understanding tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as laptops, desktops or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone.

### License agreement
* This model is gated on HuggingFace, please refer to the original [model card](https://huggingface.co/google/gemma-3-27b-it) for license.
* This notebook is a sample notebook and not intended for production use.

### Execution environment setup
This notebook requires the following third-party Python dependencies:
* AWS [`sagemaker`](https://sagemaker.readthedocs.io/en/stable/index.html) with a version greater than or equal to 2.242.0

Let's install or upgrade these dependencies using the following command:

In [None]:
%pip install -Uq sagemaker

### Setup

In [None]:
import sagemaker
import boto3
from sagemaker.session import Session
import logging
from sagemaker.s3 import S3Uploader

print(sagemaker.__version__)

In [None]:
try:
    role = sagemaker.get_execution_role()
    sagemaker_session  = sagemaker.Session()
    
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

In [None]:
HF_MODEL_ID = "google/gemma-3-27b-it"

base_name = HF_MODEL_ID.split('/')[-1].replace('.', '-').lower()
model_lineage = HF_MODEL_ID.split("/")[0]
base_name

## Configure Model Serving Properties

Now we'll create a `serving.properties` file that configures how the model will be served. This configuration is crucial for optimal performance and memory utilization.

Key configurations explained:
- **Engine**: Python backend for model serving
- **Model Settings**:
  - Using gemma-3-27b-it model from Hugging Face
  - Maximum sequence length of 32768 tokens
  - model loading timeout of 1200 seconds (20 minutes)
- **Performance Optimizations**:
  - Tensor parallelism across all available GPUs
  - 87% GPU memory utilization target
  - vLLM rolling batch with max size of 16 for efficient batching
  
### Understanding KV Cache and Context Window

The `max_model_len` parameter controls the maximum sequence length the model can handle, which directly affects the size of the KV (Key-Value) cache in GPU memory.

1. Start with a conservative value (current: 32768)
2. Monitor GPU memory usage
3. Incrementally increase if memory permits
4. Target the model's full context window 

In [None]:
# Create the directory that will contain the files
from pathlib import Path

model_dir = Path('code')
model_dir.mkdir(exist_ok=True)

In [None]:
%%writefile code/serving.properties
engine=Python
option.trust_remote_code=true
option.tensor_parallel_degree=max
option.gpu_memory_utilization=.87
option.model_loading_timeout=1200
option.max_model_len=32768
option.max_rolling_batch_size=16
option.rolling_batch=vllm
option.model_id=google/gemma-3-27b-it

## Configure vLLM Requirements

(Optional) The `requirements.txt` file specifies the vLLM version needed for model inference. vLLM inference framework provides optimized serving capabilities.

### Version Considerations
- **vLLM 0.8.1**: Required for the gemma-3 models
- **transformers 4.50.0** we also update transformers

### Performance Impact
Different vLLM versions can affect:
- Inference speed
- Memory utilization
- Batch processing efficiency
- Compatibility with other libraries

### API Considerations

The vLLM API has changed in newer versions, so we need to patch the default DJL implementation of the VLLMRollingBatch.inference method with the following:

```self.engine.add_request(request_id, prompt=prompt_inputs, params=sampling_params, **request_params)```

So provide a custom model.py file that contains the patch

In [None]:
%%writefile code/requirements.txt
transformers==4.50.0
vllm==0.8.1

In [None]:
%%writefile code/model.py
#!/usr/bin/env python

from djl_python.huggingface import HuggingFaceService
from djl_python.inputs import Input
from djl_python.rolling_batch.vllm_rolling_batch import VLLMRollingBatch
from djl_python.rolling_batch.rolling_batch_vllm_utils import (
    update_request_cache_with_output, create_lora_request, get_lora_request,
    get_prompt_inputs)
import types
import logging

# Create the service
_service = HuggingFaceService()

# Define the patched inference method for VLLMRollingBatch
def patched_inference(self, new_requests):
    """
    Patched version of the inference method that works with vLLM 0.8.1
    """
    # Import necessary classes from vllm within the function to ensure they're available
    from vllm.sampling_params import RequestOutputKind
    from vllm import SamplingParams
    self.add_new_requests(new_requests)
    # step 0: register new requests to engine
    for request in new_requests:
        from vllm.utils import random_uuid
        request_id = random_uuid()
        # Chat completions request route
        if request.parameters.get("sampling_params") is not None:
            prompt_inputs = request.parameters.get("engine_prompt")
            sampling_params = request.parameters.get("sampling_params")
            sampling_params.output_kind = RequestOutputKind.DELTA
        # LMI request route
        else:
            prompt_inputs = get_prompt_inputs(request)
            params = self.translate_vllm_params(request.parameters)
            sampling_params = SamplingParams(**params)
        request_params = dict()
        if request.adapter is not None:
            adapter_name = request.adapter.get_property("name")
            request_params["lora_request"] = get_lora_request(
                adapter_name, self.lora_requests)
        
        # This is the key change: using the new API format for add_request
        # Changed from:
        # self.engine.add_request(request_id=request_id, inputs=prompt_inputs, params=sampling_params, **request_params)
        # To:
        self.engine.add_request(request_id, prompt=prompt_inputs, params=sampling_params, **request_params)
        
        self.request_cache[request_id] = {
            "request_output": request.request_output
        }
    request_outputs = self.engine.step()

    # step 1: put result to cache and request_output
    for request_output in request_outputs:
        self.request_cache = update_request_cache_with_output(
            self.request_cache, request_output, self.get_tokenizer())

    for request in self.active_requests:
        request_output = request.request_output
        if request_output.finished:
            request.last_token = True

    return self.postprocess_results()

def handle(inputs: Input):
    """
    Default handler function
    """
    if not _service.initialized:
        # Apply the monkey patch to VLLMRollingBatch.inference
        try:
            # Import the necessary modules from vllm
            import vllm
            from vllm import SamplingParams
            from vllm.sampling_params import RequestOutputKind
            
            # Patch for vLLM 0.8.1
            logging.info("Patching VLLMRollingBatch.inference for vLLM 0.8.1")
            VLLMRollingBatch.inference = patched_inference
            logging.info("Successfully patched VLLMRollingBatch.inference")
        except Exception as e:
            logging.error(f"Failed to patch VLLMRollingBatch.inference: {e}")
        
        # Initialize the service
        props = inputs.get_properties()
        _service.initialize(props)

    if inputs.is_empty():
        # initialization request
        return None

    return _service.inference(inputs)

## Upload uncompress Artifacts to S3
SageMaker AI allows us to provide [uncompress files](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-uncompressed.html). Thus, we directly upload the folder that contains ``model.py``,  ``serving.properties`` and ``requirements.txt`` to s3

This process:
1. Determines the S3 bucket location (using SageMaker default bucket)
2. Defines a prefix path for organization
3. Uploads the packaged artifacts

> **Note**: The default SageMaker bucket follows the naming pattern: `sagemaker-{region}-{account-id}`

In [None]:
from sagemaker.s3 import S3Uploader

sagemaker_default_bucket = sagemaker_session.default_bucket()

code_model_uri = S3Uploader.upload(
    local_path="code",
    desired_s3_uri=f"s3://{sagemaker_default_bucket}/lmi/{base_name}/code"
)

print(f"code_model_uri: {code_model_uri}")

## Configure Model Container and Instance

For deploying Gemma-3-27B-it, we'll use:
- **LMI (Deep Java Library) Inference Container**: A container optimized for large language model inference
- **[G6e Instance](https://aws.amazon.com/ec2/instance-types/g6e/)**: AWS's GPU instance type powered by NVIDIA L40S Tensor Core GPUs 

Key configurations:
- The container URI points to the DJL inference container in ECR (Elastic Container Registry)
- We use `ml.g6e.48xlarge` instance which offer:
  - 8 NVIDIA L40S Tensor Core GPUs
  - 384 GB of total GPU memory (48 GB of memory per GPU)
  - up to 400 Gbps of network bandwidth
  - up to 1.536 TB of system memory
  - and up to 7.6 TB of local NVMe SSD storage.

> **Note**: The region in the container URI should match your AWS region.

In [None]:
gpu_instance_type = "ml.g6e.48xlarge"

In [None]:
image_uri = "763104351884.dkr.ecr.{}.amazonaws.com/djl-inference:0.32.0-lmi14.0.0-cu126".format(sagemaker_session.boto_session.region_name)
print(image_uri)

## Create SageMaker Model

Now we'll create a SageMaker Model object that combines our:
- Container image (LMI)
- Model artifacts (configuration files)
- IAM role (for permissions)

This step defines the model configuration but doesn't deploy it yet. The Model object represents the combination of:

1. **Container Image** (`image_uri`): DJL Inference optimized for LLMs
2. **Model Data** (`model_data`): Our configuration files in S3
3. **IAM Role** (`role`): Permissions for model execution

### Required Permissions
The IAM role needs:
- S3 read access for model artifacts
- CloudWatch permissions for logging
- ECR permissions to pull the container

#### HUGGING_FACE_HUB_TOKEN 
Gemma-3-27B-Instruct is a gated model so you will need to provide your Hugging face ID

In [None]:
# Specify the S3 URI for your uncompressed code files 
model_data = {
    "S3DataSource": {
        "S3Uri": f"{code_model_uri}/",
        "S3DataType": "S3Prefix",
        "CompressionType": "None"
    }
}

In [None]:
HUGGING_FACE_HUB_TOKEN = "hf_"

In [None]:
from sagemaker.utils import name_from_base
from sagemaker.model import Model

model_name = name_from_base(base_name, short=True)

# Create model
gemma_3_model = Model(
    name = model_name,
    image_uri=image_uri,
    model_data=model_data,  # Path to uncompressed code files
    role=role,
    env={
        "HF_TASK": "Image-Text-to-Text",
        'HUGGING_FACE_HUB_TOKEN': HUGGING_FACE_HUB_TOKEN # Gemma-3-27B-Instruct is a gated model so you will need to provide your Hugging face ID
    },
)

## Deploy Model to SageMaker Endpoint

Now we'll deploy our model to a SageMaker endpoint for real-time inference. This is a significant step that:
1. Provisions the specified compute resources (G6e instance)
2. Deploys the model container
3. Sets up the endpoint for API access

### Deployment Configuration
- **Instance Count**: 1 instance for single-node deployment
- **Instance Type**: `ml.g6e.48xlarge` for high-performance inference

> ⚠️ **Important**: 
> - Deployment will take 15-20 minutes
> - Monitor the CloudWatch logs for progress

In [None]:
%%time

from sagemaker.utils import name_from_base

endpoint_name = name_from_base(base_name, short=True)

gemma_3_model.deploy(
    endpoint_name=endpoint_name,
    initial_instance_count=1,
    instance_type=gpu_instance_type
)

### Use the code below to create a predictor from an existing endpoint and make inference

In [None]:
from sagemaker.serializers import JSONSerializer, IdentitySerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.predictor import Predictor

endpoint_name = ""# replace with your enpoint name "gemma-3-27b-it-250328-1022"

gemma_3_predictor = Predictor(
    endpoint_name=endpoint_name,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer()
)

## Text only Inference

In [None]:
%%time

payload = {
    "messages" : [
        {
            "role": "user",
            "content": [{"type": "text", "text": "Write me a poem about Machine Learning."}]
        }
    ],
    "max_tokens":300,
    "temperature": 0.7,
    "top_p": 0.9,
    "top_k": 250
}

response = gemma_3_predictor.predict(payload)
print(response['choices'][0]['message']['content'])

# Print usage statistics
print("=== Token Usage ===")
usage = response['usage']
print(f"Prompt Tokens: {usage['prompt_tokens']}")
print(f"Completion Tokens: {usage['completion_tokens']}")
print(f"Total Tokens: {usage['total_tokens']}")

## Multimodality

Gemma 3 models are multimodal, handling text and image input and generating text output, with open weights for both pre-trained variants and instruction-tuned variants. Gemma 3 has a large, 128K context window, multilingual support in over 140 languages, and is available in more sizes than previous versions. Gemma 3 models are well-suited for a variety of text generation and image understanding tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as laptops, desktops or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone.

In [None]:
from IPython.display import Image as IPyImage

IPyImage(url="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG")

In [None]:
import json

payload = {
  "messages": [
    {
      "role": "system",
      "content": [{"type": "text", "text": "You are a helpful assistant."}]
    },
    {
      "role": "user",
      "content": [
        {
          "type": "image_url", 
          "image_url": {"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"},
        },
        {"type": "text", "text": "What animal is on the candy?"}
      ]
    }
  ]
}

response = gemma_3_predictor.predict(payload)
print(response['choices'][0]['message']['content'])

# Print usage statistics
print("=== Token Usage ===")
usage = response['usage']
print(f"Prompt Tokens: {usage['prompt_tokens']}")
print(f"Completion Tokens: {usage['completion_tokens']}")
print(f"Total Tokens: {usage['total_tokens']}")

## Implement Streaming

This section implements a streaming chat interface for real-time interaction with the Gemma 3-27B-it model. The implementation includes:

1. **Streaming Infrastructure**:
   - Custom `LineIterator` for efficient stream processing
   - Real-time token processing
   - Performance monitoring (tokens per second)

2. **Performance Features**:
   - Live response streaming
   - Token speed monitoring
   - Memory-efficient processing

### Key Components

#### Streaming Parameters
- `max_tokens`: 8189 (default)
- `temperature`: 0.7 Float that controls the randomness of the sampling. Lower values make the model more deterministic, while higher values make the model more random. Zero means greedy sampling.
- `top_p`: 0.9 Float that controls the cumulative probability of the top tokens to consider.
- Real-time TPS (Tokens Per Second) monitoring

In [None]:
import io
import json
import time
import boto3
from IPython.display import clear_output

# SageMaker Runtime client
smr_client = boto3.client("sagemaker-runtime")
# Replace with your SageMaker endpoint name if needed
endpoint_name = # replace with your enpoint name "gemma-3-27b-it-250328-1022"

class LineIterator:
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord("\n"):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if "PayloadPart" not in chunk:
                print("Unknown event type:" + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk["PayloadPart"]["Bytes"])

def stream_chat_response(endpoint_name, inputs, max_tokens=8189, temperature=0.7, top_p=0.9):    
    body = {
      "messages": [
        {
          "role": "system",
          "content": [{"type": "text", "text": "You are a helpful assistant."}]
        },
        {
          "role": "user",
          "content": [
            {"type": "text", "text": inputs}
          ]
        }
      ],
        
        "max_tokens":max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "stream": True,
    }

    resp = smr_client.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name,
        Body=json.dumps(body),
        ContentType="application/json",
    )

    event_stream = resp["Body"]
    start_json = b"{"
    full_response = ""
    start_time = time.time()
    token_count = 0

    for line in LineIterator(event_stream):
        if line != b"" and start_json in line:
            data = json.loads(line[line.find(start_json):].decode("utf-8"))
            token_text = data['choices'][0]['delta'].get('content', '')
            full_response += token_text
            token_count += 1

            # Calculate tokens per second
            elapsed_time = time.time() - start_time
            tps = token_count / elapsed_time if elapsed_time > 0 else 0

            # Clear the output and reprint everything
            clear_output(wait=True)
            print("Bot:", full_response)
            print(f"\nTokens per Second: {tps:.2f}", end="")

    print("\n") # Add a newline after response is complete
    
    return full_response

def chat(endpoint_name):
    print("Welcome to the SageMaker Streaming Chat! Type 'exit' to quit.")
    chat_history = []
    while True:
        user_input = input("\nYou: ")
        if user_input.lower() == "exit":
            break
        bot_response = stream_chat_response(endpoint_name, user_input)
        
        # Update chat history
        chat_history.append({
            'user': user_input,
            'assistant': bot_response
        })


# Start the chat
chat(endpoint_name)

In [None]:
# Clean up
gemma_3_predictor.delete_model()
gemma_3_predictor.delete_endpoint(delete_endpoint_config=True)