# Stable Diffusion Inpainting with ClipSeg

Deploy Stable Diffusion Inpainting pipeline with [ClipSeg](https://huggingface.co/blog/clipseg-zero-shot).

User can generate inpainted image without creating their own mask image. User can specify mask with text.

In [None]:
import sagemaker, boto3, json
from sagemaker import get_execution_role
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.huggingface import HuggingFace
from sagemaker.pytorch import PyTorch

role = get_execution_role()
region = boto3.Session().region_name
sess = sagemaker.Session()
bucket = sess.default_bucket()

sagemaker.__version__

In [None]:
%%writefile scripts/code/requirements.txt
transformers
diffusers
accelerate

In [None]:
%%writefile scripts/code/inference.py
import torch
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from diffusers import DiffusionPipeline
from torch import autocast

import json
import base64
from PIL import Image
from io import BytesIO
import numpy as np

def model_fn(model_dir):
    processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
    model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
    pipe = DiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-inpainting",
        custom_pipeline="text_inpainting",
        segmentation_model=model,
        segmentation_processor=processor
    )
    pipe = pipe.to("cuda")
    return pipe

def input_fn(data, content_type):
    if content_type == 'application/json':
        data = json.loads(data)
    else:
        raise TypeError('content_type is only allowed application/json')
    return data

def predict_fn(data, model):
    pipe = model
    image_decoded = BytesIO(base64.b64decode(data['image'].encode()))
    image = Image.open(image_decoded).convert("RGB")
    data["image"] = image
    with autocast("cuda"):
        image = pipe(**data).images[0]
    # Convert to JSON Encoded Image
    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode()


def output_fn(data, accept_type):
    if accept_type == 'application/json':
        data = json.dumps({'generated_image' : data})
    else:
        raise TypeError('content_type is only allowed application/json')
    return data

In [None]:
!rm -rf scripts/model
%cd scripts
!tar -czvf ../package.tar.gz *
%cd -

In [None]:
model_path = sess.upload_data("package.tar.gz", bucket=bucket, key_prefix=f"StableDiffusionInpainting-ClipSeg")
model_path

In [None]:
endpoint_name = "StableDiffusionInpainting-CLIPSeg"

huggingface_model = PyTorchModel(
    model_data=model_path,
    framework_version="2.0",
    py_version='py310',
    role=role,
    name=endpoint_name
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type='ml.g5.2xlarge',
    endpoint_name=endpoint_name
)

In [None]:
from sagemaker.predictor import Predictor
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

import base64
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np

predictor_client = Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer()
)

prompt = "a cat"
mask_text = "a dog"
input_img_file_name = "dog_suit.jpg"
with open(input_img_file_name, "rb") as f:
    input_img_image_bytes = f.read()
encoded_input_image = base64.b64encode(bytearray(input_img_image_bytes)).decode()
data = {
    "prompt": prompt,
    "text": mask_text,
    "image": encoded_input_image
}
response = predictor_client.predict(
    data=data
)
mask = response["generated_image"]

In [None]:
def display_img_and_prompt(img, prmpt):
    """Display the generated image."""
    plt.figure(figsize=(12, 12))
    plt.imshow(np.array(img))
    plt.axis("off")
    plt.title(prmpt)
    plt.show()

generated_image_decoded = BytesIO(base64.b64decode(mask.encode()))
generated_image_rgb = Image.open(generated_image_decoded).convert("RGB")
display_img_and_prompt(generated_image_rgb, prompt)

In [None]:
predictor_client.delete_model()
predictor_client.delete_endpoint()