# Deploy Phi-3 Model on SageMaker using Large Model Inference (LMI) Container

This notebook demonstrates how to deploy Microsoft's Phi-3 model on Amazon SageMaker using the SageMaker Large Model Inference (LMI) container with vLLM backend.

## Overview

- **Model**: microsoft/Phi-3-mini-4k-instruct (3.8B parameters)
- **Container**: DJL LMI Container with vLLM backend
- **Instance Type**: ml.g5.2xlarge (1 GPU)
- **Backend**: vLLM for optimized inference

## What is LMI?

SageMaker Large Model Inference (LMI) containers are purpose-built Docker containers for LLM inference. They provide:

- **Multiple Backend Support**: vLLM, TensorRT-LLM, Transformers NeuronX
- **Optimized Performance**: Continuous batching, quantization, tensor parallelism
- **Easy Configuration**: Environment variables or serving.properties file
- **Flexibility**: Works with HuggingFace models or custom S3 artifacts

## Prerequisites

- AWS Account with SageMaker access
- Appropriate IAM role
- GPU instance quota

## 1. Setup and Installation

In [None]:
!pip install sagemaker --upgrade --quiet
!pip install boto3 --upgrade --quiet

## 2. Initialize SageMaker Session

In [None]:
import sagemaker
import boto3
import json
from datetime import datetime
from sagemaker import Model

# Initialize session
sess = sagemaker.Session()
region = sess.boto_region_name
role = sagemaker.get_execution_role()
sm_client = boto3.client('sagemaker', region_name=region)
sagemaker_runtime = boto3.client('sagemaker-runtime', region_name=region)

print(f"SageMaker role: {role}")
print(f"AWS region: {region}")
print(f"SageMaker version: {sagemaker.__version__}")

## 3. Get LMI Container Image URI

We'll use the latest DJL LMI container with vLLM backend.

In [None]:
# LMI container configuration
# Check for latest version at: https://github.com/aws/deep-learning-containers/blob/master/available_images.md
LMI_VERSION = '0.31.0-lmi13.0.0-cu124'

# Construct the image URI
inference_image_uri = f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:{LMI_VERSION}"

print(f"Using LMI container: {inference_image_uri}")

## 4. Method 1: Deploy using Environment Variables

This is the most flexible approach as configuration is stored in the SageMaker Model object.

In [None]:
# Model configuration
model_name_env = f"phi3-mini-lmi-env-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
endpoint_name_env = f"{model_name_env}-ep"

# LMI configuration via environment variables
lmi_env_config = {
    # Model configuration
    'HF_MODEL_ID': 'microsoft/Phi-3-mini-4k-instruct',
    
    # Backend configuration
    'OPTION_ROLLING_BATCH': 'vllm',  # Use vLLM backend
    'OPTION_MAX_ROLLING_BATCH_SIZE': '8',
    
    # GPU configuration
    'OPTION_TENSOR_PARALLEL_DEGREE': '1',  # Number of GPUs
    
    # Context and generation limits
    'OPTION_MAX_MODEL_LEN': '4096',
    'OPTION_MAX_INPUT_LEN': '3072',
    
    # Performance tuning
    'OPTION_DTYPE': 'fp16',  # or 'bf16' for better precision
    'OPTION_GPU_MEMORY_UTILIZATION': '0.9',
    
    # Optional: Quantization for memory efficiency
    # 'OPTION_QUANTIZE': 'awq',  # or 'gptq'
    
    # Optional: HuggingFace token for gated models
    # 'HUGGING_FACE_HUB_TOKEN': '<YOUR_HF_TOKEN>',
}

print(f"Model name: {model_name_env}")
print(f"Endpoint name: {endpoint_name_env}")
print(f"\nLMI Configuration:")
print(json.dumps(lmi_env_config, indent=2))

In [None]:
# Create SageMaker Model with environment variables
model_env = Model(
    image_uri=inference_image_uri,
    model_data=None,  # Model will be downloaded from HuggingFace
    role=role,
    name=model_name_env,
    env=lmi_env_config,
    sagemaker_session=sess
)

print(f"Model object created: {model_name_env}")

In [None]:
# Deploy the model
print("Deploying model... This will take 5-10 minutes.")

predictor_env = model_env.deploy(
    endpoint_name=endpoint_name_env,
    initial_instance_count=1,
    instance_type='ml.g5.2xlarge',
    container_startup_health_check_timeout=600,
    model_data_download_timeout=900,
)

print(f"\n✅ Endpoint deployed successfully: {endpoint_name_env}")

## 5. Test the Deployed Model

Test basic inference with the deployed endpoint.

In [None]:
# Simple test
test_payload = {
    "inputs": "What is the capital of France?",
    "parameters": {
        "max_new_tokens": 100,
        "temperature": 0.7,
        "top_p": 0.9,
        "do_sample": True
    }
}

response = predictor_env.predict(test_payload)

print("\n" + "="*50)
print("RESPONSE:")
print("="*50)
print(response['generated_text'])
print("="*50)

## 6. Method 2: Deploy using serving.properties File

This method packages configuration with the model artifact.

In [None]:
import os
import tarfile

# Create directory for model artifacts
model_dir = "phi3_model_artifacts"
os.makedirs(model_dir, exist_ok=True)

# Create serving.properties file
serving_properties = """# Phi-3 Model Configuration
engine=Python
option.model_id=microsoft/Phi-3-mini-4k-instruct

# Backend configuration
option.rolling_batch=vllm
option.max_rolling_batch_size=8

# GPU configuration
option.tensor_parallel_degree=1

# Context limits
option.max_model_len=4096
option.max_input_len=3072

# Performance settings
option.dtype=fp16
option.gpu_memory_utilization=0.9

# Optional: Quantization
# option.quantize=awq

# Optional: HuggingFace token
# option.huggingface_token=<YOUR_HF_TOKEN>
"""

# Write serving.properties
with open(f"{model_dir}/serving.properties", 'w') as f:
    f.write(serving_properties)

print("serving.properties created:")
print(serving_properties)

In [None]:
# Create tar.gz archive
tarball_name = "phi3_model.tar.gz"

with tarfile.open(tarball_name, "w:gz") as tar:
    tar.add(model_dir, arcname=".")

print(f"Created tarball: {tarball_name}")

# Upload to S3
bucket = sess.default_bucket()
prefix = "phi3-lmi-models"
s3_model_uri = f"s3://{bucket}/{prefix}/{tarball_name}"

s3_client = boto3.client('s3')
s3_client.upload_file(tarball_name, bucket, f"{prefix}/{tarball_name}")

print(f"Model artifacts uploaded to: {s3_model_uri}")

In [None]:
# Create model with serving.properties
model_name_props = f"phi3-mini-lmi-props-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
endpoint_name_props = f"{model_name_props}-ep"

model_props = Model(
    image_uri=inference_image_uri,
    model_data=s3_model_uri,
    role=role,
    name=model_name_props,
    sagemaker_session=sess
)

print(f"Model with serving.properties created: {model_name_props}")

## 7. Advanced Inference Examples

In [None]:
# Chat-based inference with Phi-3 format
def create_phi3_prompt(system_msg, user_msg):
    """
    Create a properly formatted Phi-3 prompt.
    """
    prompt = f"<|system|>\n{system_msg}<|end|>\n"
    prompt += f"<|user|>\n{user_msg}<|end|>\n"
    prompt += "<|assistant|>\n"
    return prompt

system_message = "You are a helpful AI assistant specialized in explaining technical concepts."
user_message = "Explain how transformers work in machine learning."

formatted_prompt = create_phi3_prompt(system_message, user_message)

chat_payload = {
    "inputs": formatted_prompt,
    "parameters": {
        "max_new_tokens": 300,
        "temperature": 0.7,
        "top_p": 0.9,
        "top_k": 50,
        "do_sample": True,
        "stop": ["<|end|>", "<|endoftext|>"]
    }
}

response = predictor_env.predict(chat_payload)

print("\n" + "="*50)
print("CHAT RESPONSE:")
print("="*50)
print(response['generated_text'])
print("="*50)

## 8. Streaming Inference with LMI

In [None]:
def stream_lmi_response(endpoint_name, payload):
    """
    Stream responses from LMI endpoint.
    """
    # Add stream parameter
    payload['parameters']['stream'] = True
    
    response = sagemaker_runtime.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name,
        Body=json.dumps(payload),
        ContentType='application/json'
    )
    
    event_stream = response['Body']
    
    print("\nStreaming response:")
    print("-" * 50)
    
    full_text = ""
    
    for event in event_stream:
        if 'PayloadPart' in event:
            chunk = event['PayloadPart']['Bytes'].decode('utf-8')
            
            # Parse JSON lines
            for line in chunk.strip().split('\n'):
                if line.startswith('data:'):
                    data_str = line[5:].strip()
                    if data_str == '[DONE]':
                        break
                    try:
                        data = json.loads(data_str)
                        if 'token' in data:
                            token_text = data['token']['text']
                            print(token_text, end='', flush=True)
                            full_text += token_text
                    except json.JSONDecodeError:
                        continue
    
    print("\n" + "-" * 50)
    return full_text

# Stream a response
stream_payload = {
    "inputs": "Write a short story about a robot learning to paint.",
    "parameters": {
        "max_new_tokens": 200,
        "temperature": 0.8,
        "top_p": 0.9,
        "do_sample": True
    }
}

result = stream_lmi_response(endpoint_name_env, stream_payload)

## 9. Benchmark Performance

In [None]:
import time
import numpy as np

def benchmark_endpoint(endpoint_name, num_requests=10):
    """
    Benchmark endpoint performance.
    """
    latencies = []
    
    test_prompt = "Explain artificial intelligence in simple terms."
    
    payload = {
        "inputs": test_prompt,
        "parameters": {
            "max_new_tokens": 100,
            "temperature": 0.7,
            "do_sample": True
        }
    }
    
    print(f"Running {num_requests} inference requests...\n")
    
    for i in range(num_requests):
        start_time = time.time()
        
        response = sagemaker_runtime.invoke_endpoint(
            EndpointName=endpoint_name,
            Body=json.dumps(payload),
            ContentType='application/json'
        )
        
        end_time = time.time()
        latency = (end_time - start_time) * 1000  # Convert to ms
        latencies.append(latency)
        
        print(f"Request {i+1}: {latency:.2f} ms")
    
    print(f"\n{'='*50}")
    print("Benchmark Results:")
    print(f"{'='*50}")
    print(f"Mean Latency: {np.mean(latencies):.2f} ms")
    print(f"Median Latency: {np.median(latencies):.2f} ms")
    print(f"Min Latency: {np.min(latencies):.2f} ms")
    print(f"Max Latency: {np.max(latencies):.2f} ms")
    print(f"Std Dev: {np.std(latencies):.2f} ms")
    print(f"{'='*50}")

# Run benchmark
benchmark_endpoint(endpoint_name_env, num_requests=5)

## 10. Multi-turn Conversation Example

In [None]:
def multi_turn_chat(endpoint_name, conversation_history):
    """
    Handle multi-turn conversations.
    conversation_history: list of {'role': 'user'/'assistant', 'content': 'text'}
    """
    # Build the prompt
    prompt = "<|system|>\nYou are a helpful AI assistant.<|end|>\n"
    
    for turn in conversation_history:
        role = turn['role']
        content = turn['content']
        
        if role == 'user':
            prompt += f"<|user|>\n{content}<|end|>\n"
        elif role == 'assistant':
            prompt += f"<|assistant|>\n{content}<|end|>\n"
    
    # Add assistant prefix for next response
    prompt += "<|assistant|>\n"
    
    payload = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": 200,
            "temperature": 0.7,
            "do_sample": True,
            "stop": ["<|end|>"]
        }
    }
    
    response = sagemaker_runtime.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=json.dumps(payload),
        ContentType='application/json'
    )
    
    result = json.loads(response['Body'].read().decode('utf-8'))
    return result['generated_text']

# Example conversation
conversation = [
    {"role": "user", "content": "What is Python?"},
]

print("Turn 1:")
response1 = multi_turn_chat(endpoint_name_env, conversation)
print(f"Assistant: {response1}\n")

# Add to conversation history
conversation.append({"role": "assistant", "content": response1})
conversation.append({"role": "user", "content": "Can you give me an example?"})

print("Turn 2:")
response2 = multi_turn_chat(endpoint_name_env, conversation)
print(f"Assistant: {response2}")

## 11. Error Handling and Retry Logic

In [None]:
from botocore.exceptions import ClientError
import time

def invoke_with_retry(endpoint_name, payload, max_retries=3):
    """
    Invoke endpoint with retry logic.
    """
    for attempt in range(max_retries):
        try:
            response = sagemaker_runtime.invoke_endpoint(
                EndpointName=endpoint_name,
                Body=json.dumps(payload),
                ContentType='application/json'
            )
            
            result = json.loads(response['Body'].read().decode('utf-8'))
            return result
            
        except ClientError as e:
            error_code = e.response['Error']['Code']
            print(f"Attempt {attempt + 1} failed: {error_code}")
            
            if attempt < max_retries - 1:
                wait_time = 2 ** attempt  # Exponential backoff
                print(f"Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
            else:
                print(f"Max retries reached. Error: {str(e)}")
                raise

# Test retry logic
test_payload = {
    "inputs": "Test retry logic.",
    "parameters": {"max_new_tokens": 50}
}

result = invoke_with_retry(endpoint_name_env, test_payload)
print(f"\nResponse: {result['generated_text']}")

## 12. Monitor CloudWatch Metrics

In [None]:
import datetime as dt

cloudwatch = boto3.client('cloudwatch', region_name=region)

def get_endpoint_metrics(endpoint_name, metric_name, minutes=60):
    """
    Get CloudWatch metrics for the endpoint.
    """
    end_time = dt.datetime.utcnow()
    start_time = end_time - dt.timedelta(minutes=minutes)
    
    response = cloudwatch.get_metric_statistics(
        Namespace='AWS/SageMaker',
        MetricName=metric_name,
        Dimensions=[
            {'Name': 'EndpointName', 'Value': endpoint_name},
            {'Name': 'VariantName', 'Value': 'AllTraffic'}
        ],
        StartTime=start_time,
        EndTime=end_time,
        Period=300,  # 5 minutes
        Statistics=['Average', 'Sum', 'Maximum']
    )
    
    return response['Datapoints']

# Get invocation metrics
print("Fetching endpoint metrics...\n")

metrics = ['Invocations', 'ModelLatency', 'Invocation4XXErrors', 'Invocation5XXErrors']

for metric in metrics:
    datapoints = get_endpoint_metrics(endpoint_name_env, metric)
    if datapoints:
        print(f"\n{metric}:")
        for dp in sorted(datapoints, key=lambda x: x['Timestamp'])[-5:]:
            print(f"  {dp['Timestamp']}: {dp.get('Sum', dp.get('Average', 'N/A'))}")
    else:
        print(f"\n{metric}: No data available")

## 13. Cleanup Resources

In [None]:
# Delete endpoint
print(f"Deleting endpoint: {endpoint_name_env}")
predictor_env.delete_endpoint(delete_endpoint_config=True)

# Delete model
print(f"Deleting model: {model_name_env}")
predictor_env.delete_model()

# Optionally delete S3 artifacts
# s3_client.delete_object(Bucket=bucket, Key=f"{prefix}/{tarball_name}")

print("\n✅ Cleanup complete!")

## Summary

In this notebook, we covered:

1. ✅ Understanding LMI containers and vLLM backend
2. ✅ Two deployment methods (environment variables vs serving.properties)
3. ✅ Chat-based inference with proper formatting
4. ✅ Streaming responses
5. ✅ Performance benchmarking
6. ✅ Multi-turn conversations
7. ✅ Error handling and retry logic
8. ✅ CloudWatch metrics monitoring
9. ✅ Resource cleanup

## Key Differences: LMI vs TGI

| Feature | LMI (vLLM) | TGI |
|---------|------------|-----|
| Backend | vLLM, TensorRT-LLM, Transformers | HuggingFace TGI |
| Configuration | serving.properties or env vars | Env vars |
| Flexibility | Multiple backends | Single backend |
| Quantization | AWQ, GPTQ, FP8 | AWQ, GPTQ |
| Best For | Production, multi-backend | Quick deployment |

## Next Steps

- Deploy Phi-3 medium (14B) or small (7B) models
- Implement auto-scaling with CloudWatch alarms
- Test quantization (AWQ/GPTQ) for memory efficiency
- Set up A/B testing with multiple endpoints
- Integrate with Lambda for serverless inference

## Additional Resources

- [DJL LMI Documentation](https://docs.djl.ai/master/docs/serving/serving/docs/lmi/index.html)
- [vLLM Documentation](https://docs.vllm.ai/)
- [SageMaker LMI Guide](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-large-model-inference.html)
- [Phi-3 on HuggingFace](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)