# Hosting ControlNet/Lora models on SageMaker using DJL container.

In this notebook, we explore how to host ControlNet/Lora models on SageMaker Real-time endpoint using the Large Model Inference container that packages DJL model server.

In this notebook, under the hood we use stable-diffusion-webui to generate image with Lora and ControlNet support. 

Note - Amazon Web Services has no control or authority over the third-party generative AI service referenced in this Workshop, and does not make any representations or warranties that the third-party generative AI service is secure, virus-free, operational, or compatible with your production environment and standards. You are responsible for making your own independent assessment of the content provided in this Workshop, and take measures to ensure that you comply with your own specific quality control practices and standards, and the local rules, laws, regulations, licenses and terms of use that apply to you, your content, and the third-party generative AI service referenced in this Workshop. The content of this Workshop: (a) is for informational purposes only, (b) represents current Amazon Web Services product offerings and practices, which are subject to change without notice, and (c) does not create any commitments or assurances from Beijing Sinnet Technology Co., Ltd. (“Sinnet”), Ningxia Western Cloud Data Technology Co., Ltd. (“NWCD”), Amazon Connect Technology Services (Beijing) Co., Ltd. (“Amazon”), or their respective affiliates, suppliers or licensors.  Amazon Web Services’ content, products or services are provided “as is” without warranties, representations, or conditions of any kind, whether express or implied.  The responsibilities and liabilities of Sinnet, NWCD or Amazon to their respective customers are controlled by the applicable customer agreements. 

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)

---

## Build Docker image and push to ECR.

Initialize the variables for SageMaker default bucket, role, and AWS account ID, and current AWS region.

In [None]:
import sagemaker

sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
role = sagemaker.get_execution_role()

import boto3

account_id = boto3.client("sts").get_caller_identity().get("Account")
region_name = boto3.session.Session().region_name

Execute the script to build Docker images for SageMaker endpoint.

In [None]:
%%sh -s "$region_name"

# This script shows how to build the Docker image and push it to ECR to be ready for use
# by SageMaker.

region=$1
echo "$region $1"

# Get the login command from ECR and execute it directly
docker login -u AWS -p $(aws ecr get-login-password --region $region) 763104351884.dkr.ecr.$region.amazonaws.com

# Get the account number associated with the current IAM credentials
account=$(aws sts get-caller-identity --query Account --output text)

if [ $? -ne 0 ]
then
    exit 255
fi

inference_image=all-in-one-ai-stable-diffusion-webui-inference-api
inference_fullname=${account}.dkr.ecr.${region}.amazonaws.com/${inference_image}:latest

# If the repository doesn't exist in ECR, create it.
aws ecr describe-repositories --repository-names "${inference_image}" --region ${region} || aws ecr create-repository --repository-name "${inference_image}" --region ${region}

if [ $? -ne 0 ]
then
    aws ecr create-repository --repository-name "${inference_image}" --region ${region}
fi

# Get the login command from ECR and execute it directly
docker login -u AWS -p $(aws ecr get-login-password --region $region) $account.dkr.ecr.$region.amazonaws.com

aws ecr set-repository-policy \
    --repository-name "${inference_image}" \
    --policy-text "file://ecr-policy.json" \
    --region ${region}

# Build the docker image locally with the image name and then push it to ECR
# with the full name.

docker build -t ${inference_image} -f Dockerfile.inference . --build-arg REGION=${region}

docker tag ${inference_image} ${inference_fullname}

docker push ${inference_fullname}


Upload the dummy file to S3 to meet the requirement of SageMaker Endpoint for model data.

In [None]:
model_data = f"s3://{bucket}/stable-diffusion-webui/data/model.tar.gz"
!touch dummy
!tar czvf model.tar.gz dummy
!rm dummy
!aws s3 cp model.tar.gz $model_data

## Deploy to SageMaker Real-time Endpoint

Initialize the variables for SageMaker default bucket, role, and AWS account ID, and current AWS region.

In [None]:
model_name = None
image_uri = "{0}.dkr.ecr.{1}.amazonaws.com/all-in-one-ai-stable-diffusion-webui-inference-api:latest".format(
    account_id, region_name
)
base_name = sagemaker.utils.base_name_from_image(image_uri)

Define the models configuration in order to download those models from one of source - HTTP, S3 and HuggingFace. Note: Here as an example the Lora model - 2bNierAutomataLora_v2b.safetensors and ControlNet model - control_sd15_canny.pth are going to be downloaded from Civitai and Huggingface directly once the SageMaker endpoint is created.

In [None]:
import json

huggingface_models = [
    {
        "repo_id": "runwayml/stable-diffusion-v1-5",
        "filename": "v1-5-pruned.ckpt",
        "name": "Stable-diffusion",
    },
    {
        "repo_id": "lllyasviel/ControlNet",
        "filename": "models/control_sd15_canny.pth",
        "name": "ControlNet",
    },
]

http_models = [
    {
        "uri": "https://civitai.com/api/download/models/7627",
        "filename": "2bNierAutomataLora_v2b.safetensors",
        "name": "Lora",
    }
]

model_environment = {
    "ckpt": "/tmp/models/Stable-diffusion/v1-5-pruned.ckpt",
    "huggingface_models": json.dumps(huggingface_models),
    "http_models": json.dumps(http_models),
    "generated_images_s3uri": f"s3://{bucket}/stable-diffusion-webui/generated/",
    "embeddings_s3uri": f"s3://{bucket}/stable-diffusion-webui/embeddings/",
    "hypernetwork_s3uri": f"s3://{bucket}/stable-diffusion-webui/hypernetwork/",
}

Define the model, instance type and instance initial count for SageMaker endpoint.

In [None]:
from sagemaker.model import Model
from sagemaker.predictor import Predictor

model = Model(
    name=model_name,
    model_data=model_data,
    role=role,
    image_uri=image_uri,
    env=model_environment,
    predictor_cls=Predictor,
)

instance_type = "ml.g4dn.2xlarge"
instance_count = 1

Here to be simplified, we use real-time inference. However, it has some limitations by nature. Real-time inference is suitable for workloads where payload sizes are up to 6 MB and need to be processed with low latency requirements in the order of milliseconds or seconds asynchronous inference is more suitable for workloads with large payload sizes and long inference processing times. 

In [None]:
predictor = model.deploy(
    instance_type=instance_type,
    initial_instance_count=instance_count,
    volume_size_in_gb=225,
    container_startup_health_check_timeout=1800,
)

## Generate images using Lora models

LoRA (Low-Rank Adaptation of Large Language Models) models have become the standard to extend the Stable Diffusion models. Let's use Lora model to generate images.

In [None]:
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()

inputs = {
    "task": "text-to-image",
    "model": "v1-5-pruned.ckpt",
    "txt2img_payload": {
        "enable_hr": False,
        "denoising_strength": 0,
        "firstphase_width": 0,
        "firstphase_height": 0,
        "hr_scale": 2,
        "hr_upscaler": "",
        "hr_second_pass_steps": 0,
        "hr_resize_x": 0,
        "hr_resize_y": 0,
        "prompt": "yorha no. 2 type b, 1girl, bangs, black blindfold, black dress, black gloves, black hairband, blindfold, blindfold removed, breasts, cleavage cutout, clothing cutout, commentary request, dress, gloves, hairband, half-closed eyes, hand up, highres, io (sinking=carousel), juliet sleeves, long sleeves, looking at viewer, medium breasts, mole, mole under mouth, nier (series), nier automata, no blindfold, parted lips, puffy sleeves, short hair, solo, thighhighs, turtleneck, upper body, white hair, bokeh <lora:2bNierAutomataLora_v2b:0.5>",
        "styles": [""],
        "seed": 141050431,
        "subseed": 3557256075,
        "subseed_strength": 0,
        "seed_resize_from_h": -1,
        "seed_resize_from_w": -1,
        "sampler_name": "",
        "batch_size": 1,
        "n_iter": 1,
        "steps": 20,
        "cfg_scale": 7,
        "width": 512,
        "height": 512,
        "restore_faces": False,
        "tiling": False,
        "do_not_save_samples": False,
        "do_not_save_grid": False,
        "negative_prompt": "(worst quality, low quality:1.3)",
        "eta": 0,
        "s_churn": 0,
        "s_tmax": 0,
        "s_tmin": 0,
        "s_noise": 1,
        "override_settings": {},
        "override_settings_restore_afterwards": True,
        "script_args": [],
        "sampler_index": "DPM++ SDE Karras",
        "script_name": "",
        "send_images": True,
        "save_images": False,
        "alwayson_scripts": {},
    },
}

prediction = predictor.predict(inputs)

Helper function for s3

In [None]:
import boto3

s3_resource = boto3.resource("s3")


def get_bucket_and_key(s3uri):
    pos = s3uri.find("/", 5)
    bucket = s3uri[5:pos]
    key = s3uri[pos + 1 :]
    return bucket, key

Process the generated images from Real-Time inference result.

In [None]:
from PIL import Image
import io
import uuid
from datetime import datetime

for image_uri in prediction["images"]:
    image_bucket, image_key = get_bucket_and_key(image_uri)
    image_object = s3_resource.Object(image_bucket, image_key)
    image = Image.open(io.BytesIO(image_object.get()["Body"].read()))
    image.show()
    image.save(datetime.now().strftime(f"%Y%m%d%H%M%S-{uuid.uuid4()}.jpg"))

## Generate images using ControlNet models

ControlNet is a neural network structure to control diffusion models by adding extra conditions.

In [None]:
from PIL import Image
import base64
import io


def encode_image_to_base64(image):
    with io.BytesIO() as output_bytes:
        image.save(output_bytes, format="JPEG")
        bytes_data = output_bytes.getvalue()

    encoded_string = base64.b64encode(bytes_data)

    base64_str = str(encoded_string, "utf-8")
    mimetype = "image/jpeg"
    image_encoded_in_base64 = (
        "data:" + (mimetype if mimetype is not None else "") + ";base64," + base64_str
    )
    return image_encoded_in_base64


def decode_base64_to_image(encoding):
    if encoding.startswith("data:image/"):
        encoding = encoding.split(";")[1].split(",")[1]
    try:
        image = Image.open(io.BytesIO(base64.b64decode(encoding)))
        return image
    except Exception as e:
        print(e)

In [None]:
from PIL import Image

image = Image.open("./images/inference/ControlNet/bal-source.png").convert("RGB")

In [None]:
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()

inputs = {
    "task": "text-to-image",
    "model": "v1-5-pruned.ckpt",
    "txt2img_payload": {
        "enable_hr": False,
        "denoising_strength": 0,
        "firstphase_width": 0,
        "firstphase_height": 0,
        "hr_scale": 2,
        "hr_upscaler": "",
        "hr_second_pass_steps": 0,
        "hr_resize_x": 0,
        "hr_resize_y": 0,
        "prompt": "ballon",
        "styles": [""],
        "seed": -1,
        "subseed": -1,
        "subseed_strength": 0,
        "seed_resize_from_h": -1,
        "seed_resize_from_w": -1,
        "sampler_name": "",
        "batch_size": 1,
        "n_iter": 1,
        "steps": 20,
        "cfg_scale": 7,
        "width": 512,
        "height": 512,
        "restore_faces": False,
        "tiling": False,
        "do_not_save_samples": False,
        "do_not_save_grid": False,
        "negative_prompt": "",
        "eta": 0,
        "s_churn": 0,
        "s_tmax": 0,
        "s_tmin": 0,
        "s_noise": 1,
        "override_settings": {},
        "override_settings_restore_afterwards": True,
        "script_args": [],
        "sampler_index": "DPM++ SDE Karras",
        "script_name": "",
        "send_images": True,
        "save_images": False,
        "alwayson_scripts": {
            "controlnet": {
                "args": [
                    {
                        "enabled": True,
                        "module": "none",
                        "model": "control_sd15_canny [fef5e48e]",
                        "weight": 1,
                        "image": encode_image_to_base64(image),
                        "low_vram": False,
                        "processor_res": 64,
                        "threshold_a": 64,
                        "threshold_b": 64,
                        "guidance_start": 0,
                        "guidance_end": 1,
                        "guess_mode": False,
                    }
                ]
            }
        },
    },
}

prediction = predictor.predict(inputs)

Process the generated images from real-time inference result.

In [None]:
from PIL import Image
import io
import uuid
from datetime import datetime

for image_uri in prediction["images"]:
    image_bucket, image_key = get_bucket_and_key(image_uri)
    image_object = s3_resource.Object(image_bucket, image_key)
    image = Image.open(io.BytesIO(image_object.get()["Body"].read()))
    image.show()
    image.save(datetime.now().strftime(f"%Y%m%d%H%M%S-{uuid.uuid4()}.jpg"))

## [Optional] Create auto-scaling group for SageMaker endpoint in case you want to scale it based on specific metrics automatically.

In [None]:
def create_autoscaling_group_for_sagemaker_endpoint(
    endpoint_name, min_capcity=1, max_capcity=2, target_value=5
):
    # application-autoscaling client
    asg_client = boto3.client("application-autoscaling")

    # This is the format in which application autoscaling references the endpoint
    resource_id = f"endpoint/{endpoint_name}/variant/AllTraffic"

    # Configure Autoscaling on asynchronous endpoint down to zero instances
    response = asg_client.register_scalable_target(
        ServiceNamespace="sagemaker",
        ResourceId=resource_id,
        ScalableDimension="sagemaker:variant:DesiredInstanceCount",
        MinCapacity=min_capcity,
        MaxCapacity=max_capcity,
    )

    response = asg_client.put_scaling_policy(
        PolicyName=f"Request-ScalingPolicy-{endpoint_name}",
        ServiceNamespace="sagemaker",
        ResourceId=resource_id,
        ScalableDimension="sagemaker:variant:DesiredInstanceCount",
        PolicyType="TargetTrackingScaling",
        TargetTrackingScalingPolicyConfiguration={
            "TargetValue": target_value,
            "CustomizedMetricSpecification": {
                "MetricName": "SageMakerVariantInvocationsPerInstance",
                "Namespace": "AWS/SageMaker",
                "Dimensions": [{"Name": "EndpointName", "Value": endpoint_name}],
                "Statistic": "Average",
            },
            "ScaleInCooldown": 600,  # duration until scale in begins (down to zero)
            "ScaleOutCooldown": 300,  # duration between scale out attempts
        },
    )


create_autoscaling_group_for_sagemaker_endpoint(predictor.endpoint_name)

## Resource cleanup.

In [None]:
predictor.delete_endpoint()

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/inference|generativeai|llm-workshop|lab9-hosting-controlnet-models-on-sagemaker|stable-diffusion-webui-sync-inference.ipynb)
