# Deploy Stable Cascade for Real-Time Image Generation on SageMaker
---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![ This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)

---

For further reading & reference materials:

Sources: https://www.philschmid.de/sagemaker-stable-diffusion#2-create-sagemaker-modeltargz-artifact

Further Reading: https://huggingface.co/stabilityai/stable-cascade

In [None]:
!pip install sagemaker huggingface_hub diffusers transformers accelerate safetensors tokenizers torch --upgrade --q
!pip install python-dotenv --upgrade --q

In [None]:
import base64
import boto3
import json
import matplotlib.pyplot as plt
import os
import random
import sagemaker
import tarfile
import time
import torch

from diffusers import (
    StableCascadePriorPipeline,
    StableCascadeDecoderPipeline,
    StableCascadeUNet,
)
from distutils.dir_util import copy_tree
from dotenv import load_dotenv
from huggingface_hub import snapshot_download
from io import BytesIO
from pathlib import Path
from IPython.display import display
from PIL import Image
from sagemaker import get_execution_role
from sagemaker.s3 import S3Uploader, S3Downloader
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.async_inference import AsyncInferenceConfig

load_dotenv()

sess = sagemaker.Session()
print(f"Sagemaker bucket: {sess.default_bucket()}")
print(f"Sagemaker session region: {sess.boto_region_name}")

In [None]:
HF_PRIOR_ID = "stabilityai/stable-cascade-prior"
HF_DECODER_ID = "stabilityai/stable-cascade"
CACHE_DIR = os.getenv("CACHE_DIR", "cache_dir")

prior_unet = StableCascadeUNet.from_pretrained(HF_PRIOR_ID, subfolder="prior_lite")
decoder_unet = StableCascadeUNet.from_pretrained(HF_DECODER_ID, subfolder="decoder_lite")

prior = StableCascadePriorPipeline.from_pretrained(
    HF_PRIOR_ID,
    variant="bf16",
    torch_dtype=torch.bfloat16,
    cache_dir=CACHE_DIR,
    prior=prior_unet,
)
decoder = StableCascadeDecoderPipeline.from_pretrained(
    HF_DECODER_ID,
    variant="bf16",
    torch_dtype=torch.bfloat16,
    cache_dir=CACHE_DIR,
    decoder=decoder_unet,
)

In [None]:
model_path = "model/"
prior_path = "model/prior/"
decoder_path = "model/decoder/"
code_path = "code/"
cache_dir = "cache_dir/"

if not os.path.exists(model_path):
    os.mkdir(model_path)
if not os.path.exists(code_path):
    os.mkdir(code_path)
if not os.path.exists(cache_dir):
    os.mkdir(cache_dir)
if not os.path.exists(prior_path):
    os.mkdir(prior_path)
if not os.path.exists(decoder_path):
    os.mkdir(decoder_path)

prior.save_pretrained(save_directory=prior_path)
decoder.save_pretrained(save_directory=decoder_path)

In [None]:
# Perform local inference in notebook to validate model loading and inference call

prior = StableCascadePriorPipeline.from_pretrained(prior_path, local_files_only=True)
decoder = StableCascadeDecoderPipeline.from_pretrained(decoder_path, local_files_only=True)
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
negative_prompt = ""

# Uncomment to run on GPU
# prior.enable_model_cpu_offload()
prior_output = prior(
    prompt=prompt,
    height=1024,
    width=1024,
    negative_prompt=negative_prompt,
    guidance_scale=4.0,
    num_images_per_prompt=1,
    num_inference_steps=20,
)

# Uncomment to run on GPU
# decoder.enable_model_cpu_offload()
decoder_output = decoder(
    image_embeddings=prior_output.image_embeddings,
    prompt=prompt,
    negative_prompt=negative_prompt,
    guidance_scale=0.0,
    output_type="pil",
    num_inference_steps=10,
).images[0]
decoder_output.save("cascade.png")

In [None]:
%%writefile code/requirements.txt
--find-links https://download.pytorch.org/whl/torch_stable.html
accelerate>=0.25.0
torch==2.1.2+cu118
torchvision==0.16.2+cu118
transformers>=4.30.0
numpy>=1.23.5
kornia>=0.7.0
insightface>=0.7.3
opencv-python>=4.8.1.78
tqdm>=4.66.1
matplotlib>=3.7.4
webdataset>=0.2.79
wandb>=0.16.2
munch>=4.0.0
onnxruntime>=1.16.3
einops>=0.7.0
onnx2torch>=1.5.13
warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git
torchtools @ git+https://github.com/pabloppp/pytorch-tools
diffusers

In [None]:
%%writefile code/inference.py
import base64
import json
import os
import torch

from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline
from io import BytesIO


def model_fn(model_dir):
    """
    Load the model for inference
    """
    print("Entering model_fn...")
    print(f"Model Directory is {model_dir}")

    prior = StableCascadePriorPipeline.from_pretrained(f"{model_dir}/prior", local_files_only=True)
    decoder = StableCascadeDecoderPipeline.from_pretrained(
        f"{model_dir}/decoder", local_files_only=True
    )

    model_dict = {"prior": prior, "decoder": decoder}
    print(f"model dictionary: {model_dict}")
    return model_dict


def predict_fn(input_data, model_dict):
    """
    Apply model to the incoming request
    """
    print("Entering predict_fn...")
    prior = model_dict["prior"]
    decoder = model_dict["decoder"]

    print(f"Processing input_data {input_data}")
    prompt = input_data["prompt"]
    negative_prompt = input_data["negative_prompt"]
    print(f"Prompt = {prompt}")
    print(f"Negative Prompt = {negative_prompt}")

    prior.enable_model_cpu_offload()
    prior_output = prior(
        prompt=prompt,
        height=1024,
        width=1024,
        negative_prompt=negative_prompt,
        guidance_scale=4.0,
        num_images_per_prompt=1,
        num_inference_steps=20,
    )

    decoder.enable_model_cpu_offload()
    decoder_output = decoder(
        image_embeddings=prior_output.image_embeddings,
        prompt=prompt,
        negative_prompt=negative_prompt,
        guidance_scale=0.0,
        output_type="pil",
        num_inference_steps=10,
    ).images[0]

    encoded_images = []
    buffered = BytesIO()
    decoder_output.save(buffered, format="JPEG")
    encoded_images.append(base64.b64encode(buffered.getvalue()).decode())
    print("Finished encodeing returned images.")
    return {"generated_images": encoded_images}

In [None]:
# assemble model package
model_tar = Path(f"model-{random.getrandbits(16)}")
model_tar.mkdir(exist_ok=True)

copy_tree(prior_path, str(model_tar.joinpath("prior")))
copy_tree(decoder_path, str(model_tar.joinpath("decoder")))
copy_tree(code_path, str(model_tar.joinpath("code")))

In [None]:
# helper to create the model.tar.gz
def compress(tar_dir=None, output_file="model.tar.gz"):
    parent_dir = os.getcwd()
    os.chdir(tar_dir)
    with tarfile.open(os.path.join(parent_dir, output_file), "w:gz") as tar:
        for item in os.listdir("."):
            print(item)
            tar.add(item, arcname=item)
    os.chdir(parent_dir)


compress(str(model_tar))

In [None]:
# upload model.tar.gz to s3
s3_model_uri = S3Uploader.upload(
    local_path="model.tar.gz",
    desired_s3_uri=f"s3://{sess.default_bucket()}/stable-cascade",
)
print(f"model uploaded to: {s3_model_uri}")

In [None]:
%cd

In [None]:
# helper decoder
def decode_base64_image(image_string):
    base64_image = base64.b64decode(image_string)
    buffer = BytesIO(base64_image)
    return Image.open(buffer)


# display PIL images as grid
def display_images(images=None, columns=3, width=100, height=100):
    plt.figure(figsize=(width, height))
    for i, image in enumerate(images):
        plt.subplot(int(len(images) / columns + 1), columns, i + 1)
        plt.axis("off")
        plt.imshow(image)

In [None]:
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    model_data=s3_model_uri,
    role=get_execution_role(sess),
    transformers_version="4.17",
    pytorch_version="1.10",
    py_version="py38",
)

# deploy the endpoint endpoint
predictor = huggingface_model.deploy(initial_instance_count=1, instance_type="ml.g5.48xlarge")

In [None]:
start_time = time.time()

# invoke_endpoint_async API call
client = boto3.client("sagemaker-runtime")
prompt = "A dog trying to catch a flying pizza art"
num_images_per_prompt = 1
payload = {"prompt": prompt, "negative_prompt": ""}

serialized_payload = json.dumps(payload, indent=4)
with open("payload.json", "w") as outfile:
    outfile.write(serialized_payload)

response = client.invoke_endpoint(
    EndpointName=predictor.endpoint_name,
    ContentType="application/json",  # Specify the format of the payload
    Accept="application/json",
    Body=serialized_payload,
)
print(f"inference response: {response}")

response_payload = json.loads(response["Body"].read().decode("utf-8"))

# decode images
decoded_images = [decode_base64_image(image) for image in response_payload["generated_images"]]

# visualize generation
display_images(decoded_images)

end_time = time.time()
inference_time = end_time - start_time
print(f"total inference time = {inference_time}")

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![ This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)

![ This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)

![ This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)

![ This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)

![ This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)

![ This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)

![ This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)

![ This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)

![ This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)

![ This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)

![ This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)

![ This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)

![ This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)

![ This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)

![ This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/inference|generativeai|huggingface-multimodal|stability-cascade|DeployStableCascade.ipynb)
