# Deploy the model as Sagemaker endpoint and Invoke it

In [None]:
import boto3
import yaml
# Prepare boto3 Sagemaker client
region = "us-west-2"
sm_client = boto3.client("sagemaker", region_name=region)

# Role to give SageMaker service permission to access your account resources (s3, etc.). Change role ARN to correct one.
sagemaker_role = "arn:aws:iam::345967381662:role/service-role/AmazonSageMaker-ExecutionRole-20180829T140091"

In [None]:
### Create Sagemaker endpoint and deploy the model
# https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-deployment.html

#Get model from S3
model_url = f"s3://sagemaker-us-west-2-345967381662/stable-diffusion/text-to-image/sm_model_g5.tar.gz"

#Get SM container image (prebuilt example)
container = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.0.1-gpu-py310-cu118-ubuntu20.04-sagemaker"

# ==== Create model ====
model_name = "stable-diffusion-2-1-base-g5"

create_model_response = sm_client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = sagemaker_role,
    Containers = [{
        "Image": container,
        "Mode": "SingleModel",
        "ModelDataUrl": model_url,
    }]
)
print(yaml.dump(create_model_response))

In [None]:
##### === Create Endpoint Config ====

endpoint_config_name = "stable-diffusion-2-1-base-g5"
instance_type = "ml.g5.2xlarge"

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name, # You will specify this name in a CreateEndpoint request.
    # List of ProductionVariant objects, one for each model that you want to host at this endpoint.
    ProductionVariants=[
        {
            "VariantName": "variant1", # The name of the production variant.
            "ModelName": model_name,
            "InstanceType": instance_type, # Specify the compute instance type.
            "InitialInstanceCount": 1 # Number of instances to launch initially.
        }
    ]
)
print(yaml.dump(endpoint_config_response))

In [None]:
# ==== Create Endpoint ====

endpoint_name = 'stable-diffusion-2-1-base-g5'

create_endpoint_response = sm_client.create_endpoint(
        EndpointName=endpoint_name,
        EndpointConfigName=endpoint_config_name
)
print(yaml.dump(create_endpoint_response))

In [None]:
# ==== Check Endpoint Status ====
desc_endpoint_response = sm_client.describe_endpoint(
    EndpointName=endpoint_name
)
print(f"EndpointStatus: {desc_endpoint_response.get('EndpointStatus', None)}")
print("==========================================================")
print(yaml.dump(desc_endpoint_response))

In [None]:
# Invoke the Endpoint via Boto3 SageMaker Client

import boto3
from botocore.config import Config
import json
import yaml

content_type = "application/json"
# You can use multiple prompts in prompt array 
# if the model was compiled with batch size greater than one (or with variable batch size)
request_body = {
    "prompt": ["a photo of an astronaut riding a horse on mars"]
}

# Serialize data for endpoint
payload = json.dumps(request_body)

config = Config(region_name = 'us-west-2')
sm_runtime_client = boto3.client("sagemaker-runtime", config=config)
response = sm_runtime_client.invoke_endpoint(
    # change to your endpoint name returned in the previous step
    EndpointName=endpoint_name,
    ContentType="application/json",
    Body=payload,
)
print(yaml.dump(response['ResponseMetadata']))
res = response["Body"].read()

In [None]:
# Visualize the Generated Image

import matplotlib.pyplot as plt
import base64
import numpy as np

for img_encoded in eval(res)["images"]:
    pred_decoded_byte = base64.decodebytes(
        bytes(img_encoded, encoding="utf-8")
    )
    # update H for used model edition (base - 512, regular - 768)
    H = 512
    pred_decoded = np.reshape(np.frombuffer(pred_decoded_byte, dtype=np.uint8), (H, H, 3))
    plt.imshow(pred_decoded)
    plt.axis("off")
    plt.show()