## Creative Content Assisted by Generative AI using Amazon SageMaker: Magic Fill/Replace
---
In this notebook, we will extend the inpainting eraser example by replacing the LaMa model with **[Stable Diffusion (SD) Inpaint](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting)**. This example not only can erase object from image, it also can fill object or replace background with a simple text prompt. The diagram below illustrated the capabilities.

![magic_fill_replace](https://raw.github.com/geekyutao/Inpaint-Anything/main/example/MainFramework.png)

To generate segmentation, we used a foundation model developed by Meta Research called **[Segment Anything Model (SAM)](https://segment-anything.com/) - Apache-2.0 license**. This model is trained on a massive dataset called SA-1B with over 11 million images and 1.1 billion segmentation masks.  This massive scale gave Sam model unprecedented ability to identify and isolate objects from an image out of the box without training.

To fill/replace, we will use **[Stable Diffusion (SD) Inpaint](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) - CreativeML Open RAIL++-M License** model from Stabilityai. This model does image-to-image, but also allows you to supply a mask. The model take the entire image as context and generate the mask region according to a set of text prompts.

**Note: please run the `0_setup.ipynb` notebook first before starting on this example. We recommend to use pytorch kernel on SageMaker Notebook Instance using `ml.g4dn.xlarge`**

### Setup

In [None]:
import sagemaker, boto3, json
from sagemaker import get_execution_role

from matplotlib import pyplot as plt
import io
import base64

import time

from PIL import Image
import numpy as np
import json

%matplotlib inline  

role = get_execution_role()

sm_client = boto3.client(service_name="sagemaker")
runtime_sm_client = boto3.client("sagemaker-runtime")
s3 = boto3.client('s3')

sagemaker_session = sagemaker.Session(boto_session=boto3.Session())
region = sagemaker_session.boto_region_name
account = sagemaker_session.account_id()
bucket = sagemaker_session.default_bucket()
prefix = 'magic-fill-replace'

%store -r extended_triton_image_uri

## Serve models wtih Triton inference server

We will use Triton Python backend to deploy and host these models on SageMaker MME. Triton server requires our models to be package in following folder structure. We can find these already provided in the `model_repo` folder.
```
|-model_repo
    |---sam
        |----1
             |--model.py
        |----config.pbtxt
    |---sd_inpaint
        |----1
             |--model.py
             |--mask_processing.py
        |----config.pbtxt
```

We are using Python backend to load our models. In order to use Python backend, you will need at least a Triton config file, and a Python file named `model.py` that is the entry point for your model. Let's explore the structure for each file. 

`config.pbtxt` is a manditory configuration file for Triton that config the backend type, batch size, input, output format, etc.

In [None]:
!cat model_repo/sd_inpaint/config.pbtxt

Python backend script needs to define a TritonPythonModel class with four potential functions. Refer to [Triton Python backend documentation](https://github.com/triton-inference-server/python_backend) for more details

```python
import triton_python_backend_utils as pb_utils
class TritonPythonModel:
    """Your Python model must use the same class name. Every Python model
    that is created must have "TritonPythonModel" as the class name.
    """
    def auto_complete_config(auto_complete_model_config):
    def initialize(self, args):
    def execute(self, requests):
    def finalize(self):
```

In [None]:
!cat model_repo/sd_inpaint/1/model.py

## Deploy Models to MME

In [None]:
!rm -rf `find -type d -name .ipynb_checkpoints`  
!find . | grep -E "(__pycache__|\.pyc$)" | xargs sudo rm -rf

In [None]:
model_dir = "model_repo"
models = ["sam", "sd_inpaint"]
v_ = 0

model_targets = dict()
for m in models:
    
    tar_name = f"{m}-v{v_}.tar.gz"
    model_targets[m] = tar_name

    !tar -C $model_dir -zcvf $tar_name $m
    
    sagemaker_session.upload_data(path=tar_name, key_prefix=f"{prefix}/models")

    
print(model_targets)

**Define the Serving Container**

Start with a container definition. Define the ModelDataUrl to specify the S3 directory that contains all the models that SageMaker multi-model endpoint will use to load and serve predictions. Set Mode to MultiModel to indicates SageMaker would create the endpoint with MME container specifications. 

In [None]:
model_data_url = f"s3://{bucket}/{prefix}/models/"
ts = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {
    "Image": extended_triton_image_uri,
    "ModelDataUrl": model_data_url,
    "Mode": "MultiModel",
}


**Setup SM Model**

Using the SageMaker boto3 client, create the model using [create_model](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model) API. We will pass the container definition to the create model API along with ModelName and ExecutionRoleArn.

In [None]:
sm_model_name = f"{prefix}-models-{ts}"

create_model_response = sm_client.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

**Create a SageMaker endpoint configuration.**

Create a multi-model endpoint configuration using create_endpoint_config boto3 API. Specify an accelerated GPU computing instance in InstanceType (we will use the same instance type that we are using to host our SageMaker Notebook). We recommend configuring your endpoints with at least two instances with real-life use-cases. This allows SageMaker to provide a highly available set of predictions across multiple Availability Zones for the models.

In [None]:
endpoint_config_name = f"{prefix}-config-{ts}"
instance_type = 'ml.g5.2xlarge'

create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": instance_type,
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

**Create endpoint**

Using the above endpoint configuration we create a new sagemaker endpoint and wait for the deployment to finish. The status will change to **InService** once the deployment is successful.

In [None]:
endpoint_name = f"{prefix}-ep-{ts}"

create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

## Invoking the models

Now we can test our models. We want to firt call the sam model to generate a segmentation mask. 

---

### Invoke SAM model
This primary input for this model is the image and the [x, y] coordinates of the image pixle to locate the object. We need to encode the image into bytes before sending it to the endpoint.

Optionally, you can also   pass in `point_labels` if you need segment object in multiple class. or `dilate_kernel_size` if you need to play with the sharpness of the mask. 

In [None]:
def encode_image(img):
    
    # Convert the image to bytes
    with io.BytesIO() as output:
        img.save(output, format="JPEG")
        img_bytes = output.getvalue()
    
    return base64.b64encode(img_bytes).decode('utf8')

Here is how you can invoke the SageMaker MME

In [None]:
# pixle coordinate of dog in dog.jpg is 200, 450
# pixle coordinate of dog in sample1.png is 750, 500
img_file='statics/sample1.png'
original_image = Image.open(img_file)

print("Original Image")
display(original_image)
original_image_bytes = encode_image(original_image)

gen_args = json.dumps(dict(point_coords=[750, 500], point_labels=1, dilate_kernel_size=15))

inputs = dict(image=original_image_bytes,
              gen_args = gen_args)

payload = {
    "inputs":
        [{"name": name, "shape": [1,1], "datatype": "BYTES", "data": [data]} for name, data in inputs.items()]
}

Notice when you invoke the model the first time, the latency is much higher due to cold start. Every subsequent calls will be much faster because the model is cached in memory.

In [None]:
%%time
response = runtime_sm_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType="application/octet-stream",
        Body=json.dumps(payload),
        TargetModel=model_targets["sam"], 
    )

In [None]:
output = json.loads(response["Body"].read().decode("utf8"))["outputs"]
mask_decoded = io.BytesIO(base64.b64decode(output[0]["data"][0]))
mask_rgb = Image.open(mask_decoded).convert("RGB")

print("Object Mask")
display(mask_rgb)

### Invoke Stable Diffusion Inpaint Model

we need to pass the `mask_rgb` which indicates which regions of the image should be filled. In parallel, you can use text prompt to control what to generate in the masked area. If you leave the prompt as an empty string, the model can provide the same effect as remove the object from the image. If you remove the black and white color of the mask, the model will replace the background instead of fill the object.

In [None]:
# Inputs ==================
# original_image_bytes
mask_image = encode_image(mask_rgb)
prompt = "a teddy bear on a bench"

nprompt = "ugly, distorted"

gen_args = json.dumps(dict(num_inference_steps=50, guidance_scale=10, seed=1))

inputs = dict(image=original_image_bytes,
              mask_image=mask_image,
              prompt = prompt,
              negative_prompt = nprompt,
              gen_args = gen_args)

payload = {
    "inputs":
        [{"name": name, "shape": [1,1], "datatype": "BYTES", "data": [data]} for name, data in inputs.items()]
}

In [None]:
%%time
response = runtime_sm_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType="application/octet-stream",
        Body=json.dumps(payload),
        TargetModel=model_targets["sd_inpaint"], 
    )

In [None]:
output = json.loads(response["Body"].read().decode("utf8"))["outputs"]
mask_decoded = io.BytesIO(base64.b64decode(output[0]["data"][0]))
mask_rgb = Image.open(mask_decoded).convert("RGB")

print("Object Filled")
display(mask_rgb)

### [Optional] Gradio UI
Write configuration file for the Gradio app

In [None]:
config = dict()
config["endpoint_name"] = endpoint_name
config["models"] = model_targets

with open("config.json", 'w') as f:
    json.dump(config, f)

1) Open up a system terminal and navigate into this folder

2) install packages. most important is to pip install `gradio`

```
pip install -r requirements.txt
```
3) run the following commend to launch the app
```
python run.py
```

4) click on the public link to open up the ui in your browser.

### Clean Up
When you are done delete the endpoint to stop incurring charges

In [None]:
response = sm_client.delete_endpoint(
    EndpointName=endpoint_name
)