# Change background using Grouded SAM and Stable Diffusion

在本节实验中，我们将会部署 Grounded SAM 和 Stable diffusion 2 inpainting model

Grounded SAM 用于根据提示词来找到对象所在的位置，然后进行语义分割，对物体背景进行蒙版（MASK）。然后将蒙版后的图片给到 SD inpainting 模型进行局部重绘，实现背景的替换

## Deploy Grounded Segment Anything

>注意：执行单元格代码框后，若左侧中括号中的符号为'*'，表示代码正在运行过程中；若为数字，则表示代码已执行完成。

In [None]:
# init sagemaker parameters
import boto3
import sagemaker
from sagemaker import serializers, deserializers
from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

s3_model_prefix = "east-ai-models/grounded-sam"

print(f"role: {role}")
print(f"bucket: {bucket}")

##### 压缩 dummy 文件并上传至 S3

In [None]:
# compress dummy model and upload to S3
!touch dummy
!rm -f model.tar.gz
!tar czvf model.tar.gz dummy
s3_model_artifact = sess.upload_data("model.tar.gz", bucket, s3_model_prefix)
print(f"S3 Code or Model tar uploaded to --- > {s3_model_artifact}")
!rm -f dummy

#### Model deployment

In [None]:
framework_version = '2.0.1'
py_version = 'py310'
instance_type = "ml.g4dn.xlarge"
endpoint_name ="grounded-sam"

model = PyTorchModel(
    model_data = s3_model_artifact,
    entry_point = 'inference.py',
    source_dir = "./code/",
    role = role,
    framework_version = framework_version, 
    py_version = py_version,
)

print("模型部署过程大约需要 7~8 分钟，请等待" + "."*20)

model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
)

print("模型部署已完成，可以继续执行后续步骤" + "."*20)

In [None]:
# our requests and responses will be in json format so we specify the serializer and the deserializer
sam_predictor = PyTorchPredictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
    deserializer=deserializers.JSONDeserializer(),
)

## Deploy inpainting stable diffusion

In [None]:
s3_code_prefix = "east-ai-models/inpainting-sd/accelerate"

!mkdir inpaintmodel

#### Writing SageMaker LMI code properties and model.py

In [None]:
%%writefile ./inpaintmodel/requirements.txt
transformers
diffusers==0.17.0
omegaconf
accelerate
boto3

In [None]:
%%writefile ./inpaintmodel/serving.properties
engine=Python
option.model_id=stabilityai/stable-diffusion-2-inpainting
option.tensor_parallel_degree=1

In [None]:
%%writefile ./inpaintmodel/model.py
from djl_python import Input, Output
import os
import torch
from typing import Any, Dict, Tuple
import warnings
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionInpaintPipeline
from diffusers import EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler, LMSDiscreteScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler,DDIMScheduler
import io
from PIL import Image
import base64
import json
import boto3
from torch import autocast
import random
import numpy as np


model = None

def image_read(image_file):
    return Image.open(image_file).convert("RGB")


def mask_read(mask_file):
    return Image.open(mask_file).convert('1') 
    
    
def image_fuser(new_image_lst, org_image, mask):
    results = []
    for new_image in new_image_lst:
        new_image = np.array(new_image)
        org_image = np.array(org_image)
        mask = np.array(mask)
        org_image[mask] = new_image[mask]
        results.append(Image.fromarray(org_image))
    return results


def generate_image(image, mask, prompt, negative_prompt, generator, pipe, num_inference_steps, num_images_per_prompt):
    # resize for inpainting 
    w, h = image.size
    in_image = image.resize((512, 512))
    in_mask = mask.resize((512, 512))
    image_gen = pipe(image=in_image, mask_image=in_mask, prompt=prompt, negative_prompt=negative_prompt, generator=generator, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt).images
    results = image_fuser(image_gen, in_image, in_mask)
    results = [r.resize((w, h)) for r in results]
    return results


def get_model(properties):
    print(properties)
    model_name = properties["model_id"]
    model = StableDiffusionInpaintPipeline.from_pretrained(model_name, torch_dtype=torch.float16)
    model = model.to("cuda")
    return model


def handle(inputs: Input) -> None:
    global model
    print("print inputs: " + str(inputs) + '.'*20)
    
    if not model:
        model = get_model(inputs.get_properties())
    
    samplers = {
        "euler_a": EulerAncestralDiscreteScheduler,
        "eular": EulerDiscreteScheduler,
        "heun": HeunDiscreteScheduler,
        "lms": LMSDiscreteScheduler,
        "dpm2": KDPM2DiscreteScheduler,
        "dpm2_a": KDPM2AncestralDiscreteScheduler,
        "ddim": DDIMScheduler
    }

    if inputs.is_empty():
        # Model server makes an empty call to warmup the model on startup
        return None
    
    input_data = inputs.get_as_json()
    
    dir_lst = input_data['input_image'].split('/')
    s3_client = boto3.client('s3')
    s3_response_object = s3_client.get_object(Bucket=dir_lst[2], Key='/'.join(dir_lst[3:]))
    img_bytes = s3_response_object['Body'].read()
    org_image = image_read(io.BytesIO(img_bytes))
    
    dir_lst = input_data['input_mask_image'].split('/')
    s3_client = boto3.client('s3')
    s3_response_object = s3_client.get_object(Bucket=dir_lst[2], Key='/'.join(dir_lst[3:]))
    img_bytes = s3_response_object['Body'].read()
    seg_mask = mask_read(io.BytesIO(img_bytes))
    
    if input_data['seed'] == -1:
        generator = torch.Generator(device='cuda').manual_seed(random.randint(1, 10000000))
    else:
        generator = torch.Generator(device='cuda').manual_seed(input_data['seed'])
    
    model.scheduler = samplers[input_data["sampler"]].from_config(model.scheduler.config)
    
    inpaint_prompt = input_data['prompt']
    inpaint_negative_prompt = input_data['negative_prompt']
    num_inference_steps = input_data['steps']
    num_images_per_prompt = input_data['count']
    inpainted_images = generate_image(org_image, seg_mask, inpaint_prompt, inpaint_negative_prompt, generator, model, num_inference_steps, num_images_per_prompt)
    print("Prediction Complete" + '.'*20)
    
    res = {'images': []}
    for image in inpainted_images:
        byteImgIO = io.BytesIO()
        image.save(byteImgIO, "WEBP")
        byteImgIO.seek(0)
        byteImg = byteImgIO.read()
        imgstr = base64.b64encode(byteImg).decode('ascii')
        res['images'].append(imgstr)
        
    return Output().add(json.dumps(res))

In [None]:
# compress code and upload to S3
!rm -f model.tar.gz
!rm -rf inpaintmodel/.ipynb_checkpoints
!tar czvf model.tar.gz -C inpaintmodel .
s3_code_artifact = sess.upload_data("model.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar uploaded to --- > {s3_code_artifact}")

#### Model deployment

In [None]:
# retrieve SageMaker LMI container image URI
from sagemaker import Model

image_uri = sagemaker.image_uris.retrieve(
    framework="djl-deepspeed", region=region, version="0.23.0"
)


print(image_uri)

model = Model(image_uri=image_uri, model_data=s3_code_artifact, role=role)

In [None]:
instance_type = "ml.g4dn.xlarge"

endpoint_name = "inpainting-sd"

print("模型部署过程大约需要 7~8 分钟，请等待" + "."*20)

model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    container_startup_health_check_timeout=900,
)

print("模型部署已完成，可以继续执行后续步骤" + "."*20)

In [None]:
# our requests and responses will be in json format so we specify the serializer and the deserializer
sd_predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
    deserializer=deserializers.JSONDeserializer(),
)

## Prediction(Optional)

#### Predict using grounded_sam to generate mask image

In [None]:
# 将 <object_uuid> 替换为上一节 notebook（product_design_sd.ipynb）中推理部分生成的其中一张图片的名称
input_image_path = 's3://{}/product-design-output/<object_uuid>.webp'.format(bucket)

In [None]:
# generate mask image to s3
input_data = {
                'input_image': input_image_path,
                'prompt': 'tent',
                'output_mask_image_dir': 's3://{}/mask-images/'.format(bucket)
             }

mask_res = sam_predictor.predict(input_data)
mask_res

In [None]:
# 查看 mask 后的图片
import io
from PIL import Image

dir_lst = mask_res['result'].split('/')
s3_client = boto3.client('s3')
s3_response_object = s3_client.get_object(Bucket=dir_lst[2], Key='/'.join(dir_lst[3:]))
img_bytes = s3_response_object['Body'].read()
Image.open(io.BytesIO(img_bytes)).convert("RGB")

#### Predict using inpainting SD model

In [None]:
import base64
import json

def predict_fn(predictor, inputs):
    response = predictor.predict(inputs)
    for image in response['images']:
        dataBytesIO = io.BytesIO(base64.b64decode(image))
        image = Image.open(dataBytesIO)
        display(image)

In [None]:
inputs = {
    "prompt": "tent on the ground, mountain and snow, high quality, 4k",
    "negative_prompt": "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, disfigured, gross proportions",
    "input_image": input_image_path,
    "input_mask_image": mask_res['result'],
    "steps": 30,
    "sampler": "ddim",
    "seed": -1,
    "count": 2
}

predict_fn(sd_predictor, inputs)