# Deploying Llama 3.1 to SageMaker with Custom Model Parser for Strands Agent

### What You'll Learn

By the end of this notebook, you will understand how to:

1. **Set up a SageMaker environment** - Initialize sessions, retrieve credentials, and configure AWS resources automatically
2. **Use ml-container-creator** - Generate complete deployment projects with infrastructure-as-code
3. **Build and push containers** - Use AWS CodeBuild to create and publish Docker containers to ECR
4. **Deploy models to SageMaker** - Create real-time inference endpoints for LLMs
5. **Build custom model providers** - Extend SageMakerAIModel with custom response parsing logic
6. **Create Strands agents** - Build lightweight AI agents that use your deployed models

## Prerequisites

Before running this notebook, please execute the `prerequisites.sh` script. This script installs [Node.JS](https://nodejs.org/en), [Yeoman](https://yeoman.io/) and the zip utility. It also clones the [ml-container-creator](https://github.com/awslabs/ml-container-creator) utility for deploying large language models to [SageMaker AI](https://aws.amazon.com/sagemaker/ai/).

**Remember to delete your endpoint when finished!** Instructions are provided at the end of this notebook.

We add the autoreleod extension for ease of use and to prevent kernel restarts on the notebook.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%%bash
./prerequisites.sh

## Section 1: Environment Setup

In the next cell, we'll set up our SageMaker environment.

In [None]:
# Import required libraries
import os
import re
import json
import time
import boto3
import sagemaker

from datetime import datetime
from pathlib import Path
from sagemaker import get_execution_role

print("Initializing SageMaker environment...\n")

sagemaker_session = sagemaker.Session()

try:
    role = get_execution_role()
except ValueError:
    print("‚ö†Ô∏è  Not running in SageMaker environment")
    print("   Please set role manually: role = 'arn:aws:iam::ACCOUNT:role/ROLE_NAME'")
    raise

region = sagemaker_session.boto_region_name
bucket = sagemaker_session.default_bucket()
boto_session = boto3.Session(region_name=region)

# Display summary of configuration
print("=" * 60)
print("Environment Configuration Summary")
print("=" * 60)
print(f"Region:          {region}")
print(f"S3 Bucket:       {bucket}")
print(f"Execution Role:  {role}")
print("=" * 60)
print("\n‚úÖ Environment setup complete! Ready to proceed.\n")

## Section 2: Generate Deployment Project

### What is ml-container-creator?

[ml-container-creator](https://github.com/awslabs/ml-container-creator) is an open-source tool from AWS Labs that automates the creation of SageMaker "Bring Your Own Container" (BYOC) deployment projects. It's a Yeoman generator that creates infrastructure code based on your ML framework and serving requirements. Review the [ml-container-creator documentation site](https://awslabs.github.io/ml-container-creator/) for more details on how it works.

In [None]:
%%bash
yo --version
yo --generators

In [None]:
timestamp = int(time.time())
project_name = f"llama-31-deployment-{timestamp}"

output_dir = Path("./generated_projects")
output_dir.mkdir(exist_ok=True)
project_dir = output_dir / project_name

In [None]:
!export HF_TOKEN="<HUGGING_FACE_TOKEN>"
!export AWS_ROLE=$role

In [None]:
# %%bash -s project_name region HF_TOKEN
!yo ml-container-creator $project_name \
--project-dir=$project_dir \
--framework=transformers \
--model-server=sglang \
--model-name=meta-llama/Llama-3.1-8B-Instruct \
--deploy-target=codebuild \
--region=$region \
--include_testing=true \
--test_types=hosted-model-endpoint \
--hf_token=$HF_TOKEN \
--role_arn=$role \
--skip-install \
--skip-prompts \
--force

## Section 3: Build and Push Container

Let's start the build! üõ†Ô∏è

In [None]:
!cd ./$project_dir && ./deploy/submit_build.sh

## Section 4: Deploy to SageMaker

### Cost Awareness

‚ö†Ô∏è **Important:** Once your endpoint is InService, you will begin to be charged for it!

Let's deploy! üöÄ

In [None]:
!./$project_dir/deploy/deploy.sh

In [None]:
# GRAB ENDPOINT NAME FROM DEPLOY.SH OUTPUT
endpoint_name="<ENDPOINT_NAME_HERE>"

### Test Endpoint
1. Non-Streaming Test
2. Multi-Turn Chat
3. Streaming Response
4. Agent Response (this should fail)

In [None]:
# Initialize SageMaker Runtime client
runtime_client = boto3.client('sagemaker-runtime', region_name=region)

# Prepare the payload for sglang (OpenAI-compatible format)
payload = {
    "messages": [
        {"role": "system", "content": "You are a helpful AI assistant."},
        {"role": "user", "content": "Hello, how are you?"}
    ],
    "max_tokens": 1000,
    "temperature": 0.7,
    "top_p": 0.9,
    "stream": False
}

print("Sending request to SageMaker endpoint...")
print(f"Endpoint: {endpoint_name}")
print(f"Payload: {json.dumps(payload, indent=2)}\n")

try:
    # Invoke the endpoint
    response = runtime_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='application/json',
        Accept='application/json',
        Body=json.dumps(payload)
    )
    
    # Parse the response
    response_body = response['Body'].read().decode('utf-8')
    result = json.loads(response_body)
    
    print("‚úì Response received successfully!\n")
    print("=" * 60)
    print("Full Response:")
    print("=" * 60)
    print(json.dumps(result, indent=2))
    print("\n" + "=" * 60)
        
except Exception as e:
    print(f"‚ùå Error invoking endpoint: {e}")
    print(f"\nError type: {type(e).__name__}")

In [None]:
conversation_payload = {
    "messages": [
        {"role": "system", "content": "You are a helpful AI assistant."},
        {"role": "user", "content": "What are the three primary colors?"},
        {"role": "assistant", "content": "The three primary colors are red, blue, and yellow."},
        {"role": "user", "content": "What happens if you mix the first two?"}
    ],
    "max_tokens": 500,
    "temperature": 0.7,
    "top_p": 0.9,
    "stream": False
}

print("Sending request to SageMaker endpoint...")
print(f"Endpoint: {endpoint_name}")
print(f"Payload: {json.dumps(conversation_payload, indent=2)}\n")

try: 
    # Multi-turn conversation example    
    response = runtime_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='application/json',
        Accept='application/json',
        Body=json.dumps(conversation_payload)
    )

    response_body = response['Body'].read().decode('utf-8')
    result = json.loads(response_body)

    print("‚úì Response received successfully!\n")
    print("=" * 60)
    print("Full Response:")
    print("=" * 60)
    print(json.dumps(result, indent=2))
    print("\n" + "=" * 60)
           
except Exception as e:
    print(f"‚ùå Error invoking endpoint: {e}")
    print(f"\nError type: {type(e).__name__}")

In [None]:
# Streaming inference
streaming_payload = {
    "messages": [
        {"role": "user", "content": "Write a short poem about AI"}
    ],
    "max_tokens": 200,
    "temperature": 0.8,
    "stream": True
}

print("Sending request to SageMaker endpoint...")
print(f"Endpoint: {endpoint_name}")
print(f"Payload: {json.dumps(streaming_payload, indent=2)}\n")

try:
    response = runtime_client.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name,
        ContentType='application/json',
        Accept='application/json',
        Body=json.dumps(streaming_payload)
    )

    # Process streaming response
    print("Streaming response:")
    print("=" * 60)
    for event in response['Body']:
        chunk = event['PayloadPart']['Bytes'].decode('utf-8')
        # Handle empty chunks
        if not chunk.strip():
            continue
            
        # Handle multiple JSON objects in the chunk
        for line in chunk.split('\n'):
            if line.startswith('data: '):
                try:
                    json_str = line.replace('data: ', '').strip()
                    if json_str:  # Check if there's actual content
                        chunk_data = json.loads(json_str)
                        if 'choices' in chunk_data and len(chunk_data['choices']) > 0:
                            delta = chunk_data['choices'][0].get('delta', {})
                            if 'content' in delta:
                                print(delta['content'], end='', flush=True)
                except json.JSONDecodeError as json_err:
                    # Skip malformed JSON without breaking the stream
                    continue
    print("\n")
    print("=" * 60)
    print("\n")
    print("‚úì Response received successfully!\n")
except Exception as e:
    print(f"‚ùå Error invoking endpoint: {e}")
    print(f"\nError type: {type(e).__name__}")

In [None]:
from strands.agent import Agent
from strands.models.sagemaker import SageMakerAIModel

provider = SageMakerAIModel(
    endpoint_config={
        "endpoint_name": endpoint_name,
        "region_name": region,
    },
    payload_config={
        "max_tokens": 1000,
        "temperature": 0.7,
        "stream": True,
    }
)

agent = Agent(
    name="llama-assistant",    
    model=provider,
    system_prompt=(
        "You are a helpful AI assistant powered by Llama 3.1, "
        "deployed on Amazon SageMaker. You provide clear, accurate, "
        "and friendly responses to user questions. When you don't know "
        "something, you say so honestly rather than making things up."
    )
)

test_message = "Hello! Can you tell me what you are and how you can help me?"
response = agent(test_message)
    
# Display the response
print("-" * 60)
print(f"Agent: {response.content}\n")
print("-" * 60)

# Display metadata about the response
print("\nResponse Metadata:")
if hasattr(response, 'usage'):
    print(f"  Tokens Used: {response.usage.get('total_tokens', 'N/A')}")
    print(f"  Prompt Tokens: {response.usage.get('prompt_tokens', 'N/A')}")
    print(f"  Completion Tokens: {response.usage.get('completion_tokens', 'N/A')}")
if hasattr(response, 'finish_reason'):
    print(f"  Finish Reason: {response.finish_reason}")

print("\n‚ùå Test 0 failed! The agent should not have successfully responded.\n")

This error is a result of the default `SageMakerAIModel` response parser being unable to parse the model's response. We need to use a custom model parser to parse agent responses from this model through Strands.

## Section 5: Create Custom Model Provider

### How Custom Parsing Works

We'll use the `LlamaModelProvider` class that:

1. **Extends SageMakerAIModel**: Inherits all the endpoint invocation logic
2. **Overrides parse_response()**: Implements custom parsing for sglang format
3. **Adds error handling**: Provides meaningful error messages
4. **Validates responses**: Checks structure before accessing fields

The class is defined in `./code/llama_model_provider.py`

### Customer Parser Agent Tests
1. Simple Chat
2. Multi-Turn Chat
3. Complex Reasoning

Let's build it! üõ†Ô∏è

In [None]:
import sys
sys.path.append(os.path.join(os.getcwd(), './code')) 
from llama_model_provider import LlamaModelProvider

In [None]:
llama_provider = LlamaModelProvider(
    endpoint_name=endpoint_name,  # The SageMaker endpoint to invoke
    region_name=region,  # AWS region (from environment setup)
    max_tokens=1000,  # Maximum response length (roughly 750 words)
    temperature=0.7,  # Creativity level (0.0 = deterministic, 2.0 = very random)
    top_p=0.9  # Diversity control (0.9 is a good default)
)

agent = Agent(
    name="llama-assistant",
    model=llama_provider,
    system_prompt=(
        "You are a helpful AI assistant powered by Llama 3.1, "
        "deployed on Amazon SageMaker. You provide clear, accurate, "
        "and friendly responses to user questions. When you don't know "
        "something, you say so honestly rather than making things up."
    )
)

test_message = "Hello! Can you tell me what you are and how you can help me?"
response = agent(test_message)

# Display the response
print("-" * 60)
print(f"Agent: {response}\n")
print("-" * 60)

# Display metadata about the response
print("\nResponse Metadata:")
if hasattr(response, 'usage'):
    print(f"  Tokens Used: {response.usage.get('total_tokens', 'N/A')}")
    print(f"  Prompt Tokens: {response.usage.get('prompt_tokens', 'N/A')}")
    print(f"  Completion Tokens: {response.usage.get('completion_tokens', 'N/A')}")
if hasattr(response, 'finish_reason'):
    print(f"  Finish Reason: {response.finish_reason}")

print("\n‚úÖ Test 1 passed! The agent successfully responded.\n")

In [None]:
# Test multi-turn conversation

message1 = "What are the three primary colors?"
print(f"User: {message1}\n")
print("‚è≥ Waiting for response...\n")

response1 = agent(message1)
print("-" * 60)
print(f"Agent: {response1}\n")
print("-" * 60)

# Second message: Follow-up that requires context from first message
# This tests if the agent remembers what we just talked about
message2 = "Can you mix them to create other colors?"
print(f"\nUser: {message2}\n")
print("‚è≥ Waiting for response...\n")

response2 = agent(message2)
print("-" * 60)
print(f"Agent: {response2}\n")
print("-" * 60)

# Third message: Another follow-up
message3 = "What color do you get if you mix the first two you mentioned?"
print(f"\nUser: {message3}\n")
print("‚è≥ Waiting for response...\n")

response3 = agent(message3)
print("-" * 60)
print(f"Agent: {response3}\n")
print("-" * 60)

print("\n‚úÖ Test 2 passed! The agent successfully maintained conversation context.\n")

In [None]:
complex_prompt = """
I'm building a Python web application and need to decide between Flask and FastAPI.
Can you compare these two frameworks and recommend which one I should use for a
REST API that needs to handle high traffic and support async operations?
""".strip()

print(f"User: {complex_prompt}\n")
print("‚è≥ Waiting for response...\n")

response = agent(complex_prompt)

print("-" * 60)
print(f"Agent: {response}\n")
print("-" * 60)

print("\n‚úÖ Test 3 passed! The agent handled a complex reasoning task.\n")

## Section 6: Cleanup

### Cleanup Instructions

When you're done with your endpoint, delete it to avoid ongoing charges. Follow these steps to clean up all resources created in this notebook.

**What Needs to be Deleted:**

1. **SageMaker Endpoint** - The running inference service (costs money!)
2. **Endpoint Configuration** - The configuration (no cost, but good practice)
3. **SageMaker Model** - The model definition (no cost, but good practice)
4. **ECR Image** (Optional) - The container image (minimal storage cost)
5. **CodeBuild Project** (Optional) - The build project (no cost when not building)

**Important Notes:**

- Deleting the endpoint stops all charges immediately
- Deletion is permanent - you'll need to redeploy to use the model again
- You can keep the ECR image and CodeBuild project for future deployments
- The generated project files remain on disk for reuse

In [None]:
PROCEED_WITH_DELETION = False

In [None]:
# Delete SageMaker resources
print("=" * 60)
print("Resource Cleanup")
print("=" * 60)
print("\n‚ö†Ô∏è  WARNING: This will delete your SageMaker endpoint and stop all charges.")
print("You will need to redeploy if you want to use the model again.\n")

if not PROCEED_WITH_DELETION:
    print("‚ùå Deletion not confirmed.")
    print("\nTo delete resources, set PROCEED_WITH_DELETION = True in the cell above.")
    print("\nAlternatively, use the AWS CLI commands below for manual deletion.\n")
else:
    print("‚úÖ Deletion confirmed. Proceeding with cleanup...\n")
    
    # Create SageMaker client
    sagemaker_client = boto3.client('sagemaker', region_name=region)
    
    # Step 1: Delete the endpoint
    print("Step 1: Deleting SageMaker Endpoint...")
    try:
        sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
        print(f"  ‚úì Endpoint '{endpoint_name}' deletion initiated")
        
        # Wait for endpoint to be deleted
        print("  ‚è≥ Waiting for endpoint to be deleted (this may take 2-3 minutes)...")
        waiter = sagemaker_client.get_waiter('endpoint_deleted')
        waiter.wait(EndpointName=endpoint_name)
        print("  ‚úì Endpoint deleted successfully\n")
        
    except sagemaker_client.exceptions.ClientError as e:
        if 'Could not find endpoint' in str(e):
            print(f"  ‚ÑπÔ∏è  Endpoint '{endpoint_name}' not found (may already be deleted)\n")
        else:
            print(f"  ‚ùå Error deleting endpoint: {e}\n")
    
    # Step 2: Delete the endpoint configuration
    print("Step 2: Deleting Endpoint Configuration...")
    try:
        # First, get the config name from the endpoint (if it still exists)
        config_name = f"{endpoint_name}-config"
        sagemaker_client.delete_endpoint_config(EndpointConfigName=config_name)
        print(f"  ‚úì Endpoint configuration '{config_name}' deleted\n")
        
    except sagemaker_client.exceptions.ClientError as e:
        if 'Could not find endpoint configuration' in str(e):
            print(f"  ‚ÑπÔ∏è  Endpoint configuration not found (may already be deleted)\n")
        else:
            print(f"  ‚ùå Error deleting endpoint configuration: {e}\n")
    
    # Step 3: Delete the model
    print("Step 3: Deleting SageMaker Model...")
    try:
        model_name = f"{endpoint_name}-model"
        sagemaker_client.delete_model(ModelName=model_name)
        print(f"  ‚úì Model '{model_name}' deleted\n")
        
    except sagemaker_client.exceptions.ClientError as e:
        if 'Could not find model' in str(e):
            print(f"  ‚ÑπÔ∏è  Model not found (may already be deleted)\n")
        else:
            print(f"  ‚ùå Error deleting model: {e}\n")
    
    print("=" * 60)
    print("‚úÖ Cleanup Complete!")
    print("=" * 60)
    print("\nResources deleted:")
    print("  ‚úì SageMaker Endpoint (charges stopped)")
    print("  ‚úì Endpoint Configuration")
    print("  ‚úì SageMaker Model")
    print("\nResources retained (optional cleanup):")
    print("  ‚Ä¢ ECR container image (minimal storage cost)")
    print("  ‚Ä¢ CodeBuild project (no cost when not building)")
    print("  ‚Ä¢ Generated project files (local, no cost)")
    print("\nüí° You can redeploy anytime using the generated project files!\n")