# <a id='toc1_'></a>[Using SDXL 0.9 with AWS JumpStart](#toc0_)

This sample notebook shows you how to deploy Stable Diffusion SDXL 0.9 from Stability AI as an endpoint on Amazon SageMaker.

> **Note**: This is a reference notebook and it cannot run unless you make changes suggested in the notebook.


## <a id='toc1_1_'></a>[Prerequisites](#toc0_)

1. **Note**: Open this notebook from an Amazon SageMaker Notebook Instance or Amazon SageMaker Studio.

1. Ensure that IAM role used has **AmazonSageMakerFullAccess**

1. To deploy the ML model successfully using the steps in this notebook, ensure that either:
    1. Your IAM role has the following three permissions and you have authority to make AWS Marketplace subscriptions in the AWS account used: 
        1. **aws-marketplace:ViewSubscriptions**
        1. **aws-marketplace:Unsubscribe**
        1. **aws-marketplace:Subscribe**  
    2. Or your AWS account has a subscription to [the SDXL 0.9 Jumpstart](https://aws.amazon.com/marketplace/pp/prodview-wqewmgjyf7h7o). If so, skip step: [Subscribe to the model package](#1.-Subscribe-to-the-model-package) 




## <a id='toc1_2_'></a>[Resources](#toc0_)


1. [Stability SDK documentation](https://api.stability.ai/docs#tag/v1generation)

2. [Documentation on real-time inference with Amazon SageMaker](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-hosting.html).


## <a id='toc1_3_'></a>[Usage instructions](#toc0_)
You can run this notebook one cell at a time (By using Shift+Enter for running a cell).


   
- [1. Subscribe to the SDXL Model Package](#toc3_)    
- [2: Create an endpoint and perform real-time inference](#toc4_)    
  - [A: Text to image](#toc4_1_)    
  - [B: Image to image](#toc4_2_)    
- [3: Delete the endpoint](#toc5_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# <a id='toc3_'></a>[1. Subscribe to the SDXL Model Package](#toc0_)

To subscribe to the SDXL Model Package:
1. Open the SDXL Model Package listing page: https://aws.amazon.com/marketplace/pp/prodview-wqewmgjyf7h7o
1. On the AWS Marketplace listing, click on the **Continue to subscribe** button.
1. On the **Subscribe to this software** page, review and click on **"Accept Offer"** if you and your organization accept the EULA, pricing, and support terms.

In [None]:
!pip install 'stability-sdk[sagemaker] @ git+https://github.com/Stability-AI/stability-sdk.git@sagemaker'
! pip install protobuf==3.20


import sagemaker
from sagemaker import ModelPackage, get_execution_role
from stability_sdk_sagemaker.predictor import StabilityPredictor
from stability_sdk_sagemaker.models import get_model_package_arn
from stability_sdk.api import GenerationRequest, GenerationResponse, TextPrompt

from PIL import Image
from typing import Union
import io
import os
import base64
import boto3

# <a id='toc4_'></a>[2: Create an endpoint and perform real-time inference](#toc0_)

In [None]:
# Choose your endpoint name
from sagemaker.utils import name_from_base
endpoint_name=name_from_base('sdxl-0-9-jumpstart') # change this as desired


Once you have subscribed to Stability SDXL, get the Model Package ARN using the map below:


In [None]:

model_package_map = {
    "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/sdxl-v0-9-2042286-feeb547f21a83a53a3dc9a9bf08f660f",
    "us-east-2": "arn:aws:sagemaker:us-east-2:057799348421:model-package/sdxl-v0-9-2042286-feeb547f21a83a53a3dc9a9bf08f660f",
    "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/sdxl-v0-9-2042286-feeb547f21a83a53a3dc9a9bf08f660f",
    "ca-central-1": "arn:aws:sagemaker:ca-central-1:470592106596:model-package/sdxl-v0-9-2042286-feeb547f21a83a53a3dc9a9bf08f660f",
    "eu-central-1": "arn:aws:sagemaker:eu-central-1:446921602837:model-package/sdxl-v0-9-2042286-feeb547f21a83a53a3dc9a9bf08f660f",
    "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/sdxl-v0-9-2042286-feeb547f21a83a53a3dc9a9bf08f660f",
    "eu-west-2": "arn:aws:sagemaker:eu-west-2:856760150666:model-package/sdxl-v0-9-2042286-feeb547f21a83a53a3dc9a9bf08f660f",
    "ap-northeast-2": "arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/sdxl-v0-9-2042286-feeb547f21a83a53a3dc9a9bf08f660f",
    "ap-northeast-1": "arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/sdxl-v0-9-2042286-feeb547f21a83a53a3dc9a9bf08f660f",
    "ap-south-1": "arn:aws:sagemaker:ap-south-1:077584701553:model-package/sdxl-v0-9-2042286-feeb547f21a83a53a3dc9a9bf08f660f"
}


region = boto3.Session().region_name
if region not in model_package_map.keys():
    raise ("UNSUPPORTED REGION")
package_arn = model_package_map[region]  # TODO

role_arn = get_execution_role()
sagemaker_session = sagemaker.Session()

Create a deployable `ModelPackage`. For SDXL 0.9 use on of the following instances types: ml.g5.2xlarge, ml.p4d.24xlarge, and specify it as `instance_type` below.


In [None]:
model = ModelPackage(role=role_arn,model_package_arn=package_arn,sagemaker_session=sagemaker_session,predictor_cls=StabilityPredictor)


# Deploy the ModelPackage. This will take 5-10 minutes to run

instance_type="ml.g5.2xlarge" # valid instance types for this model are ml.g5.2xlarge, p4d.24xlarge, and p4de.24xlarge
deployed_model = model.deploy(initial_instance_count=1,instance_type=instance_type,endpoint_name=endpoint_name)


If you have already deployed your model, you can also access it via your chosen `endpoint_name` and `sagemaker_session`:


In [None]:
deployed_model = StabilityPredictor(endpoint_name=endpoint_name, sagemaker_session=sagemaker_session)


We can call `predict` on our deployed model to return model outputs. For the full list of parameters, [see the Stability.ai SDK documentation.](https://api.stability.ai/docs#tag/v1generation)

## <a id='toc4_1_'></a>[A: Text to image](#toc0_)


In [None]:
output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text="jaguar in the Amazon rainforest")],
                                             # style_preset="cinematic",
                                             seed = 12345,
                                            width=1024,
                                            height=1024.
                                             ))


Output images are included in the response's `artifacts` as base64 encoded strings. Below is a helper function for accessing decoding these images:

In [None]:
def decode_and_show(model_response: GenerationResponse) -> None:
    """
    Decodes and displays an image from SDXL output

    Args:
        model_response (GenerationResponse): The response object from the deployed SDXL model.

    Returns:
        None
    """
    image = model_response.artifacts[0].base64
    image_data = base64.b64decode(image.encode())
    image = Image.open(io.BytesIO(image_data))
    display(image)

decode_and_show(output)


In [None]:
text = "photograph of latte art of a cat"

output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text=text)],
                                            seed=5,
                                            height=640,
                                            width=1536,
                                            sampler="DDIM",
                                             ))
decode_and_show(output)

Let's try passing in a `style_preset`. See the [Stability SDK documentation](https://api.stability.ai/docs#tag/v1generation) for a full list of available presets.

In [None]:
output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text="teapot")],
                                            style_preset="origami",
                                            seed = 3,
                                            height = 1024,
                                            width = 1024
                                             ))

decode_and_show(output)


SDXL can render short snippets of text, like single words. Let's try an example below.

In [None]:
text = "the word go written in neon lights"

output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text=text)],
                                            style_preset="neon-punk",
                                            seed=111,
                                            height=640,
                                            width=1536,
                                            sampler="DDIM",
                                             ))
decode_and_show(output)

## <a id='toc4_2_'></a>[B: Image to image](#toc0_)

To perform inference that takes an image as input, you must pass the image into `init_image` as a base64-encoded string. Like output images, input images must be one of the supported resolutions: i.e. (height, width) should be one of   (1024, 1024), (1152, 896), (896, 1152), (1216, 832), (832, 1216), (1344, 768), (768, 1344), (1536, 640), (640, 1536).


Below is a helper function for converting images to base64-encoded strings:

In [None]:
def encode_image(image_path: str, resize: bool = True, size: (int, int) = (1024, 1024)) -> Union[str, None]:
    """
    Encode an image as a base64 string, optionally resizing it to 512x512.

    Args:
        image_path (str): The path to the image file.
        resize (bool, optional): Whether to resize the image. Defaults to True.

    Returns:
        Union[str, None]: The encoded image as a string, or None if encoding failed.
    """
    assert os.path.exists(image_path)

    if resize:
        image = Image.open(image_path)
        image = image.resize(size)
        image.save("image_path_resized.png")
        image_path = "image_path_resized.png"
    image = Image.open(image_path)
    assert image.size == size
    with open(image_path, "rb") as image_file:
        img_byte_array = image_file.read()
        # Encode the byte array as a Base64 string
        try:
            base64_str = base64.b64encode(img_byte_array).decode("utf-8")
            return base64_str
        except Exception as e:
            print(f"Failed to encode image {image_path} as base64 string.")
            print(e)
            return None
    

Let's feed an image into the model as well as the prompt this time. We can set `image_scale` to weight the relative importance of the image and the prompt. For the demo, we'll use a [picture of the cat, taken from Wikimedia Commons](https://commons.wikimedia.org/wiki/File:Cat_August_2010-4.jpg), provided along with this notebook.

In [None]:
! wget https://platform.stability.ai/Cat_August_2010-4.jpg

In [None]:
# Here is the original image:
display(Image.open('Cat_August_2010-4.jpg'))

In [None]:
cat_path = "Cat_August_2010-4.jpg"

size = (1536, 640)
cat_data = encode_image(cat_path, size=size)

output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text="cat in embroidery")],
                                                  init_image= cat_data,
                                                  cfg_scale=9,
                                                  image_strength=0.8,
                                                  seed=42,
                                                  height=size[0],
                                                  width=size[1],
                                                  init_image_mode="STEP_SCHEDULE"
                                                  ))
decode_and_show(output)

# <a id='toc5_'></a>[3: Delete the endpoint](#toc0_)

When you've finished working, you can delete the endpoint to release the EC2 instance(s) associated with it, and stop billing.

Get your list of Sagemaker endpoints using the AWS Sagemaker CLI like this:

In [None]:
!aws sagemaker list-endpoints

# Delete an endpoint

In [None]:
deployed_model.sagemaker_session.delete_endpoint(endpoint_name)
# Rerun the aws cli command above to confirm that its gone.