# <a id='toc1_'></a>[Using Stable Diffusion with AWS JumpStart](#toc0_)

This sample notebook shows you how to deploy SDXL from Stability AI as an endpoint on Amazon SageMaker.

> **Note**: This is a reference notebook and it cannot run unless you make changes suggested in the notebook.


## <a id='toc1_1_'></a>[Prerequisites](#toc0_)

1. **Note**: This notebook contains elements which render correctly in Jupyter interface. Open this notebook from an Amazon SageMaker Notebook Instance or Amazon SageMaker Studio.
1. Ensure that IAM role used has **AmazonSageMakerFullAccess**
1. To deploy this ML model successfully, ensure that:
    1. Either your IAM role has these three permissions and you have authority to make AWS Marketplace subscriptions in the AWS account used: 
        1. **aws-marketplace:ViewSubscriptions**
        1. **aws-marketplace:Unsubscribe**
        1. **aws-marketplace:Subscribe**  
    2. or your AWS account has a subscription to <model_link>. If so, skip step: [Subscribe to the model package](#1.-Subscribe-to-the-model-package) <<<<<<<<< TODO >>>>>>>>>>


## <a id='toc1_2_'></a>[Resources](#toc0_)


1. [Stability SDK documentation](https://api.stability.ai/docs#tag/v1generation)

2. [Documentation on real-time inference with Amazon SageMaker](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-hosting.html).


## <a id='toc1_3_'></a>[Usage instructions](#toc0_)
You can run this notebook one cell at a time (By using Shift+Enter for running a cell).


   
- [1. Subscribe to the SDXL Model Package](#toc3_)    
- [2: Create an endpoint and perform real-time inference](#toc4_)    
  - [A: Text to image](#toc4_1_)    
  - [B: Image to image](#toc4_2_)    
- [3: Delete the endpoint](#toc5_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# <a id='toc3_'></a>[1. Subscribe to the SDXL Model Package](#toc0_)

To subscribe to the SDXL Model Package:
1. Open the SDXL Model Package listing page <model_link>  <<<TODO>>>
1. On the AWS Marketplace listing, click on the **Continue to subscribe** button.
1. On the **Subscribe to this software** page, review and click on **"Accept Offer"** if you and your organization accept the EULA, pricing, and support terms. 
1. Once you click on **Continue to configuration button** and then choose a **region**, you will see a **Product Arn** displayed. This identifies SDXL, and will be used to create your endpoint using Boto3. Copy the ARN corresponding to your region and specify the same in the following cell. <<<TODO: check whether we need this>>>

In [None]:
import sagemaker
from sagemaker import ModelPackage
from stability_sdk_sagemaker.predictor import StabilityPredictor
from stability_sdk_sagemaker.models import get_model_package_arn
from stability_sdk.api import GenerationRequest,TextPrompt
import boto3

from PIL import Image
import io
import os
import base64

# <a id='toc4_'></a>[2: Create an endpoint and perform real-time inference](#toc0_)

In [None]:

# Choose your endpoint name
endpoint_name='jumpstart-sdxl-stability-sdk-kate-rc2' # change this for each deployment


Once you have subscribed to the Stability SDXL, pass in your Account ID and the Model Package name, `'stable-diffusion-xl-beta-v2-2-2-rc1'` to get the Model Package ARN.


In [None]:
package_arn = get_model_package_arn(model_package_name='stable-diffusion-xl-beta-v2-2-2-rc2', region_name='us-east-1', account_id='740929234339')
role_arn = 'arn:aws:iam::740929234339:role/stability-api-sagemaker-execution-role-us-east-1'

sagemaker_session = sagemaker.Session()

In [None]:
! aws sagemaker list-model-packages

Create a deployable `ModelPackage`:

In [None]:
model = ModelPackage(role=role_arn,model_package_arn=package_arn,sagemaker_session=sagemaker_session,predictor_cls=StabilityPredictor)

# Deploy the ModelPackage. This will take ~5 minutes to run
# Instance type ml.g5.xlarge is sufficient for SDXL
deployed_model = model.deploy(initial_instance_count=1,instance_type='ml.g5.xlarge',endpoint_name=endpoint_name)


If you have already deployed your model, you can also access it via your chosen `endpoint_name` and `sagemaker_session`:


In [None]:

deployed_model = StabilityPredictor(endpoint_name=endpoint_name, sagemaker_session=sagemaker_session)


We can call `predict` on our deployed model to return model outputs. For the full list of parameters, [see the Stability.ai SDK documentation.](https://api.stability.ai/docs#tag/v1generation)

## <a id='toc4_1_'></a>[A: Text to image](#toc0_)


In [None]:
output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text="A photograph of fresh pizza with basil and tomatoes, from a traditional oven")],
                                             style_preset="cinematic",
                                             seed = 42
                                             ))


Output images are included in the response's `artifacts` as base64 encoded strings. Below is a helper function for accessing decoding these images:

In [None]:
def decode_and_show(model_response):
    image = model_response.artifacts[0].base64
    image_data = base64.b64decode(image.encode())
    image = Image.open(io.BytesIO(image_data))
    display(image)

decode_and_show(output)


In [None]:
output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text="teapot")],
                                             style_preset="origami",
                                             seed = 1234
                                             ))
decode_and_show(output)


## <a id='toc4_2_'></a>[B: Image to image](#toc0_)

To perform inference that takes an image as input, you must pass in the image as `init_image` in the form of a base64-encoded string. Images must be of size (512, 512).

Below is a helper function for converting images to base64-encoded strings:

In [None]:
image_path = "sun.png"

def encode_image(image_path, resize=True):
    assert os.path.exists(image_path)

    if resize:
        image = Image.open(image_path)
        image = image.resize((512, 512))
        image.save("image_path_resized.png")
        image_path = "image_path_resized.png"
    image = Image.open(image_path)
    assert image.size == (512, 512)
    with open(image_path, "rb") as image_file:
        img_byte_array = image_file.read()
        # Encode the byte array as a Base64 string
        base64_str = base64.b64encode(img_byte_array).decode("utf-8")
    return base64_str

img_base64_str = encode_image(image_path)
    

In [None]:
output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text="birds flying in front of the sun")],
                                             init_image=img_base64_str
                                             ))

decode_and_show(output)

# <a id='toc5_'></a>[3: Delete the endpoint](#toc0_)

When you've finished working, you can delete the endpoint to release the EC2 instance(s) associated with it, and stop billing.

Get your list of Sagemaker endpoints using the AWS Sagemaker CLI like this:

In [None]:
!aws sagemaker list-endpoints

In [None]:
# Delete an endpoint
deployed_model.sagemaker_session.delete_endpoint(endpoint_name)


Alternatively, to delete the Model Package deployed to the endpoint, you can use:

In [None]:
deployed_model.delete_model()