# Multimodal Agent (Advanced): Run EXAONE Atelier from AWS Marketplace

---

In this demo notebook, we demonstrate how to use the SageMaker Python SDK to deploy a multimodal agent using EXAONE Atelier Image Captioning Model and Large Language Models (LLMs) from AWS Marketplace. 

---

## Setup

***

In [None]:
%pip install --upgrade --quiet sagemaker
%pip install --upgrade --quiet sagemaker accelerate datasets tritonclient[all] gradio

In [None]:
import boto3

#model_package = "exaone-atelier-i2t-limited-85f3f0d181593a10b7aef9bea522a333" # EXAONE Atelier - Image to Text - Limited
model_package = "exaone-atelier-i2t-76c77246a8343a23a36b2ce80c06f4f6" # EXAONE Atelier - Image to Text
model_package_map = {
    "us-east-1": f"arn:aws:sagemaker:us-east-1:865070037744:model-package/{model_package}",
    "us-east-2": f"arn:aws:sagemaker:us-east-2:057799348421:model-package/{model_package}",
    "us-west-1": f"arn:aws:sagemaker:us-west-1:382657785993:model-package/{model_package}",
    "us-west-2": f"arn:aws:sagemaker:us-west-2:594846645681:model-package/{model_package}",
    "ca-central-1": f"arn:aws:sagemaker:ca-central-1:470592106596:model-package/{model_package}",
    "eu-central-1": f"arn:aws:sagemaker:eu-central-1:446921602837:model-package/{model_package}",
    "eu-west-1": f"arn:aws:sagemaker:eu-west-1:985815980388:model-package/{model_package}",
    "eu-west-2": f"arn:aws:sagemaker:eu-west-2:856760150666:model-package/{model_package}",
    "eu-west-3": f"arn:aws:sagemaker:eu-west-3:843114510376:model-package/{model_package}",
    "eu-north-1": f"arn:aws:sagemaker:eu-north-1:136758871317:model-package/{model_package}",
    "ap-southeast-1": f"arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/{model_package}",
    "ap-southeast-2": f"arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/{model_package}",
    "ap-northeast-2": f"arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/{model_package}",
    "ap-northeast-1": f"arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/{model_package}",
    "ap-south-1": f"arn:aws:sagemaker:ap-south-1:077584701553:model-package/{model_package}",
    "sa-east-1": f"arn:aws:sagemaker:sa-east-1:270155090741:model-package/{model_package}",
}
region = boto3.Session().region_name
if region not in model_package_map.keys():
    raise Exception(f"Current boto3 session region {region} is not supported.")

model_package_arn = model_package_map[region]

In [None]:
model_name = 'exaone-atelier-i2t'

In this demo notebook, we use multiple crops of one image to provide LLMs with more detailed information about the image.

We recommend using `p4d.24xlarge` for faster image processing.

In [None]:
from sagemaker import ModelPackage
from sagemaker import get_execution_role
import sagemaker

role = get_execution_role()
sagemaker_session = sagemaker.Session()


model = ModelPackage(
    role=role,
    model_package_arn=model_package_arn,
    sagemaker_session=sagemaker_session,
)

model.deploy(
    initial_instance_count=1,
    instance_type='ml.p4d.24xlarge',
    endpoint_name=model_name,
    container_startup_health_check_timeout=3600,
)
model.endpoint_name

---
Here we use Meta's Llama-2 70B for LLM.  

To perform inference on Llama-2 models, you need to pass custom_attributes='accept_eula=true' as part of header. This means you have read and accept the end-user-license-agreement (EULA) of the model. EULA can be found in model card description or from https://ai.meta.com/resources/models-and-libraries/llama-downloads/. By default, this notebook sets custom_attributes='accept_eula=false', so all inference requests will fail until you explicitly change this custom attribute.

---

In [None]:
model_id, model_version = "meta-textgeneration-llama-2-70b-f", "3.0.0"

from sagemaker.jumpstart.model import JumpStartModel

llm_model = JumpStartModel(model_id=model_id, model_version=model_version)
llm_predictor = llm_model.deploy(accept_eula=True)

In [None]:
prompt = "suggest a good place to eat."

llm_payload = {
    "inputs": f"[INST] {prompt} [/INST] ",    
    "parameters": {"max_new_tokens": 512, "top_p": 0.9, "temperature": 0.6},
}

response = llm_predictor.predict(llm_payload, custom_attributes="accept_eula=true") # set accept_eula to true

print(response[0]['generated_text'])

## Invoke the Endpoint

In [None]:
import base64
from PIL import Image
from io import BytesIO
import numpy as np
import tritonclient.http as httpclient
import requests
import json
import boto3

smr_client = boto3.client("sagemaker-runtime")

def encode_image(image):
    buffer = BytesIO()
    image.save(buffer, format="JPEG")
    img_str = base64.b64encode(buffer.getvalue())
    return img_str

def get_sample_binary(payload):

    inputs = []
    outputs = []
    for idx, dic in enumerate(payload["inputs"]):
        input_name = dic["name"]
        input_value = dic["data"][0]

        input_value = np.array([input_value.encode('utf-8')], dtype=np.object_)

        input_value = np.expand_dims(input_value, axis=0)
        inputs.append(httpclient.InferInput(input_name, [1, 1], "BYTES"))
        inputs[idx].set_data_from_numpy(input_value)

    outputs.append(httpclient.InferRequestedOutput("generated_caption", binary_data=True))

    request_body, header_length = httpclient.InferenceServerClient.generate_request_body(
        inputs, outputs=outputs
    )
    return request_body, header_length


def invoke_endpoint(endpoint_name, payload):
    import re
    request_body, header_length = get_sample_binary(payload)
    response = smr_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType="application/vnd.sagemaker-triton.binary+json;json-header-size={}".format(
            header_length
        ),
        Body=request_body
    )
    data = response["Body"].read()
    ptn = re.compile(rb'\{"binary_data_size":[0-9]*\}')
    match = json.loads(ptn.search(data).group().decode('utf-8'))
    binary_data_size = match['binary_data_size']
    binary_data = data[len(data)-binary_data_size+1:]
    binary_data = binary_data.replace(b'\x00', b'')
    binary_data = binary_data.replace(b'\x01', b'').decode('utf-8')    

    return eval(binary_data)

### Prepare image


In [None]:
def prepare_image(image, pro_mode=False, max_size = 2048):
    if pro_mode:   
        width, height = image.size
        if width < max_size or height < max_size:
            if width > height:
                ratio = width / height         
                image = image.resize((int( max_size * ratio), max_size)) 
            else:
                ratio = height / width
                image = image.resize((max_size, int( max_size * ratio)))    
        image.thumbnail((max_size, max_size))

        width, height = image.size
        crop_dim = min(width, height)
        #center crop
        left = round((width - crop_dim)/2)
        top = round((height - crop_dim)/2)
        x_right = round(width - crop_dim) - left
        x_bottom = round(height - crop_dim) - top
        right = width - x_right
        bottom = height - x_bottom
        center = image.crop((left, top, right, bottom))    

        #upper left 
        crop_dim = round(crop_dim * 2/3)
        crop_dim = min(crop_dim, (max(width, height) - crop_dim) * 2/3)
        left = 0
        top = 0
        right = left + crop_dim
        bottom = top + crop_dim
        upper_left = image.crop((left, top, right, bottom))   
    
        #upper right
        left = width - crop_dim
        top = 0
        right = left + crop_dim
        bottom = top + crop_dim
        upper_right = image.crop((left, top, right, bottom))   

        #lower left 
        left = 0
        top = height - crop_dim
        right = left + crop_dim
        bottom = top + crop_dim
        lower_left = image.crop((left, top, right, bottom))   

        #lower right 
        left = width - crop_dim
        top = height - crop_dim
        right = left + crop_dim
        bottom = top + crop_dim
        lower_right = image.crop((left, top, right, bottom)) 
        return center, upper_left, upper_right, lower_left, lower_right
    else:
        width, height = image.size
        if width < 256 or height < 256:       
            image = image.resize((256, 256)) 

        else:
            resize_dim = min(width, height)
            image = image.resize((resize_dim, resize_dim))
        return image    

In [None]:
def generate_input(image):
    input_image = encode_image(image)
    inputs = dict(
        image=input_image,
    )

    payload = {
        "inputs": [
            {"name": name, "shape": [1, -1], "datatype": "BYTES", "data": [data.decode('utf8')]}
            for name, data in inputs.items()
        ]
    }
    return payload

### Run Multimodal Agent with Gradio
---
In this example, we use Gradio to run an demo page of actual multimodal agent. Before you upload an image, you can choose turn `Pro Mode` on by clicking check box. Everytime you change the mode, press the `Upload` button again.

In [None]:
import gradio as gr

endpoint_name = model.endpoint_name

def captioning_fn(image, pro_mode=False):
    if pro_mode:
        center, upper_left, upper_right, lower_left, lower_right = prepare_image(image, pro_mode)   
        result=""
        captions = invoke_endpoint(endpoint_name, generate_input(center))
        result += (captions[0][0]+'\n\n')
        result += (invoke_endpoint(endpoint_name, generate_input(upper_left))[0][0]+'\n')
        result += (invoke_endpoint(endpoint_name, generate_input(upper_right))[0][0]+'\n')
        result += (invoke_endpoint(endpoint_name, generate_input(lower_left))[0][0]+'\n')
        result += (invoke_endpoint(endpoint_name, generate_input(lower_right))[0][0]+'\n')
    else:
        image = prepare_image(image, pro_mode)   
        result=""
        result += (invoke_endpoint(endpoint_name, generate_input(image))[0][0]+'\n')        
    
    return result, 'image upload finished.'

def chat_fn(captions, question, pro_mode=False):
    if pro_mode:
        prompt = f"""Interpret the following elements as different aspects of a single image. 
The elements begin with the central focus of the image and extend to include various details surrounding it, creating a comprehensive view of the entire image:

{captions}
Each of these elements contributes to the overall understanding of the image. Consider them collectively to form a unified interpretation. 
Then, based on this integrated perspective, provide a response to the user's request.

Request: {question} 
        """
    else:
        prompt = f"""Here's a description of an image.

{captions}
Based on given description, provide a response to the user's request.

Request: {question} 
        """        
        
    llm_payload = {
    "inputs": f"[INST] {prompt} [/INST] ",    
    "parameters": {"max_new_tokens": 512, "top_p": 0.9, "temperature": 0.6},
    }
    print(prompt)
    response = llm_predictor.predict(llm_payload, custom_attributes="accept_eula=true") # set accept_eula to true

    return response[0]['generated_text']

with gr.Blocks(css="""
    .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
    #component-21 > div.wrap.svelte-w6rprc {height: 600px;}
    """) as demo:
        title = """<h1 align=center>EXAONE Multimodal</h1>"""
        gr.Markdown(title)       
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(type="pil")   
                with gr.Row():
                    pro_mode = gr.Checkbox(label="Pro Mode")
                    predict_btn = gr.Button(value="Upload")   
                with gr.Row():      
                    req_prompt = gr.Textbox(label="Input")                     
                llm_btn = gr.Button(value="Generate")                                            
                                  
            with gr.Column():
                best_cap = gr.Textbox(label="Captions", visible=False) 
                answer = gr.Textbox(label="Output")  
        predict_btn.click(
            captioning_fn,
            inputs=[
                input_image,
                pro_mode,
            ],
            outputs=[best_cap, answer])                 
        llm_btn.click(
            chat_fn,
            inputs=[best_cap,
                req_prompt,
                pro_mode,
            ],
            outputs=[answer])   

demo.queue().launch(auth=("user", "test1234"), share=True)

## Clean Up the Endpoint

In [None]:
# Delete the SageMaker endpoint
model.sagemaker_session.delete_endpoint(model.endpoint_name)
model.sagemaker_session.delete_endpoint_config(model.endpoint_name)
model.delete_model()

llm_predictor.delete_model()
llm_predictor.delete_endpoint()