# Deploying Stable Diffusion using Stability AI DLC on AWS SageMaker

## Example: Stable Diffusion XL v1.0 on PyTorch 2.0.1

This example will deploy an endpoint running Stable Diffusion XL on AWS SageMaker using the Stability AI DLC. This example can provide inference as-is or serve as a basis for custom development & deployment scenarios.

If you are looking for a production-ready, turnkey solution for inference with a full-featured API, check out [SDXL on AWS Marketplace](https://aws.amazon.com/marketplace/seller-profile?id=seller-mybtdwpr2puau) and the related [Jumpstart notebooks](https://github.com/Stability-AI/aws-jumpstart-examples).

In [None]:
# NOTE: You may have to restart your kernel after installing boto3
!pip install "sagemaker>=2.173.0" "boto3>=1.28.9" --upgrade --quiet

import sagemaker
from sagemaker import ModelPackage, get_execution_role

from PIL import Image
from typing import Union, Tuple
import io
import os
import base64
import boto3

In [None]:
import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")


## 1. Copy the prebuilt model archive
You can skip this step when redeploying as long as model_url is pointing at a copy of the archive.

In [None]:
model_filename = "sdxlv1-sgm0.1.0.tar.gz"
model_source_uri = f"s3://stabilityai-public-packages/model-packages/sdxl-v1-0-dlc/sgm0.1.1/{model_filename}"
model_uri = f's3://{sagemaker_session_bucket}/stabilityai/sdxl-v1-0-dlc/sgm0.1.1/{model_filename}'

!aws s3 cp {model_source_uri} {model_filename}
!aws s3 cp {model_filename} {model_uri}

## 2. Create and deploy a model and perform real-time inference

In [None]:
# images are available in us-east-1 and us-west-2
inference_image_uri = '188650660114.dkr.ecr.us-east-1.amazonaws.com/stabilityai-pytorch-inference:2.0.1-sgm0.1.1-gpu-py310-cu118-ubuntu20.04-sagemaker'
#inference_image_uri = '188650660114.dkr.ecr.us-west-2.amazonaws.com/stabilityai-pytorch-inference:2.0.1-sgm0.1.1-gpu-py310-cu118-ubuntu20.04-sagemaker'

In [None]:
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import BytesDeserializer
from sagemaker.utils import name_from_base


endpoint_name = name_from_base(f"sdxl-v1")

pytorch_model = PyTorchModel(
    name=endpoint_name,
    model_data=model_uri,
    image_uri=inference_image_uri,
    role=role    
)

deployed_model = pytorch_model.deploy(
    endpoint_name=endpoint_name,
    initial_instance_count=1,
    instance_type="ml.g5.4xlarge", # 4xlarge is required to load the model
    serializer=JSONSerializer(),
    deserializer=BytesDeserializer(accept="image/png")
)

The code below can be used to create a predictor from an existing endpoint.

In [None]:
# from sagemaker.predictor import Predictor
# from sagemaker.serializers import JSONSerializer
# from sagemaker.deserializers import BytesDeserializer

# # Create a predictor with proper serializers
# deployed_model = Predictor(
#     endpoint_name=endpoint_name, 
#     sagemaker_session=sess,
#     serializer=JSONSerializer(),
#     deserializer=BytesDeserializer(accept="image/png")

# )

## A. Text to image

**Note**: The endpoint will be "InService" before the model has finished loading, so this request will initially time out. Check the endpoint logs in CloudWatch for status.

In [None]:
# Helper to display images
def decode_and_show(model_response) -> None:
    """
    Decodes and displays an image from SDXL output

    Args:
        model_response (GenerationResponse): The response object from the deployed SDXL model.

    Returns:
        None
    """        
    image = Image.open(io.BytesIO(model_response))
    display(image)

In [None]:
output = deployed_model.predict({"text_prompts":[{"text": "jaguar in the Amazon rainforest"}],                                             
                                             "seed": 133,
                                            "width": 1024,
                                            "height": 1024})
decode_and_show(output)                                             


Available samplers are:
```
“EulerEDMSampler”,
“HeunEDMSampler”,
“EulerAncestralSampler”,
“DPMPP2SAncestralSampler”,
“DPMPP2MSampler”,
“LinearMultistepSampler”,
```

In [None]:
text = "photograph of latte art of a cat"

output = deployed_model.predict({"text_prompts":[{"text":text}],
                                            "seed":45,
                                            "height":640,
                                            "width":1536,
                                            "sampler":"EulerEDMSampler",
                                })
decode_and_show(output)

SDXL can render short snippets of text, like single words. Let's try an example below.

In [None]:
text = "the word go written in neon lights"

output = deployed_model.predict({"text_prompts":[{"text":text}],                                            
                                            "seed": 142,
                                            "height": 640,
                                            "width": 1536,
                                            "sampler": "LinearMultistepSampler",
                                })
decode_and_show(output)

## B. Image to image

To perform inference that takes an image as input, you must pass the image into `init_image` as a base64-encoded string.

Below is a helper function for converting images to base64-encoded strings:

In [None]:
def encode_image(image_path: str, resize: bool = True, size: Tuple[int, int] = (1024, 1024)) -> Union[str, None]:
    """
    Encode an image as a base64 string, optionally resizing it to a supported resolution.

    Args:
        image_path (str): The path to the image file.
        resize (bool, optional): Whether to resize the image. Defaults to True.

    Returns:
        Union[str, None]: The encoded image as a string, or None if encoding failed.
    """
    assert os.path.exists(image_path)

    if resize:
        image = Image.open(image_path)
        image = image.resize(size)
        image.save("image_path_resized.png")
        image_path = "image_path_resized.png"
    image = Image.open(image_path)
    assert image.size == size
    with open(image_path, "rb") as image_file:
        img_byte_array = image_file.read()
        # Encode the byte array as a Base64 string
        try:
            base64_str = base64.b64encode(img_byte_array).decode("utf-8")
            return base64_str
        except Exception as e:
            print(f"Failed to encode image {image_path} as base64 string.")
            print(e)
            return None
    

In [None]:
! wget https://platform.stability.ai/Cat_August_2010-4.jpg

Let's feed an image into the model as well as the prompt this time. We can set `image_scale` to weight the relative importance of the image and the prompt. For the demo, we'll use a [picture of the cat, taken from Wikimedia Commons](https://commons.wikimedia.org/wiki/File:Cat_August_2010-4.jpg), provided along with this notebook.

In [None]:
# Here is the original image:
display(Image.open('Cat_August_2010-4.jpg'))

In [None]:
cat_path = "Cat_August_2010-4.jpg"

size = (1536, 640)
cat_data = encode_image(cat_path, size=size)

output = deployed_model.predict({"text_prompts":[{"text": "cat in embroidery"}],
                                                  "init_image": cat_data,
                                                  "cfg_scale": 9,
                                                  "image_strength": 0.8,
                                                  "seed": 42,
                                                  })                                            
decode_and_show(output)

# <a id='toc5_'></a>[3: Delete the endpoint](#toc0_)

When you've finished working, you can delete the endpoint to release the EC2 instance(s) associated with it, and stop billing.

Get your list of Sagemaker endpoints using the AWS Sagemaker CLI like this:

In [None]:
!aws sagemaker list-endpoints

# Delete an endpoint

In [None]:
deployed_model.delete_endpoint()
# Rerun the aws cli command above to confirm that its gone.