### SageMaker Stable diffusion Quick Kit - Inference 部署(Stable Diffusion XL , SDXL LORA)
   [SageMaker Stable Diffusion Quick Kit](https://github.com/aws-samples/sagemaker-stablediffusion-quick-kit) 提供了一组开箱即用的代码、配置文件，它可以帮助客户在亚马逊云上使用Amazon SageMaker , Lambda, Cloudfront快速构建Stable diffusion AI绘图服务.
   
   ![架构](https://raw.githubusercontent.com/aws-samples/sagemaker-stablediffusion-quick-kit/main/images/architecture.png)


#### 前提条件
1. 亚马逊云账号
2. 建议使用ml.g5.xlarge

### Notebook部署步骤
1. 升级boto3, sagemaker python sdk
2. 编译docker image
3. 部署AIGC推理服务
    * 配置模型参数
    * 配置异步推理
    * 部署SageMaker Endpoint 
4. 测试ControlNet模型
5. 清除资源


### 1. 升级boto3, sagemaker python sdk

In [None]:
!pip install --upgrade boto3 sagemaker

In [None]:
#导入对应的库

import re
import os
import json
import uuid

import numpy as np
import pandas as pd
from time import gmtime, strftime


import boto3
import sagemaker

from sagemaker import get_execution_role,session

role = get_execution_role()


sage_session = session.Session()
bucket = sage_session.default_bucket()
aws_region = boto3.Session().region_name


print(f'sagemaker sdk version: {sagemaker.__version__}\nrole:  {role}  \nbucket:  {bucket}')




### 2. 编译docker image (sdxl-inference-v2)

In [None]:
!./build_push.sh

### 3. 部署AIGC推理服务

#### 3.1 创建dummy model_data 文件(真正的模型使用code/infernece.py进行加载)

In [None]:
!touch dummy
!tar czvf model.tar.gz dummy sagemaker-logo-small.png
assets_dir = 's3://{0}/{1}/assets/'.format(bucket, 'stablediffusion')
model_data = 's3://{0}/{1}/assets/model.tar.gz'.format(bucket, 'stablediffusion')
!aws s3 cp model.tar.gz $assets_dir
!rm -f dummy model.tar.gz

#### 3.2 创建 model 配置

In [None]:

boto3_session = boto3.session.Session()
current_region=boto3_session.region_name

client = boto3.client("sts")
account_id=client.get_caller_identity()["Account"]

client = boto3.client('sagemaker')

#使用步骤2编译好的docker images
#默认名字为: sdxl-inference-v2
container = f'{account_id}.dkr.ecr.{current_region}.amazonaws.com/sdxl-inference-v2'

model_data = f's3://{bucket}/stablediffusion/assets/model.tar.gz'

model_name = 'AIGC-Quick-Kit-' +  strftime("%Y-%m-%d-%H-%M-%S", gmtime())
role = get_execution_role()

primary_container = {
    'Image': container,
    'ModelDataUrl': model_data,
    'Environment':{
        's3_bucket': bucket,
        'model_name':'stabilityai/stable-diffusion-xl-base-1.0' #使用SDXL 1.0
    }
}

create_model_response = client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = primary_container,


)

In [None]:
_time_tag = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
_variant_name =  'AIGC-Quick-Kit-'+ _time_tag
endpoint_config_name = 'AIGC-Quick-Kit-' +  _time_tag

response = client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            'VariantName': _variant_name,
            'ModelName': model_name,
            'InitialInstanceCount': 1,
            'InstanceType': 'ml.g5.2xlarge',
            'InitialVariantWeight': 1
        },
    ]
    ,
    AsyncInferenceConfig={
        'OutputConfig': {
            'S3OutputPath': f's3://{bucket}/stablediffusion/asyncinvoke/out/'
        }
    }
)

#### 3.3 部署SageMaker endpoint (这里只需要运行一次!!!)

In [None]:
endpoint_name = f'AIGC-Quick-Kit-{str(uuid.uuid4())}'

print(f'终端节点:{endpoint_name} 正在创建中，首次启动中会加载模型，请耐心等待, 请在控制台上查看状态')

response = client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name,
)



### 4.测试

#### 4.1 创建测试辅助方法 

In [None]:
import time
import uuid
import io
import traceback
from PIL import Image


s3_resource = boto3.resource('s3')

def get_bucket_and_key(s3uri):
    pos = s3uri.find('/', 5)
    bucket = s3uri[5 : pos]
    key = s3uri[pos + 1 : ]
    return bucket, key


def predict_async(endpoint_name,payload):
    runtime_client = boto3.client('runtime.sagemaker')
    input_file=str(uuid.uuid4())+".json"
    s3_resource = boto3.resource('s3')
    s3_object = s3_resource.Object(bucket, f'stablediffusion/asyncinvoke/input/{input_file}')
    payload_data = json.dumps(payload).encode('utf-8')
    s3_object.put( Body=bytes(payload_data))
    input_location=f's3://{bucket}/stablediffusion/asyncinvoke/input/{input_file}'
    print(f'input_location: {input_location}')
    response = runtime_client.invoke_endpoint_async(
        EndpointName=endpoint_name,
        InputLocation=input_location
    )
    result =response.get("OutputLocation",'')
    wait_async_result(result)


def s3_object_exists(s3_path):
    """
    s3_object_exists
    """
    try:
        s3 = boto3.client('s3')
        base_name=os.path.basename(s3_path)
        _,ext_name=os.path.splitext(base_name)
        bucket,key=get_bucket_and_key(s3_path)
        
        s3.head_object(Bucket=bucket, Key=key)
        return True
    except Exception as ex:
        print("job is not completed, waiting...")   
        return False
    
def draw_image(output_location):
    try:
        bucket, key = get_bucket_and_key(output_location)
        obj = s3_resource.Object(bucket, key)
        body = obj.get()['Body'].read().decode('utf-8') 
        predictions = json.loads(body)
        print(predictions['result'])
        for image in predictions['result']:
            bucket, key = get_bucket_and_key(image)
            obj = s3_resource.Object(bucket, key)
            bytes = obj.get()['Body'].read()
            image = Image.open(io.BytesIO(bytes))
            #resize image to 50% size
            half = 0.5
            out_image = image.resize( [int(half * s) for s in image.size] )
            out_image.show()
    except Exception as e:
        print("result is not completed, waiting...")   
    

    
def wait_async_result(output_location,timeout=60):
    current_time=0
    while current_time<timeout:
        if s3_object_exists(output_location):
            print("have async result")
            draw_image(output_location)
            break
        else:
            time.sleep(5)

            
        
def check_sendpoint_status(endpoint_name,timeout=600):
    client = boto3.client('sagemaker')
    current_time=0
    while current_time<timeout:
        client = boto3.client('sagemaker')
        try:
            response = client.describe_endpoint(
            EndpointName=endpoint_name
            )
            if response['EndpointStatus'] !='InService':
                raise Exception (f'{endpoint_name} not ready , please wait....')
        except Exception as ex:
            print(f'{endpoint_name} not ready , please wait....')
            time.sleep(10)
        else:
            status = response['EndpointStatus']
            print(f'{endpoint_name} is ready, status: {status}')
            break
        

#### 检查endpoint 状态

In [None]:
check_sendpoint_status(endpoint_name)

In [None]:
def wait_endpoint_ready(endpoint_name,timeout=600):
    current_time=0
    while current_time<timeout:
        client = boto3.client('sagemaker')
        response = client.describe_endpoint(
        EndpointName=endpoint_name
        )
        if response['EndpointStatus'] !='InService':
            raise Exception (f'{endpoint_name} not ready , please wait....')
            time.sleep(5)
        else:
            status = response['EndpointStatus']
            print(f'{endpoint_name} is ready, status: {status}')
            break

In [None]:
wait_endpoint_ready(endpoint_name)

### 4.1 测试 SDXL with out refiner
首次执行的时候SageMaker会从HaggingFace拉取stabilityai/stable-diffusion-xl-base-1.0 模型，需要稍微等待一下


In [None]:

payload={
                    "prompt": "a fantasy creaturefractal dragon",
                    "steps":20,
                    "sampler":"euler_a",
                    "count":1,
                    "SDXL_REFINER":"disable"
}

predict_async(endpoint_name,payload)


### 4.2 测试SDXL with refiner
首次执行的时候SageMaker会从HaggingFace拉取stabilityai/stable-diffusion-xl-refiner-1.0 模型，需要稍微等待一下

SDXL_REFINER设置为enable


In [None]:
payload={
                    "prompt": "a fantasy creaturefractal dragon",
                    "steps":20,
                    "sampler":"euler_a",
                    "count":1,
                    "SDXL_REFINER":"enable"
}

predict_async(endpoint_name,payload)

### 4.3 测试SDXL with Kohya-style LORA

In [None]:
payload={
                    "prompt": "a fantasy creaturefractal dragon",
                    "steps":20,
                    "sampler":"euler_a",
                    "count":1,
                    "control_net_enable":"disable",
                    "SDXL_REFINER":"enable",
                     "lora_name":"dragon",
                    "lora_url":"https://civitai.com/api/download/models/129363"
}

predict_async(endpoint_name,payload)

### 5 清除资源

In [None]:
response = client.delete_endpoint(
    EndpointName=endpoint_name
    
)

response = client.delete_endpoint_config(
    EndpointConfigName=endpoint_config_name
)


print(f'终端节点:{endpoint_name} 已经被清除，请在控制台上查看状态')
