# How to serve the `stabilityai/stable-cascade` model on Amazon SageMaker

Stable Cascade project links:
* [Project page](https://www.timothybrooks.com/instruct-pix2pix/)
* [GitHub repositroy](https://github.com/Stability-AI/StableCascade)
* [HuggingFace model hub page](https://huggingface.co/stabilityai/stable-cascade) (model ID: `stabilityai/stable-cascade`)

## 1. Dependency installation

In [2]:
%pip install pip --upgrade --quiet
%pip install huggingface_hub==0.20.3 sagemaker==2.199.0 diffusers==0.27.2 sagemaker-ssh-helper --quiet

[0mNote: you may need to restart the kernel to use updated packages.
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
distributed 2022.7.0 requires tornado<6.2,>=6.0.3, but you have tornado 6.4 which is incompatible.[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.


## 2. Imports & variable assignments

In [4]:
import base64
import io
import json
import os
from pathlib import Path
import PIL
import shutil
import tarfile
import time
from typing import Any, Dict, List

import boto3
import botocore
from diffusers.utils import make_image_grid
import huggingface_hub
import sagemaker
import sagemaker.utils

In [5]:
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "0" # Put it to 1 if do not want the progress of the downloads

In [6]:
SM_DEFAULT_EXECUTION_ROLE_ARN = sagemaker.get_execution_role()
SM_SESSION = sagemaker.session.Session()
SM_ARTIFACT_BUCKET_NAME = SM_SESSION.default_bucket() # We use SageMaker's default bucket

REGION_NAME = SM_SESSION._region_name
S3_CLIENT = boto3.client("s3", region_name=REGION_NAME)
SAGEMAKER_CLIENT = boto3.client("sagemaker", region_name=REGION_NAME)
SAGEMAKER_RUNTIME_CLIENT = boto3.client("sagemaker-runtime", region_name=REGION_NAME)

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml


In [9]:
HOME_DIR = os.environ["HOME"]

# HuggingFace local model storage
HF_LOCAL_CACHE_DIR = Path(HOME_DIR) / ".cache" / "huggingface" / "hub"
HF_LOCAL_DOWNLOAD_DIR = Path.cwd() / "model_repo"
HF_LOCAL_PRIOR_DOWNLOAD_DIR = Path.cwd() / "prior_model_repo"
HF_LOCAL_DOWNLOAD_DIR.mkdir(exist_ok=True)
HF_LOCAL_PRIOR_DOWNLOAD_DIR.mkdir(exist_ok=True)

# Selected HuggingFace model
HF_HUB_PRIOR_MODEL_NAME = "stabilityai/stable-cascade-prior"
HF_HUB_MODEL_NAME = "stabilityai/stable-cascade"

# HuggingFace remote model storage (Amazon S3)
MODEL_ARTIFACTS_KEY_PREFIX = f"photobooth/endpoint/{HF_HUB_MODEL_NAME}/model"
CODE_ARTIFACTS_KEY_PREFIX = f"photobooth/endpoint/{HF_HUB_MODEL_NAME}/code"

PRIOR_MODEL_ARTIFACTS_KEY_PREFIX = f"photobooth/endpoint/{HF_HUB_PRIOR_MODEL_NAME}/model"
PRIOR_CODE_ARTIFACTS_KEY_PREFIX = f"photobooth/endpoint/{HF_HUB_PRIOR_MODEL_NAME}/code"

## 3. Assets deployment: model artifacts

In [None]:
huggingface_hub.snapshot_download(
    repo_id=HF_HUB_MODEL_NAME,
    revision="main",
    local_dir=HF_LOCAL_DOWNLOAD_DIR,
    local_dir_use_symlinks="auto",  # Files larger than 5MB are actually symlinked to the local HF cache
    allow_patterns=["*.json", "*.txt", "*.safetensors"],
);

Fetching 36 files:   0%|          | 0/36 [00:00<?, ?it/s]

decoder/config.json:   0%|          | 0.00/1.32k [00:00<?, ?B/s]

super_resolution.safetensors:   0%|          | 0.00/814M [00:00<?, ?B/s]

inpainting.safetensors:   0%|          | 0.00/218M [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/6.25G [00:00<?, ?B/s]

canny.safetensors:   0%|          | 0.00/218M [00:00<?, ?B/s]

stable_cascade_stage_b.safetensors:   0%|          | 0.00/4.55G [00:00<?, ?B/s]

stable_cascade_stage_c.safetensors:   0%|          | 0.00/9.22G [00:00<?, ?B/s]

diffusion_pytorch_model.bf16.safetensors:   0%|          | 0.00/3.13G [00:00<?, ?B/s]

decoder_lite/config.json:   0%|          | 0.00/1.32k [00:00<?, ?B/s]

diffusion_pytorch_model.bf16.safetensors:   0%|          | 0.00/1.40G [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/2.80G [00:00<?, ?B/s]

effnet_encoder.safetensors:   0%|          | 0.00/81.5M [00:00<?, ?B/s]

model_index.json:   0%|          | 0.00/451 [00:00<?, ?B/s]

previewer.safetensors:   0%|          | 0.00/16.0M [00:00<?, ?B/s]

scheduler/scheduler_config.json:   0%|          | 0.00/117 [00:00<?, ?B/s]

stage_a.safetensors:   0%|          | 0.00/73.7M [00:00<?, ?B/s]

stage_b.safetensors:   0%|          | 0.00/6.25G [00:00<?, ?B/s]

stage_b_bf16.safetensors:   0%|          | 0.00/3.13G [00:00<?, ?B/s]

stage_b_lite.safetensors:   0%|          | 0.00/2.80G [00:00<?, ?B/s]

In [None]:
model_artifacts_url = SM_SESSION.upload_data(
    path=HF_LOCAL_DOWNLOAD_DIR.as_posix(),
    bucket=SM_ARTIFACT_BUCKET_NAME,
    key_prefix=MODEL_ARTIFACTS_KEY_PREFIX,
)
print(f"Model artifacts have been successfully uploaded to: {model_artifacts_url}")

In [None]:
huggingface_hub.snapshot_download(
    repo_id=HF_HUB_PRIOR_MODEL_NAME,
    revision="main",
    local_dir=HF_LOCAL_PRIOR_DOWNLOAD_DIR,
    local_dir_use_symlinks="auto",  # Files larger than 5MB are actually symlinked to the local HF cache
    allow_patterns=["*.json", "*.txt", "*.safetensors"], # Falcon is currently implemented as a custom model, we must include the .py files to be able to use it
);

In [None]:
prior_model_artifacts_url = SM_SESSION.upload_data(
    path=HF_LOCAL_PRIOR_DOWNLOAD_DIR.as_posix(),
    bucket=SM_ARTIFACT_BUCKET_NAME,
    key_prefix=PRIOR_MODEL_ARTIFACTS_KEY_PREFIX,
)
print(f"Prior Model artifacts have been successfully uploaded to: {model_artifacts_url}")

## 4. Assets deployment: code artifacts

In [7]:
# Code artifacts local storage
SOURCE_DIR = Path.cwd() / "code"
SOURCE_DIR.mkdir(exist_ok=True)

In [23]:
%%writefile {SOURCE_DIR}/requirements.txt
transformers==4.39.1
#diffusers==0.26.1
#diffusers @ git+https://github.com/kashif/diffusers.git@wuerstchen-v3
#diffusers @ git+https://github.com/kashif/diffusers.git@a3dc21385b7386beb3dab3a9845962ede6765887
diffusers==0.27.2
accelerate==0.28.0
peft==0.8.2

Overwriting /root/GenAI-PhotoBooth/Cascade/code/requirements.txt


In [15]:
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
import torch

self._pipeline = StableCascadeDecoderPipeline.from_pretrained("./", torch_dtype=torch.float16).to(self._device)

ModuleNotFoundError: No module named 'torch'

In [24]:
%%writefile {SOURCE_DIR}/serving.properties
engine=Python
option.model_id=s3://sagemaker-us-east-1-433808754371/photobooth/endpoint/stabilityai/stable-cascade/model
option.dtype=bf16
option.entryPoint=handler.py

Overwriting /root/GenAI-PhotoBooth/Cascade/code/serving.properties


In [30]:
%%writefile {SOURCE_DIR}/handler.py
import base64
import io
import os
import PIL
from typing import Dict, Optional

from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline

from djl_python import Input, Output
import torch

def encode_image(image: PIL.Image, format: str) -> str:
    buffer = io.BytesIO()
    image.save(buffer, format=format)
    return base64.b64encode(buffer.getvalue()).decode("utf-8") # could be .decode("ascii") too

def decode_image(encoded_image: str) -> PIL.Image:
    image_data_bytes = base64.b64decode(encoded_image)
    return PIL.Image.open(io.BytesIO(image_data_bytes))


def get_torch_dtype_from_str(dtype: str) -> torch.dtype:
    if dtype == "fp32":
        return torch.float32
    if dtype == "fp16":
        return torch.float16
    if dtype == "bf16":
        return torch.bfloat16
    if dtype == "int8":
        return torch.int8
    if dtype is None:
        return None
    raise ValueError(f"Data type cannot be parsed as valid Torch data type: {dtype}")
    
    
class InferenceService:
    
    def __init__(self) -> None:
        self._device = "cuda"
        self.initialized = False
        self.num_images_per_prompt = 1
        
    def initialize(self, properties: Dict[str, str]) -> None:
        # The `option.model_id` variable can be either a HuggingFace ID, an S3 URL, or a local directory path.
        # - If `option.model_id` is a HuggingFace ID, models artifacts are downloaded by the 
        #  Transformers/Diffusers library (slower).
        # - If `option.model_id` is an S3 URL, the DJL model server downloads the model artifacts locally using 
        #   the highly efficient `s5cmd` tool and set `option.model_id` to the local download directory path.
        # - If `option.model_id` is not available, it is assumed tat model artifacts are located in 
        #   `option.model_dir`. If `option.model_dir` is not available, `/opt/ml/model` is used as default value.
        
        model_location = properties.get("model_id") or properties.get("model_dir")
        model_dir = properties.get("model_dir")
        print(f"model_location: {model_location}")
        #torch_dtype = get_torch_dtype_from_str(dtype=properties.get("dtype", "fp16"))
        
        #print(f"model_dir content: {os.listdir(model_dir)}")
        print(f"model_location content: {os.listdir(model_location)}")
        prior_location = f"{model_location}/prior"
        print(f"prior content: {os.listdir(prior_location)}")
        
        self._pipeline = StableCascadeDecoderPipeline.from_pretrained(model_location, torch_dtype=torch.float16).to(self._device)
        print("main loaded")
        print("start loading prior")
        self._prior_pipeline = StableCascadePriorPipeline.from_pretrained(f"{model_location}/prior", variant="bf16", torch_dtype=torch.bfloat16).to(self._device)
        print("prior loaded")
        
        #self._pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_location, torch_dtype=torch_dtype, safety_checker=None)
        self._pipeline.to(self._device)
        #self._prior_pipeline.to(self._device)
        self._pipeline.enable_model_cpu_offload()
        self.initialized = True
        print("end initialize")
    
    def handle_request(self, inputs: Input) -> Output:
        request_payload = inputs.get_as_json()
        prompt, negative_prompt = request_payload["prompt"], request_payload["negative_prompt"]
        image = decode_image(encoded_image=request_payload["image"])
        guidance_scale = request_payload["guidance_scale"]
        generation_parameters = request_payload["generation_parameters"]
        num_inference_steps = request_payload["num_inference_steps"]
        seed = generation_parameters.pop("seed", None)
        
        print(f"prompt: {prompt}")
        print(f"negative_prompt: {negative_prompt}")
        print(f"image: {image}")
        print(f"guidance_scale: {guidance_scale}")
        print(f"generation_parameters: {generation_parameters}")
        print(f"num_inference_steps: {num_inference_steps}")
        print(f"seed: {seed}")
        print(f"self.num_images_per_prompt: {self.num_images_per_prompt}")

        
        if seed:
            generation_parameters["generator"] = torch.Generator(device=self._device).manual_seed(int(seed))
        prior_output  = self._prior_pipeline(
            images=image,
            prompt=prompt,
            height=1024,
            width=1024,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            num_images_per_prompt=self.num_images_per_prompt,
            num_inference_steps=num_inference_steps
        )
        print(f"Prior output: {prior_output}")
        
        decoder_output = self._pipeline(
            image_embeddings=prior_output.image_embeddings.half(),
            prompt=prompt,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            output_type="pil",
            num_inference_steps=num_inference_steps
        )
        serialized_output_image = encode_image(image=decoder_output.images[0], format="PNG")
        return Output().add(serialized_output_image).add_property("content-type", "image/png")

    
_service = InferenceService()


def handle(inputs: Input) -> Optional[Output]:
    if not _service.initialized:
        print("Initializing inference service")
        _service.initialize(properties=inputs.get_properties())
        
    if inputs.is_empty():
        return None

    return _service.handle_request(inputs=inputs)

Overwriting /root/GenAI-PhotoBooth/Cascade/code/handler.py


In [26]:
archive_file_path = Path("model.tar.gz")
with tarfile.open(archive_file_path, mode="w:gz") as tar:
    tar.add(SOURCE_DIR, arcname=".")
    
code_artifacts_uri = SM_SESSION.upload_data(
    path=archive_file_path.as_posix(),
    bucket=SM_ARTIFACT_BUCKET_NAME,
    key_prefix=CODE_ARTIFACTS_KEY_PREFIX,
)

print(f"Code artifacts have been successfully uploaded to: {code_artifacts_uri}")

Code artifacts have been successfully uploaded to: s3://sagemaker-us-east-1-433808754371/photobooth/endpoint/stabilityai/stable-cascade/code/model.tar.gz


## 5. Endpoint deployment

In [27]:
# Check your Amazon SageMaker service quota for "<instance_type> for endpoint usage"
endpoint_instance_type = "ml.g5.2xlarge"
# We use the SageMaker Large Model Inference (LMI) Deep Learning Containers (DLC) image
container_image_uri = sagemaker.image_uris.retrieve(
    framework="djl-deepspeed", 
    region=REGION_NAME, 
    version="0.25.0"
)

In [28]:
from sagemaker_ssh_helper.wrapper import SSHModelWrapper

timestamp = sagemaker.utils.sagemaker_timestamp()

model_name = f"photobooth-stable-cascade-model-{timestamp}"[:63]

create_model_response = SAGEMAKER_CLIENT.create_model(
    ModelName=model_name,
    ExecutionRoleArn=SM_DEFAULT_EXECUTION_ROLE_ARN,
    PrimaryContainer={"Image": container_image_uri, "ModelDataUrl": code_artifacts_uri},
)

model_arn = create_model_response["ModelArn"]
print(f"Created Model ARN: {model_arn}")

endpoint_config_name = f"photobooth-stable-cascade-endpoint-config-{timestamp}"[:63]
endpoint_config_response = SAGEMAKER_CLIENT.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "AllTraffic",
            "ModelName": model_name,
            "InitialInstanceCount": 1,
            "InitialVariantWeight": 1.0,
            "InstanceType": endpoint_instance_type,
            "ContainerStartupHealthCheckTimeoutInSeconds": 15 * 60,
            "ModelDataDownloadTimeoutInSeconds": 15 * 60,
        },
    ],
)

endpoint_config_arn = endpoint_config_response["EndpointConfigArn"]
print(f"Created EndpointConfig ARN: {endpoint_config_arn}")

Created Model ARN: arn:aws:sagemaker:us-east-1:433808754371:model/photobooth-stable-cascade-model-2024-03-25-23-32-30-553
Created EndpointConfig ARN: arn:aws:sagemaker:us-east-1:433808754371:endpoint-config/photobooth-stable-cascade-endpoint-config-2024-03-25-23-32-30-5


In [29]:
endpoint_name = f"photobooth-stable-cascade-endpoint-{timestamp}"[:63]

create_endpoint_response = SAGEMAKER_CLIENT.create_endpoint(
    EndpointName=endpoint_name, 
    EndpointConfigName=endpoint_config_name
)

endpoint_arn = create_endpoint_response['EndpointArn']
print(f"Endpoint ARN: {endpoint_arn}")

describe_endpoint_response = SAGEMAKER_CLIENT.describe_endpoint(EndpointName=endpoint_name)
status = describe_endpoint_response["EndpointStatus"]

while status == "Creating":
    time.sleep(30)
    describe_endpoint_response = SAGEMAKER_CLIENT.describe_endpoint(EndpointName=endpoint_name)
    status = describe_endpoint_response["EndpointStatus"]
    print(f"Status: {status}")
print(f"Final status: {status}")

Endpoint ARN: arn:aws:sagemaker:us-east-1:433808754371:endpoint/photobooth-stable-cascade-endpoint-2024-03-25-23-32-30-553
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Failed
Final status: Failed


## 6. Endpoint invocation

Image data bytes are serialized into a base64-encoded string of text. Base64 encoding encode 3 bytes of binary data into 4 ASCII characters. Each ASCII character is represented using 1 byte so the encoded data is 33% larger than the raw data. However, encoded data is a string of text that can be serialized into JSON for example.

In [189]:
def encode_file_image(file_path: Path) -> str:
    with open(file_path, mode="rb") as file_handle:
        image_data_bytes = file_handle.read()
        return base64.b64encode(image_data_bytes).decode("utf-8") # could be .decode("ascii") too

    
def encode_image(image: PIL.Image, format: str) -> str:
    buffer = io.BytesIO()
    image.save(buffer, format=format)
    return base64.b64encode(buffer.getvalue()).decode("utf-8") # could be .decode("ascii") too


def decode_image(encoded_image: str) -> PIL.Image:
    image_data_bytes = base64.b64decode(encoded_image)
    return PIL.Image.open(io.BytesIO(image_data_bytes))

In [194]:
#prompt = "Anthropomorphic cat dressed as a pilot"
prompt = "portrait of a man in lineart style"# 1900's wering sun glasses"
negative_prompt = "text"

In [239]:
#image_file_path = Path("portrait.png")
image_file_path = Path("portrait.jpg")
input_image = PIL.Image.open(image_file_path)
input_image.size
images = [input_image]
for im in images:
    print("true")

true


In [240]:
prompt =  "make it line art style",
generation_parameters = {
    "seed": 42,
    "num_inference_steps": 50,
    "image_guidance_scale":  1.5, 
    "guidance_scale": 9.0,
}

payload = {
    "image": encode_image(image=input_image, format="JPEG"),
    "prompt": prompt,
    "generation_parameters": generation_parameters,
}

In [243]:
generation_parameters = {
    #"prompt": prompt,
    #"height": 1024,
    #"width": 1024,
    #"negative_prompt": negative_prompt,
    #"guidance_scale": 2.0,
    "num_images_per_prompt": 1,
    #"num_inference_steps": 40,
    #"seed": 456809
}
payload = {
    "image": encode_image(image=input_image, format="JPEG"),
    #"image":"",
    "prompt": prompt,
    "negative_prompt": negative_prompt,
    "generation_parameters": generation_parameters,
    "guidance_scale": 1, # should be < 1
    "num_inference_steps": 20,
    #"num_images_per_prompt": 1
}
#endpoint_name= "photobooth-stable-cascade-endpoint-2024-02-27-11-03-02-293"
#json.dumps(payload)

In [244]:
# Warning: SageMaker endpoints have a fixed 60s timeout per invocation (non-configurable)
invoke_endpoint_response = SAGEMAKER_RUNTIME_CLIENT.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=json.dumps(payload),
    ContentType="application/json",
)
#invoke_endpoint_response

ModelError: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received client error (424) from primary with message "{
  "code":424,
  "message":"prediction failure",
  "error":"Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor"
}". See https://us-east-1.console.aws.amazon.com/cloudwatch/home?region=us-east-1#logEventViewer:group=/aws/sagemaker/Endpoints/photobooth-stable-cascade-endpoint-2024-02-27-14-54-29-162 in account 433808754371 for more information.

In [None]:
response_body = invoke_endpoint_response['Body'].read()
output_image = decode_image(encoded_image=response_body)
make_image_grid([input_image, output_image], rows=1, cols=2)


## 7. Clean-up
#### Delete AWS resources

In [None]:
delete_endpoint_response = SAGEMAKER_CLIENT.delete_endpoint(EndpointName=endpoint_name)
print(f"Deleted Endpoint: {endpoint_name}")
delete_endpoint_config_response = SAGEMAKER_CLIENT.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
print(f"Deleted EndpointConfig: {endpoint_config_name}")
delete_model_response = SAGEMAKER_CLIENT.delete_model(ModelName=model_name)
print(f"Deleted Model: {model_name}")

#### Delete remote assets

In [None]:
def delete_s3_objects_by_prefix(bucket_name: str, key_prefix: str) -> None:
    paginator = S3_CLIENT.get_paginator("list_objects")
    operation_parameters = {"Bucket": bucket_name, "Prefix": key_prefix}
    page_iterator = paginator.paginate(**operation_parameters)
    keys = [obj["Key"] for page in page_iterator for obj in page["Contents"]]
    S3_CLIENT.delete_objects(Bucket=bucket_name, Delete={"Objects": [{"Key": key} for key in keys]})

In [None]:
# Remove code & model artifacts from S3
delete_s3_objects_by_prefix(bucket_name=SM_ARTIFACT_BUCKET_NAME, key_prefix=MODEL_ARTIFACTS_KEY_PREFIX)
delete_s3_objects_by_prefix(bucket_name=SM_ARTIFACT_BUCKET_NAME, key_prefix=CODE_ARTIFACTS_KEY_PREFIX)

#### Delete local assets

In [None]:
def get_local_model_cache_dir(hf_model_name: str) -> str:
    for dir_name in os.listdir(HF_LOCAL_CACHE_DIR):
        if dir_name.endswith(hf_model_name.replace("/", "--")):
            break
    else:
        raise ValueError(f"Could not find HF local cache directory for model {hf_model_name}")
    return HF_LOCAL_CACHE_DIR / dir_name

In [None]:
# Remove model artifacts from the local download directory
shutil.rmtree(HF_LOCAL_DOWNLOAD_DIR)

In [None]:
# Remove model artifacts from the local HuggingFace cache directory
hf_local_cache_dir = get_local_model_cache_dir(hf_model_name=HF_HUB_MODEL_NAME)
shutil.rmtree(hf_local_cache_dir)

In [None]:
# Remove code artifacts from the local host
shutil.rmtree(SOURCE_DIR)