
# Deploying Swiss LLM Apertus on SageMaker with LMI v15 powered by vLLM

This notebook demonstrates deploying and running inference with the Apertus model. We will cover 

1. Installing SageMaker python SDK, Setting up SageMaker resources and permissions
2. Deploying the model using SageMaker LMI (Large Model Inference Container powered by Vllm)
3. Invoking the model using streaming responses

## Environment Setup

First, we'll install the SageMaker SDK to ensure compatibility with the latest features, particularly those needed for large language model deployment and streaming inference.



In [None]:
%pip install sagemaker --upgrade --quiet --no-warn-conflicts

In [None]:
local_mode = False  # if you have a local GPU you can also run the model locally using SageMaker SDK, e.g. for debugging

In [None]:
if local_mode:
    %pip install sagemaker[local] --upgrade --quiet --no-warn-conflicts

In [None]:
from sagemaker import Model, Session, get_execution_role 
from sagemaker.utils import name_from_base
from botocore.exceptions import ClientError

role = get_execution_role()  # execution role for the endpoint

if local_mode:
    from sagemaker.local import entities, LocalSession

    # Extend LocalMode’s health-check timeout to 15 minutes
    entities.HEALTH_CHECK_TIMEOUT_LIMIT = 15 * 60  # seconds

    sess = LocalSession()
    sess.config = {"local": {"local_code": True}}
else:
    sess = Session() # sagemaker session for interacting with different AWS APIs

## Configure Model Container and Instance

For deploying Apertus, we'll use:
- **LMI (Deep Java Library) Inference Container with vLLM** : A container optimized for large language model inference
- **G6 Instance**: AWS's GPU instance type optimized for large model inference

Key configurations:
- The container URI points to the DJL inference container in ECR (Elastic Container Registry)
- We use `ml.g6.48xlarge` instances which offer:
  - 8 NVIDIA L4 GPUs with 192 GB GPU memory
  - 768 GB of memory
  - High network bandwidth for optimal inference performance

> **Note**: The region in the container URI should match your AWS region. REPLACE `eu-central-2` with your region if different.

In [None]:
# Define region where you have capacity
REGION = "eu-central-2"
INSTANCE_TYPE = (
    "local_gpu" if local_mode else "ml.g6.48xlarge"
)  # Alternative: "ml.p5e.48xlarge" or "ml.p4d.24xlarge"

# Select the latest container. Check the link for the latest available version https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers
CONTAINER_VERSION = "0.33.0-lmi15.0.0-cu128"

# Construct container URI
if REGION == "eu-central-2":
    container_account = 380420809688
else:
    container_account = 763104351884

container_uri = f"{container_account}.dkr.ecr.{REGION}.amazonaws.com/djl-inference:{CONTAINER_VERSION}"


# Validate region and print configuration
if REGION != sess.boto_region_name:
    print(
        f"⚠️ Warning: Container region ({REGION}) differs from session region ({sess.boto_region_name})"
    )
else:
    print(f"✅ Region validation passed: {REGION}")

print(f"📦 Container URI: {container_uri}")
print(f"🖥️ Instance Type: {INSTANCE_TYPE}")

## Create SageMaker Model

Now we'll create a SageMaker Model object that combines our:
- vllm env variables
- Container image (LMI)
- Model artifacts (configuration files)
- IAM role (for permissions)

This step defines the model configuration but doesn't deploy it yet. The Model object represents the combination of:

1. **Container Image** (`image_uri`): DJL Inference optimized for LLMs
2. **Env Variables** (`env`): Our variables for the model server
3. **IAM Role** (`role`): Permissions for model execution


In [None]:
%pip install huggingface_hub

In [None]:
# Download the model locally first
from huggingface_hub import snapshot_download

print("📥 Downloading model locally...")
model_name = "Saesara/swissai"
local_model_path = "./apertus"

try:
    model_path = snapshot_download(
        repo_id=model_name, local_dir=local_model_path, local_files_only=False
    )
    print(f"✅ Model downloaded to: {model_path}")
except Exception as e:
    print(f"❌ Error downloading model: {e}")
    # Fallback: you can manually download the model or use a different approach
    raise

> **Note**: Apertus is implemented in transformers v4.56.0. At the time of writing VLLM has not yet released a new version to PyPI containing the Apertus implementation which is why we are installing a VLLM nightly release. 

In [None]:
requirements = """git+https://github.com/huggingface/transformers.git@v4.56.0
https://vllm-wheels.s3.us-west-2.amazonaws.com/1cf3753b901ba874a830c19555bb31fe37f91231/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl
"""

In [None]:
%store requirements >requirements.txt

We need to patch the VLLMHandler in DeepJavaLibrary Serving to not pass parameters that are deprecated in the latest VLLM version.

For this we create our own custom inference Python script.

In [None]:
patched_vllm_service = """#!/usr/bin/env python

import logging
from typing import Optional, Union, AsyncGenerator

# Patch CLI args if needed
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
_orig_args = VllmRbProperties.generate_vllm_engine_arg_dict
def _patched_args(self, passthrough_args):
    args = _orig_args(self, passthrough_args)
    for key in ("device", "use_v2_block_manager"):
        args.pop(key, None)
    return args
VllmRbProperties.generate_vllm_engine_arg_dict = _patched_args

# Custom initialize
from djl_python.lmi_vllm.vllm_async_service import VLLMHandler
from vllm import AsyncLLMEngine
from djl_python.properties_manager.hf_properties import HuggingFaceProperties
from vllm.entrypoints.openai.serving_models import OpenAIServingModels, BaseModelPath
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat

async def _custom_initialize(self, properties: dict):
    # 1. Load HF and vLLM properties
    self.hf_configs = HuggingFaceProperties(**properties)
    self.vllm_properties = VllmRbProperties(**properties)

    # 2. Build AsyncLLMEngine
    self.vllm_engine_args = self.vllm_properties.get_engine_args(async_engine=True)
    self.vllm_engine = AsyncLLMEngine.from_engine_args(self.vllm_engine_args)
    self.tokenizer = await self.vllm_engine.get_tokenizer()
    model_config = await self.vllm_engine.get_model_config()

    # 3. Prepare model registry and completion service
    model_names = self.vllm_engine_args.served_model_name or "lmi"
    if not isinstance(model_names, list):
        model_names = [model_names]
    # Users can provide multiple names that refer to the same model
    base_model_paths = [
        BaseModelPath(model_name, self.vllm_engine_args.model)
        for model_name in model_names
    ]
    
    # Use the first model name as default.
    # This is needed to be backwards compatible since LMI never required the model name in payload
    self.model_name = model_names[0]

    self.model_registry = OpenAIServingModels(
        self.vllm_engine,
        model_config,
        base_model_paths,
    )
    self.completion_service = OpenAIServingCompletion(
        self.vllm_engine,
        model_config,
        self.model_registry,
        request_logger=None,
    )

    
    
    # Construct OpenAIServingChat with only supported kwargs
    chat_kwargs = {
        'request_logger': None,
        'chat_template': getattr(self.vllm_properties, 'chat_template', None),
        'chat_template_content_format': getattr(
            self.vllm_properties, 'chat_template_content_format', None
        ),
        # do NOT include reasoning_parser or enable_reasoning
        'enable_auto_tools': getattr(self.vllm_properties, 'enable_auto_tool_choice', False),
        'tool_parser': getattr(self.vllm_properties, 'tool_call_parser', None),
        'reasoning_parser' : getattr(self.vllm_properties, 'reasoning_parser', ""),
    }

    logging.getLogger(__name__).info("Initializing OpenAIServingChat without reasoning flags")
    self.chat_completion_service = OpenAIServingChat(
        self.vllm_engine,
        model_config,
        self.model_registry,
        "assistant",  # response_role
        **chat_kwargs,
    )

    self.initialized = True

VLLMHandler.initialize = _custom_initialize

# Delegate to the existing handle function
from djl_python.lmi_vllm.vllm_async_service import handle
"""

In [None]:
%store patched_vllm_service >entrypoint.py

We combine the requirements file with the model weights into a single archive to upload to an Amazon S3 bucket. The Amazon SageMaker inference enpoint will download the archive from the Amazon S3 bucket and extract it into the inference container.

In [None]:
%%sh
mv requirements.txt apertus/
mv entrypoint.py apertus/
tar czvf apertus.tar.gz apertus/

Replace `<your-bucket-name>` with your own Amazon S3 bucket name in the same region in which you plan to deploy the endpoint in.

In [None]:
# Upload model artifacts to S3
bucket = "apertus-checkpoints" # REPLACE with the name of you Amazon S3 bucket

if not bucket or bucket == "<your-bucket-name>": # DO NOT replace this string
    raise ValueError("❌ Please set a valid S3 bucket name. Replace bucket='<your-bucket-name>'.")
s3_code_prefix = "apertus-lmi"
code_artifact = sess.upload_data("apertus.tar.gz", bucket, s3_code_prefix)

In [None]:
# Updated vLLM configuration to use local model
vllm_config = {
    "OPTION_MAX_MODEL_LEN": "4096",
    "OPTION_MAX_ROLLING_BATCH_SIZE": "8",
    "OPTION_MODEL_LOADING_TIMEOUT": "1500",
    "SERVING_FAIL_FAST": "true",
    "OPTION_ROLLING_BATCH": "disable",
    "OPTION_ASYNC_MODE": "true",
    "OPTION_ENTRYPOINT": "entrypoint",
    "OPTION_TRUST_REMOTE_CODE": "true",
    "OPTION_MODEL_ID": "./",  
    "VLLM_USE_PRECOMPILED": "1",
}

The Model object combines all the information on how to deploy the model to an endpoint.

In [None]:
model = Model(
    image_uri=container_uri,
    role=role,
    model_data=code_artifact,
    sagemaker_session=sess,
    env=vllm_config,
)

## Deploy Model to SageMaker Endpoint

Now we'll deploy our model to a SageMaker endpoint for real-time inference. This is a significant step that:
1. Provisions the specified compute resources (G6 instance)
2. Deploys the model container
3. Sets up the endpoint for API access

### Deployment Configuration
- **Instance Count**: 1 instance for single-node deployment
- **Instance Type**: `ml.g6.48xlarge` for high-performance inference
- **Health Check Timeout**: 1800 seconds 
  - Extended timeout needed for large model loading
  - Includes time for container setup and model initialization

> ⚠️ **Important**: 
> - Deployment can take upto 15 minutes
> - Monitor the endpoint status in SageMaker Console and CloudWatch logs for progress

In [None]:
if local_mode:
    # To see progress
    !docker pull $container_uri

In [None]:
endpoint_name = name_from_base("Apertus")

print(endpoint_name)

try:
    model.deploy(
        initial_instance_count=1,
        instance_type=INSTANCE_TYPE,
        endpoint_name=endpoint_name,
        container_startup_health_check_timeout=1800,
    )
    print(f"\n✅ Endpoint '{endpoint_name}' deployed successfully")
except ClientError as e:
    error_code = e.response['Error']['Code']
    if error_code == 'ResourceLimitExceeded':
        print(
            "❌ Resource limit exceeded."
            + f"Did you request the necessary Service Quotas for {INSTANCE_TYPE} in {REGION}?"
            + "See also https://repost.aws/knowledge-center/sagemaker-resource-limit-exceeded-error"
        )
    elif error_code == 'InsufficientInstanceCapacity':
        print(
            "❌ Insufficient instance capacity. Try a different AZ or instance type"
            + "See also https://repost.aws/knowledge-center/sagemaker-insufficient-capacity-error"
        )
    else:
        print(f"❌ Deployment failed: {e}")
    raise e
except Exception as e:
    print(f"❌ Unexpected deployment error: {e}")
    print("💡 Check CloudWatch logs for detailed error information")
    raise e

## Running Inference requests to the model

Once you have deployed the model to the Amazon SageMaker inference endpoint you can invoke it. Replace `<your_endpoint_name>` below with the name of your SageMaker inference endpoint.

In [None]:
# Option 1: Invoke model with response streaming
from json import dumps as json_dumps, loads as json_loads, JSONDecodeError
from boto3 import client
from time import time

# Create SageMaker Runtime client
smr_client = client("sagemaker-runtime")

endpoint_name = "<your_endpoint_name>" # REPLACE with your endpoint

print(f"Endpoint name: {endpoint_name}")
if endpoint_name == "<your_endpoint_name>": # DO NOT replace this string
    raise ValueError("❌ Please set a valid endpoint name")

# Invoke with messages format
body = {
    "messages": [
        {"role": "user", "content": "Name popular places to visit in London?"}
    ],
    "temperature": 0.9,
    "max_tokens": 256,
    "stream": True,
}

start_time = time()
first_token_received = False
ttft = None
token_count = 0
full_response = ""

print(f"Prompt: {body['messages'][0]['content']}\n")
print("Response:", end=" ", flush=True)

# Invoke endpoint with streaming

try:
    resp = smr_client.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name,
        Body=json_dumps(body),
        ContentType="application/json",
    )
except ClientError as e:
    error_code = e.response['Error']['Code']
    if error_code == 'ValidationException':
        print("❌ Validation Exception. Invalid request format or parameters")
    elif error_code == 'ModelError':
        print("❌ Model error. Check model logs")
    else:
        print(f"❌ Inference failed: {e}")
    raise e
except Exception as e:
    print(f"❌ Unexpected inference error: {e}")
    raise e

# Process streaming response
for event in resp["Body"]:
    if "PayloadPart" in event:
        payload = event["PayloadPart"]["Bytes"].decode()

        try:

            if payload.startswith("data: "):
                data = json_loads(payload[6:])  # Skip "data: " prefix
            else:
                data = json_loads(payload)

            token_count += 1
            if not first_token_received:
                ttft = time() - start_time
                first_token_received = True

            # Handle different streaming response formats
            if "choices" in data and len(data["choices"]) > 0:
                # Messages-compatible format
                if (
                    "delta" in data["choices"][0]
                    and "content" in data["choices"][0]["delta"]
                ):
                    token_text = data["choices"][0]["delta"]["content"]
                    full_response += token_text
                    print(token_text, end="", flush=True)
            elif "token" in data and "text" in data["token"]:
                # TGI format
                token_text = data["token"]["text"]
                full_response += token_text
                print(token_text, end="", flush=True)

        except JSONDecodeError:
            # Skip invalid JSON
            continue

end_time = time()
total_latency = end_time - start_time

print("\n\nMetrics:")
if ttft:
    print(
        f"Time to First Token (TTFT): {ttft:.2f} seconds"
    )
else:
    print('No tokens received')
print(f"Total Tokens Generated: {token_count}")
print(f"Total Latency: {total_latency:.2f} seconds")
# print(f"\nFull Response:\n{full_response}")

In [None]:
# # # Option 2: Invoke without streaming
# from json import dumps as json_dumps, loads as json_loads, JSONDecodeError
# from boto3 import client



# # Create SageMaker Runtime client for invocation
# smr_client = client('sagemaker-runtime')

# endpoint_name = "<your_endpoint_name>" # REPLACE with your endpoint

# print(f"Endpoint name: {endpoint_name}")
# if endpoint_name == "<your_endpoint_name>": # DO NOT replace this string
#     raise ValueError("❌ Please set a valid endpoint name")

# print(f"Prompt: {body['messages'][0]['content']}\n")

# # Invoke with messages format
# body = {
#     "messages": [
#         {"role": "user", "content": "Name popular places to visit in London?"}
#     ],
#     "temperature": 0.9,
#     "max_tokens": 256,
#     "stream": False,
# }


# try:
#     # Non-streaming invocation
#     response = smr_client.invoke_endpoint(
#         EndpointName=endpoint_name,
#         ContentType='application/json',
#         Body=json_dumps(body)
#     )
# except ClientError as e:
#     error_code = e.response['Error']['Code']
#     if error_code == 'ValidationException':
#         print("❌ Validation Exception. Invalid request format or parameters")
#     elif error_code == 'ModelError':
#         print("❌ Model error. Check model logs")
#     else:
#         print(f"❌ Inference failed: {e}")
#     raise e
# except Exception as e:
#     print(f"❌ Unexpected inference error: {e}")
#     raise e


# result = json_loads(response['Body'].read().decode())
# print(result["choices"][0]["message"]["content"])
# print(f"\nFull Response:\n{result}")

## Cleanup: Delete Endpoint

In [None]:
# from sagemaker import Session

# # Initialize session
# sess = Session()


print(f"Deleting SageMaker resources for endpoint: {endpoint_name}")
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)

Remove the local artifacts which contain the model weights:

In [None]:
!rm -rf apertus
!rm apertus.tar.gz