# Deploying Stable Diffusion using Stability AI DLC on AWS SageMaker

## Example: Stable Diffusion XL v1.0 on PyTorch 2.0.1

This example (based on https://github.com/Stability-AI/aws-dlc-examples/tree/main) will deploy an endpoint running Stable Diffusion XL on AWS SageMaker using the Stability AI DLC. This example can provide inference as-is or serve as a basis for custom development & deployment scenarios.

If you are looking for a production-ready, turnkey solution for inference with a full-featured API, check out [SDXL on AWS Marketplace](https://aws.amazon.com/marketplace/seller-profile?id=seller-mybtdwpr2puau) and the related [Jumpstart notebooks](https://github.com/Stability-AI/aws-jumpstart-examples).

# Supported regions for this example

This example uses g5 instances. Particularly, a minimum requirement is to use g5.4xlarge instances.

G5 family instances are not available in every region. As a result, we recommend you use one of the following regions for your example.

- us-east-1 (Virginia)
- us-east-2 (Ohio)
- us-west-2 (Oregon)
- ca-central-1 (Canada)
- eu-west-1 (Ireland)
- eu-central-1 (Frankfurt)
- eu-west-2 (London)
- ap-northeast-1 (Tokyo)
- ap-south-1 (Mumbai)
- ap-northeast-2 (Seoul)
- ap-southeast-2 (Sydney)
- sa-east-1 (Sao Paulo)

Source - https://aws.amazon.com/about-aws/whats-new/2023/06/amazon-ec2-g5-instances-additional-regions/

In [None]:
# NOTE: You may have to restart your kernel after installing boto3
!pip install "sagemaker" "huggingface_hub" "boto3" --upgrade --quiet

import sagemaker
from sagemaker import ModelPackage, get_execution_role

from PIL import Image
from typing import Union, Tuple
import io
import os
import base64
import boto3

aws_region = "us-east-1"
# In order to upload SageMaker ModelPackage, we need to have a bucket in the same region as the SageMaker endpoint and permission to read and write to it.
sagemaker_model_bucket="sagemak-develop-modelbucket-dwv1"

: 

## 1. Download the model weights

In [None]:
import os
from huggingface_hub import snapshot_download
local_dir = './model'
snapshot_download(
    repo_id="stabilityai/stable-diffusion-xl-base-1.0",
    allow_patterns="sd_xl_base_1.0.safetensors",
    local_dir=local_dir,
    local_dir_use_symlinks=False)
snapshot_download(
    repo_id="stabilityai/stable-diffusion-xl-refiner-1.0",
    allow_patterns="sd_xl_refiner_1.0.safetensors",
    local_dir=local_dir,
    local_dir_use_symlinks=False)


## 2. Custom Inference Script Creation

In [None]:
!mkdir -p model/code

### Inference Script: Text2Image, Image2Image

In [None]:
%%writefile model/code/inference.py
import base64
from io import BytesIO
from einops import rearrange
import json
from pathlib import Path
from PIL import Image
from pytorch_lightning import seed_everything
import numpy as np
from sagemaker_inference.errors import BaseInferenceToolkitError
import sgm
from sgm.inference.api import (
    ModelArchitecture,
    SamplingParams,
    SamplingPipeline,
    Sampler,
    get_sampler_config,
)
from sgm.inference.helpers import (
    get_input_image_tensor,
    embed_watermark,
    do_img2img,
    Img2ImgDiscretizationWrapper,
)
import os


def model_fn(model_dir, context=None):
    # Enable the refiner by default
    disable_refiner = os.environ.get("SDXL_DISABLE_REFINER", "false").lower() == "true"

    sgm_path = os.path.dirname(sgm.__file__)
    config_path = os.path.join(sgm_path, "configs/inference")
    if not os.path.exists(config_path):
        config_path = os.path.join(sgm_path, "../configs/inference")
    base_pipeline = SamplingPipeline(
        ModelArchitecture.SDXL_V1_BASE, model_path=model_dir, config_path=config_path
    )
    if disable_refiner:
        print("Refiner model disabled by SDXL_DISABLE_REFINER environment variable")
        refiner_pipeline = None
    else:
        refiner_pipeline = SamplingPipeline(
            ModelArchitecture.SDXL_V1_REFINER,
            model_path=model_dir,
            config_path=config_path,
        )

    return {"base": base_pipeline, "refiner": refiner_pipeline}


def input_fn(request_body, request_content_type):
    if request_content_type == "application/json":
        model_input = json.loads(request_body)
        if not "text_prompts" in model_input:
            raise BaseInferenceToolkitError(
                400, "Invalid Request", "text_prompts missing"
            )
        return model_input
    else:
        raise BaseInferenceToolkitError(
            400, "Invalid Request", "Content-type must be application/json"
        )


def predict_fn(data, model, context=None):
    # Only a single positive and optionally a single negative prompt are supported by this example.
    prompts = []
    negative_prompts = []
    if "text_prompts" in data:
        for text_prompt in data["text_prompts"]:
            if "text" not in text_prompt:
                raise BaseInferenceToolkitError(
                    400, "Invalid Request", "text missing from text_prompt"
                )
            if "weight" not in text_prompt:
                text_prompt["weight"] = 1.0
            if text_prompt["weight"] < 0:
                negative_prompts.append(text_prompt["text"])
            else:
                prompts.append(text_prompt["text"])

    if len(prompts) != 1:
        raise BaseInferenceToolkitError(
            400,
            "Invalid Request",
            "One prompt with positive or default weight must be supplied",
        )
    if len(negative_prompts) > 1:
        raise BaseInferenceToolkitError(
            400, "Invalid Request", "Only one negative weighted prompt can be supplied"
        )

    seed = 0
    height = 1024
    width = 1024
    sampler_name = "DPMPP2MSampler"
    cfg_scale = 7.0
    steps = 40
    use_refiner = model["refiner"] is not None
    init_image = None
    image_strength = 0.35
    refiner_steps = 40
    refiner_strength = 0.2

    if "height" in data:
        height = data["height"]
    if "width" in data:
        width = data["width"]
    if "sampler" in data:
        sampler_name = data["sampler"]
    if "cfg_scale" in data:
        cfg_scale = data["cfg_scale"]
    if "steps" in data:
        steps = data["steps"]
    if "seed" in data:
        seed = data["seed"]
        seed_everything(seed)
    if "use_refiner" in data:
        use_refiner = data["use_refiner"]
    if use_refiner:
        if "refiner_steps" in data:
            refiner_steps = data["refiner_steps"]
        if "refiner_strength" in data:
            refiner_strength = data["refiner_strength"]
    if "init_image" in data:
        if "image_strength" in data:
            image_strength = data["image_strength"]
        try:
            init_image_bytes = BytesIO(base64.b64decode(data["init_image"]))
            init_image_bytes.seek(0)
            if init_image_bytes is not None:
                init_image = get_input_image_tensor(Image.open(init_image_bytes))
        except Exception as e:
            raise BaseInferenceToolkitError(
                400, "Invalid Request", "Unable to decode init_image"
            )

    if model["refiner"] is None and use_refiner:
        raise BaseInferenceToolkitError(
            400, "Invalid Request", "Pipeline is not available"
        )

    try:
        if init_image is not None:
            img_height, img_width = init_image.shape[2], init_image.shape[3]
            output = model["base"].image_to_image(
                params=SamplingParams(
                    width=img_width,
                    height=img_height,
                    steps=steps,
                    sampler=Sampler(sampler_name),
                    scale=cfg_scale,
                    img2img_strength=image_strength,
                ),
                image=init_image,
                prompt=prompts[0],
                negative_prompt=negative_prompts[0]
                if len(negative_prompts) > 0
                else "",
                return_latents=use_refiner,
            )
        else:
            output = model["base"].text_to_image(
                params=SamplingParams(
                    width=width,
                    height=height,
                    steps=steps,
                    sampler=Sampler(sampler_name),
                    scale=cfg_scale,
                ),
                prompt=prompts[0],
                negative_prompt=negative_prompts[0]
                if len(negative_prompts) > 0
                else "",
                return_latents=use_refiner,
            )

        if isinstance(output, (tuple, list)):
            samples, samples_z = output
        else:
            samples = output
            samples_z = None

        if use_refiner and samples_z is not None:
            print("Running Refinement Stage")
            samples = refiner(
                model=model["refiner"].model,
                params=SamplingParams(
                    steps=refiner_steps,
                    sampler=Sampler.EULER_EDM,
                    scale=5.0,
                    img2img_strength=refiner_strength,
                ),
                image=samples_z,
                prompt=prompts[0],
                negative_prompt=negative_prompts[0]
                if len(negative_prompts) > 0
                else "",
            )

        samples = embed_watermark(samples)
        images = []
        for sample in samples:
            sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
            image_bytes = BytesIO()
            Image.fromarray(sample.astype(np.uint8)).save(image_bytes, format="PNG")
            image_bytes.seek(0)
            images.append(image_bytes.read())

        return images

    except ValueError as e:
        raise BaseInferenceToolkitError(400, "Invalid Request", str(e))


# fixed version of refiner function from sgm 0.1.1
def wrap_discretization(discretization, strength=1.0):
    if not isinstance(discretization, Img2ImgDiscretizationWrapper) and strength < 1.0:
        return Img2ImgDiscretizationWrapper(discretization, strength=strength)
    return discretization


def refiner(
    model,
    image,
    prompt: str,
    negative_prompt: str = "",
    params: SamplingParams = SamplingParams(
        sampler=Sampler.EULER_EDM, steps=40, img2img_strength=0.2
    ),
    samples: int = 1,
    return_latents: bool = False,
):
    sampler = get_sampler_config(params)
    value_dict = {
        "orig_width": image.shape[3] * 8,
        "orig_height": image.shape[2] * 8,
        "target_width": image.shape[3] * 8,
        "target_height": image.shape[2] * 8,
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "crop_coords_top": 0,
        "crop_coords_left": 0,
        "aesthetic_score": 6.0,
        "negative_aesthetic_score": 2.5,
    }

    sampler.discretization = wrap_discretization(
        sampler.discretization, strength=params.img2img_strength
    )

    return do_img2img(
        image,
        model,
        sampler,
        value_dict,
        samples,
        skip_encode=True,
        return_latents=return_latents,
        filter=None,
    )


def output_fn(prediction, accept):
    # This only returns a single image since that's all the example code supports
    if accept != "image/png":
        raise BaseInferenceToolkitError(
            400, "Invalid Request", "Accept header must be image/png"
        )
    return prediction[0], accept

## 3. Package and upload model archive

In [None]:
import sagemaker
import boto3
sess = sagemaker.Session()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_model_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")


In [None]:
# Rerun this cell only if you need to re-upload the weights, otherwise you can reuse the existing model_package_name and upload only your new code 
from sagemaker.utils import name_from_base
model_package_name = name_from_base(f"sdxl-v1") # You may want to make this a fixed name of your choosing instead
model_uri = f's3://{sagemaker_model_bucket}/{model_package_name}/'

print(f'Uploading base model to {model_uri}, this will take a while...')
!aws s3 cp model/sd_xl_base_1.0.safetensors {model_uri}
print(f'Uploading refiner model to {model_uri}, this will take a while...')
!aws s3 cp model/sd_xl_refiner_1.0.safetensors {model_uri}

In [None]:
# Rerun this cell when you have changed the code or are uploading a fresh copy of the weights
print(f'Uploading code to {model_uri}code')
!aws s3 cp model/code/inference.py {model_uri}code/inference.py
print("Done!")

## 4. Create and deploy a model and perform real-time inference

boto3 is being used to deploy the model here to take advantage of [Uncompressed model downloads](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-uncompressed.html)

In [None]:
# Please only use regions with g5 instance support, mentioned at the top of this page
inference_image_uri_region = aws_region

# AWS Account for AWS's publicly accessible DLC (Deep Learning Containers) ECR
inference_image_uri_region_acct = "763104351884"

inference_image_uri = f"{inference_image_uri_region_acct}.dkr.ecr.{inference_image_uri_region}.amazonaws.com/stabilityai-pytorch-inference:2.0.1-sgm0.1.0-gpu-py310-cu118-ubuntu20.04-sagemaker"


print("You will need the inference_image_uri for your model creation in Massdriver:")
print(inference_image_uri)

In [None]:
endpoint_name = name_from_base(f"sdxl-v1")
sagemaker_client = boto3.client('sagemaker')
# Creates the model and saves it to the model registry
create_model_response = sagemaker_client.create_model(
    ModelName = endpoint_name,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        "Image": inference_image_uri,
        "ModelDataSource": {
            "S3DataSource": {               # S3 Data Source configuration:
                "S3Uri": model_uri,         # path to your model and script
                "S3DataType": "S3Prefix",   # causes SageMaker to download from a prefix
                "CompressionType": "None"   # disables compression
            }
        }
    }
)

# Creates the endpoint configuration for the model
create_endpoint_config_response = sagemaker_client.create_endpoint_config(
    EndpointConfigName = endpoint_name,
    ProductionVariants = [{
        "ModelName": endpoint_name,
        "VariantName": "sdxl",
        "InitialInstanceCount": 1,
        "InstanceType": "ml.g5.4xlarge",     # 4xlarge is required to load the model
    }]
)
        
# Creates the inference endpoint
deploy_model_response = sagemaker_client.create_endpoint(
    EndpointName = endpoint_name,
    EndpointConfigName = endpoint_name
)
    
print('Waiting for the endpoint to be in service, this can take 5-10 minutes...')
waiter = sagemaker_client.get_waiter('endpoint_in_service')
waiter.wait(EndpointName=endpoint_name)
print(f'Endpoint {endpoint_name} is in service, but the model is still loading. This may take another 5-10 minutes.')

In [None]:
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import BytesDeserializer

# Create a predictor with proper serializers
deployed_model = Predictor(
    endpoint_name=endpoint_name, 
    sagemaker_session=sess,
    serializer=JSONSerializer(),
    deserializer=BytesDeserializer(accept="image/png")

)

## A. Text to image

**Note**: The endpoint will be "InService" before the model has finished loading, so this request will initially time out. Check the endpoint logs in CloudWatch for status.

In [None]:
# Helper to display images
def decode_and_show(model_response) -> None:
    """
    Decodes and displays an image from SDXL output

    Args:
        model_response (GenerationResponse): The response object from the deployed SDXL model.

    Returns:
        None
    """        
    image = Image.open(io.BytesIO(model_response))
    display(image)

In [None]:
output = deployed_model.predict({"text_prompts":[{"text": "jaguar in the Amazon rainforest"}],                                             
                                             "seed": 133,
                                            "width": 1024,
                                            "height": 1024})
decode_and_show(output)                                             


Available samplers are:
```
“EulerEDMSampler”,
“HeunEDMSampler”,
“EulerAncestralSampler”,
“DPMPP2SAncestralSampler”,
“DPMPP2MSampler”,
“LinearMultistepSampler”,
```

In [None]:
text = "photograph of latte art of a cat"

output = deployed_model.predict({"text_prompts":[{"text":text}],
                                            "seed":45,
                                            "height":640,
                                            "width":1536,
                                            "sampler":"EulerEDMSampler",
                                })
decode_and_show(output)

SDXL can render short snippets of text, like single words. Let's try an example below.

In [None]:
text = "the word go written in neon lights"

output = deployed_model.predict({"text_prompts":[{"text":text}],                                            
                                            "seed": 142,
                                            "height": 640,
                                            "width": 1536,
                                            "sampler": "LinearMultistepSampler",
                                })
decode_and_show(output)

## B. Image to image

To perform inference that takes an image as input, you must pass the image into `init_image` as a base64-encoded string.

Below is a helper function for converting images to base64-encoded strings:

In [None]:
def encode_image(image_path: str, resize: bool = True, size: Tuple[int, int] = (1024, 1024)) -> Union[str, None]:
    """
    Encode an image as a base64 string, optionally resizing it to a supported resolution.

    Args:
        image_path (str): The path to the image file.
        resize (bool, optional): Whether to resize the image. Defaults to True.

    Returns:
        Union[str, None]: The encoded image as a string, or None if encoding failed.
    """
    assert os.path.exists(image_path)

    if resize:
        image = Image.open(image_path)
        image = image.resize(size)
        image.save("image_path_resized.png")
        image_path = "image_path_resized.png"
    image = Image.open(image_path)
    assert image.size == size
    with open(image_path, "rb") as image_file:
        img_byte_array = image_file.read()
        # Encode the byte array as a Base64 string
        try:
            base64_str = base64.b64encode(img_byte_array).decode("utf-8")
            return base64_str
        except Exception as e:
            print(f"Failed to encode image {image_path} as base64 string.")
            print(e)
            return None
    

In [None]:
! wget https://platform.stability.ai/Cat_August_2010-4.jpg

Let's feed an image into the model as well as the prompt this time. We can set `image_scale` to weight the relative importance of the image and the prompt. For the demo, we'll use a [picture of the cat, taken from Wikimedia Commons](https://commons.wikimedia.org/wiki/File:Cat_August_2010-4.jpg), provided along with this notebook.

In [None]:
# Here is the original image:
display(Image.open('Cat_August_2010-4.jpg'))

In [None]:
cat_path = "Cat_August_2010-4.jpg"

size = (1536, 640)
cat_data = encode_image(cat_path, size=size)

output = deployed_model.predict({"text_prompts":[{"text": "cat in embroidery"}],
                                                  "init_image": cat_data,
                                                  "cfg_scale": 9,
                                                  "image_strength": 0.8,
                                                  "seed": 42,
                                                  })                                            
decode_and_show(output)

# <a id='toc5_'></a>[3: Delete the endpoint](#toc0_)

When you've finished working, you can delete the endpoint to release the EC2 instance(s) associated with it, and stop billing.

Get your list of Sagemaker endpoints using the AWS Sagemaker CLI like this:

In [None]:
!aws sagemaker list-endpoints

# Delete an endpoint

In [None]:
sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
# Rerun the aws cli command above to confirm that its gone.