# Deploy product design model on SageMaker

#### Install and upgrade dependencies

In [None]:
%pip install --upgrade pip --quiet
%pip install sagemaker boto3 awscli --upgrade --quiet

In [None]:
# init sagemaker parameters
import boto3
import sagemaker
from sagemaker import Model, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

s3_code_prefix = "east-ai-models/product-design-sd/accelerate"

print(f"role: {role}")
print(f"bucket: {bucket}")

In [None]:
!mkdir mymodel

#### Writing SageMaker LMI code properties and model.py

In [None]:
%%writefile ./mymodel/requirements.txt
transformers
diffusers==0.17.0
omegaconf
accelerate
boto3

In [None]:
%%writefile ./mymodel/serving.properties
engine=Python
option.s3url=s3://east-ai-workshop/product-design-sd/
option.tensor_parallel_degree=1

In [None]:
%%writefile ./mymodel/model.py
from djl_python import Input, Output
import os
import torch
from typing import Any, Dict, Tuple
import warnings
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
from diffusers import EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler, LMSDiscreteScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler,DDIMScheduler
import io
from PIL import Image
import base64
import json
import boto3
from torch import autocast
import random
import uuid


model = None
img2img_model = None


def get_model(properties):
    print(properties)
    if "model_id" in properties:
        model_name = properties["model_dir"]
        print("=========================model dir: {}============================".format(model_name))

        model_id = properties["model_id"]
        os.environ["model_id"] = model_id
        djl_list = os.listdir(model_id)
        print("=========================files in model_id============================")
        print(djl_list)

        print("=========================files in model_id/vae============================")
        print(os.listdir(model_id+'/vae'))

        ml_list = os.listdir('/opt/ml/model')
        print("=========================files in /opt/ml/model============================")
        print(ml_list)
    
    local_rank = int(os.getenv("LOCAL_RANK", "0"))
    model = StableDiffusionPipeline.from_pretrained(os.environ["model_id"])
    model = model.to("cuda")
    img2img_model = StableDiffusionImg2ImgPipeline.from_pretrained(os.environ["model_id"])
    img2img_model = img2img_model.to("cuda")
    return model, img2img_model


def handle(inputs: Input) -> None:
    global model, img2img_model
    print(model)
    print(img2img_model)
    print("print inputs: " + str(inputs) + '.'*20)
    
    if not model:
        model, img2img_model = get_model(inputs.get_properties())
    
    samplers = {
        "euler_a": EulerAncestralDiscreteScheduler,
        "eular": EulerDiscreteScheduler,
        "heun": HeunDiscreteScheduler,
        "lms": LMSDiscreteScheduler,
        "dpm2": KDPM2DiscreteScheduler,
        "dpm2_a": KDPM2AncestralDiscreteScheduler,
        "ddim": DDIMScheduler
    }

    if inputs.is_empty():
        # Model server makes an empty call to warmup the model on startup
        return None
    
    input_data = inputs.get_as_json()
    
    if 'input_image' in input_data:
        if input_data['input_image'].startswith('s3://'):
            dir_lst = input_data['input_image'].split('/')
            s3_client = boto3.client('s3')
            s3_response_object = s3_client.get_object(Bucket=dir_lst[2], Key='/'.join(dir_lst[3:]))
            img_bytes = s3_response_object['Body'].read()
            init_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
            init_img = init_img.resize((input_data['width'], input_data['height']))
        else:
            input_image = input_data['input_image']
            init_img = Image.open(io.BytesIO(base64.b64decode(input_image))).convert("RGB")
            init_img = init_img.resize((input_data['width'], input_data['height']))
        if input_data['seed'] == -1:
            generator = torch.Generator(device='cuda').manual_seed(random.randint(1, 10000000))
        else:
            generator = torch.Generator(device='cuda').manual_seed(input_data['seed'])
        with autocast('cuda'):
            img2img_model.scheduler = samplers[input_data["sampler"]].from_config(img2img_model.scheduler.config)
            images = img2img_model(
                input_data['prompt'],
                image=init_img,
                negative_prompt=input_data['negative_prompt'],
                num_inference_steps=input_data['steps'],
                num_images_per_prompt=input_data['count'],
                generator=generator).images
        print("Prediction: " + str(images) + '.'*20)
    
    else:
        if input_data['seed'] == -1:
            generator = torch.Generator(device='cuda').manual_seed(random.randint(1, 10000000))
        else:
            generator = torch.Generator(device='cuda').manual_seed(input_data['seed'])
        with autocast('cuda'):
            model.scheduler = samplers[input_data["sampler"]].from_config(model.scheduler.config)
            images = model(
                input_data['prompt'],
                input_data["height"],
                input_data["width"],
                negative_prompt=input_data['negative_prompt'],
                num_inference_steps=input_data['steps'],
                num_images_per_prompt=input_data['count'],
                generator=generator).images
        print("Prediction: " + str(images) + '.'*20)
    
    res = {'images': [], 'images_path': []}
    s3_resource = boto3.resource('s3')
    dir_lst = input_data['output_image_dir'].split('/')
    s3_bucket = dir_lst[2]
    for image in images:
        byteImgIO = io.BytesIO()
        image.save(byteImgIO, "WEBP")
        byteImgIO.seek(0)
        byteImg = byteImgIO.read()
        imgstr = base64.b64encode(byteImg).decode('ascii')
        res['images'].append(imgstr)
        
        img_id = uuid.uuid4().hex
        s3_object_key = '/'.join(dir_lst[3:]) + img_id + '.webp'
        s3_resource.Bucket(s3_bucket).put_object(Key=s3_object_key, Body=byteImg, ContentType='image/webp')
        image_output = 's3://{}/{}'.format(s3_bucket, s3_object_key)
        res['images_path'].append(image_output)

    return Output().add(json.dumps(res))

In [None]:
# compress code and upload to S3
!rm -f model.tar.gz
!rm -rf mymodel/.ipynb_checkpoints
!tar czvf model.tar.gz -C mymodel .
s3_code_artifact = sess.upload_data("model.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar uploaded to --- > {s3_code_artifact}")

#### Model deployment

In [None]:
# retrieve SageMaker LMI container image URI
image_uri = sagemaker.image_uris.retrieve(
    framework="djl-deepspeed", region=region, version="0.23.0"
)


print(image_uri)

model = Model(image_uri=image_uri, model_data=s3_code_artifact, role=role)

In [None]:
instance_type = "ml.g5.2xlarge"  # "ml.g5.2xlarge" - #single GPU. really need one GPU for this since tensor split is '1'

endpoint_name = "product-design-sd"

model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    container_startup_health_check_timeout=900,
)

#### Prediction

In [None]:
# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
    deserializer=deserializers.JSONDeserializer(),
)

In [None]:
import io
from PIL import Image
import base64
import json

def predict_fn(predictor, inputs):
    if 'input_image' in inputs:
        if inputs['input_image'].startswith('s3://'):
            response = predictor.predict(inputs)
        else:
            img = Image.open(inputs['input_image'])
            byteImgIO = io.BytesIO()
            img.save(byteImgIO, "PNG")
            byteImgIO.seek(0)
            byteImg = byteImgIO.read()
            imgstr = base64.b64encode(byteImg).decode('ascii')
            inputs['input_image'] = imgstr
            response = predictor.predict(inputs)
    else:
        response = predictor.predict(inputs)
    for image in response['images']:
        dataBytesIO = io.BytesIO(base64.b64decode(image))
        image = Image.open(dataBytesIO)
        display(image)
    for path in response['images_path']:
        print(path)

In [None]:
inputs = {
    "prompt": "3D product render, futuristic tent, finely detailed, purism, ue 5, a computer rendering, minimalism, octane render, 4k",
    "negative_prompt": "EasyNegative, (worst quality:2), (low quality:2), (normal quality:2), lowres, ((monochrome)), ((grayscale)), cropped, text, jpeg artifacts, signature, watermark, username, sketch, cartoon, drawing, anime, duplicate, blurry, semi-realistic, out of frame, ugly, deformed",
    "steps": 30,
    "sampler": "dpm2_a",
    "seed": -1,
    "height": 512,
    "width": 512,
    "count": 1,
    "output_image_dir": "s3://{}/product-design-output/".format(bucket)
}

predict_fn(predictor, inputs)

In [None]:
img2img_inputs = {
    "prompt": "3D product render, futuristic armchair, finely detailed, purism, ue 5, a computer rendering, minimalism, octane render, 4k",
    "negative_prompt": "EasyNegative, (worst quality:2), (low quality:2), (normal quality:2), lowres, ((monochrome)), ((grayscale)), cropped, text, jpeg artifacts, signature, watermark, username, sketch, cartoon, drawing, anime, duplicate, blurry, semi-realistic, out of frame, ugly, deformed",
    "steps": 30,
    "sampler": "euler_a",
    "seed": -1,
    "height": 512,
    "width": 512,
    "count": 2,
    "input_image": "chair.png",
    # "input_image": "s3://<IMAGE_LOCATION>",
    "output_image_dir": "s3://{}/product-design-output/".format(bucket)
}

predict_fn(predictor, img2img_inputs)