# Deploy Llama 4 with Amazon SageMaker Jumpstart

In this notebook, We will use the ModelBuilder class to deploy [Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/) model to SageMaker Endpoints.

[HuggingFace Model Card](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct)

## Prerequisites. 
You will need access access to a **ml.p4de.24xlarge** or **ml.g6e.48xlarge**


In [None]:
%pip install --upgrade --quiet --no-warn-conflicts sagemaker

In [None]:
import json
import sagemaker
import boto3

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment

sm_client = boto3.client("sagemaker")  # client to intreract with SageMaker
smr_client = boto3.client("sagemaker-runtime")  # client to intreract with SageMaker Endpoints
s3_client = boto3.client("s3")

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")
print(f"sagemaker version: {sagemaker.__version__}")

## Deploy Llama4-Scount-17B-16E-Instruct from Amazon SageMaker JumpStart

In [None]:
from sagemaker.jumpstart.model import JumpStartModel

model_id, model_version = "meta-vlm-llama-4-scout-17b-16e-instruct", "1.0.0"

model_name = endpoint_name = sagemaker.utils.name_from_base("llama-4-scout-17b")

jumpstart_model = JumpStartModel(model_id=model_id, 
                                 model_version=model_version, 
                                 name=model_name)

jumpstart_model.deploy(
    accept_eula=False,  # Please accept Meta Llama4 EULA by changing this to True
    instance_type="ml.p4de.24xlarge",
    initial_instance_count=1,
    container_startup_health_check_timeout=1800,
    endpoint_name=endpoint_name,
)

## Inference examples

In [None]:
import boto3
from PIL import Image
import requests
from io import BytesIO
 
runtime = boto3.client('sagemaker-runtime')

def get_image_urls(payload):
    image_urls = []
    payload = json.loads(payload)
    for msg in payload["messages"]:
        if type(msg["content"]) == list:
            for ms in msg["content"]:
                typ = ms["type"]
                if typ == "image_url":
                    image_url = ms["image_url"]["url"]
                    image_urls.append(image_url)
    return image_urls
    
def display_images(image_paths):
    """
    Displays multiple images side by side using PIL.

    Args:
        image_paths: A list of file paths to the images.
    """
    responses = [ requests.get(url) for url in image_paths]
    images = [Image.open(BytesIO(response.content)) for response in responses]
    widths, heights = zip(*(i.size for i in images))

    total_width = sum(widths)
    max_height = max(heights)

    new_image = Image.new('RGB', (total_width, max_height))

    x_offset = 0
    for image in images:
        new_image.paste(image, (x_offset, 0))
        x_offset += image.size[0]

    new_image.show()
    
def predict(payload, endpoint_name, imgs=False):
    response = runtime.invoke_endpoint(EndpointName=endpoint_name, 
                                       ContentType='application/json', 
                                       Body=payload)
    if imgs:
        image_urls = get_image_urls(payload)
        display_images(image_urls)
    result = json.loads(response['Body'].read().decode())
    return result["choices"][0]["message"]["content"]

### Simple text generation

In [None]:
messages = [
    {"role": "user", "content": "Describe Amazon SageMaker AI service"},
]

payload = {
    "messages": messages
}

res = predict(json.dumps(payload), endpoint_name)
print(res)

### Multilingual

In [None]:
messages = [
    {"role": "user", "content": "Write a haiku about springtime, but in Tagalog"},
]

payload = {
    "messages": messages
}

res = predict(json.dumps(payload), endpoint_name)
print(res)

### Multimodal

Let's ask the model to describe this image:

<img src="./img/rabbit.jpg" style="width:400px;"/>

In [None]:
import base64

def image_to_base64_data_uri(file_path):
    with open(file_path, "rb") as img_file:
        base64_data = base64.b64encode(img_file.read()).decode('utf-8')
        return f"{base64_data}"

In [None]:
file_path = './img/rabbit.jpg'
data_uri = image_to_base64_data_uri(file_path)

data = {
        "messages": [
            {
                "role": "system",
                "content": "You are a helpful assistant",
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": "Describe the image in detail please.",
                    },
                    {
                        "type": "image_url", 
                        "image_url": {"url": f"data:image/jpg;base64,{data_uri}"}
                    },
                ],
            },
        ],
        "temperature": 0.6,
        "top_p": 0.9,
    }
payload = json.dumps(data)
print(predict(payload, endpoint_name))

Another image understanding example

<img src="./img/trip.png" style="width:600px;"/>

In [None]:
file_path = './img/trip.png'
data_uri = image_to_base64_data_uri(file_path)

data = {
        "messages": [
            {
                "role": "system",
                "content": "You are a helpful assistant",
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": "Describe the location and season in detail please.",
                    },
                    {
                        "type": "image_url", 
                        "image_url": {"url": f"data:image/jpg;base64,{data_uri}"}
                    },
                ],
            },
        ],
        "temperature": 0.6,
        "top_p": 0.9,
    }
payload = json.dumps(data)
print(predict(payload, endpoint_name, imgs=False))

## Cleanup

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
sess.delete_model(model_name)