## Inpainting Inference using Online Endpoints

This sample shows how to deploy `inpainting` type stable diffusion models to an online endpoint for inference.

### Task
`inpainting` task takes an original image, a text prompt and a mask image as input. The model generates inpainted image by modifying the original image.

 
### Model
Models that can perform the `inpainting` task are tagged with `text-to-image`. We will use the `runwayml-stable-diffusion-inpainting` model in this notebook. If you opened this notebook from a specific model card, remember to replace the specific model name.


### Outline
1. Setup pre-requisites
2. Pick a model to deploy
3. Deploy the model to an online endpoint for real time inference
4. Test the endpoint using sample text prompt, original image and mask image.
5. Clean up resources - delete the online endpoint

### 1. Setup pre-requisites
* Connect to AzureML Workspace. Learn more at [set up SDK authentication](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication?tabs=sdk). Replace  `<WORKSPACE_NAME>`, `<RESOURCE_GROUP>` and `<SUBSCRIPTION_ID>` below.
* Connect to `azureml` system registry

In [None]:
from azure.ai.ml import MLClient
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
import time

try:
    credential = DefaultAzureCredential()
    credential.get_token("https://management.azure.com/.default")
except Exception as ex:
    credential = InteractiveBrowserCredential()

try:
    workspace_ml_client = MLClient.from_config(credential)
    subscription_id = workspace_ml_client.subscription_id
    resource_group = workspace_ml_client.resource_group_name
    workspace_name = workspace_ml_client.workspace_name
except Exception as ex:
    print(ex)
    # Enter details of your AML workspace
    subscription_id = "<SUBSCRIPTION_ID>"
    resource_group = "<RESOURCE_GROUP>"
    workspace_name = "<AML_WORKSPACE_NAME>"
workspace_ml_client = MLClient(
    credential, subscription_id, resource_group, workspace_name
)

# The models, fine tuning pipelines and environments are available in the AzureML system registry, "azureml"
registry_ml_client = MLClient(
    credential,
    subscription_id,
    resource_group,
    registry_name="azureml",
)

### 2. Pick a model to deploy

Browse models in the Model Catalog in the AzureML Studio, filtering by the `text-to-image` task. In this example, we use the `runwayml-stable-diffusion-inpainting` model. If you have opened this notebook for a different model, replace the model name accordingly. This is a pre-trained model.

In [None]:
# Name of the inpainting model to be deployed
model_name = "runwayml-stable-diffusion-inpainting"

try:
    model = registry_ml_client.models.get(name=model_name, label="latest")
except Exception as ex:
    print(
        f"No model named {model_name} found in registry. "
        "Please check model name present in Azure model catalog"
    ) 
    raise ex

print(
    f"\n\nUsing model name: {model.name}, version: {model.version}, id: {model.id} for generating images from text."
)

### 3. Deploy the model to an online endpoint for real time inference
Online endpoints give a durable REST API that can be used to integrate with applications that need to use the model.

In [None]:
import uuid
from azure.ai.ml.entities import ManagedOnlineEndpoint, ManagedOnlineDeployment

# Endpoint names need to be unique in a region, hence using uuid (first 8 character) to create unique endpoint name
online_endpoint_name = (
    "inpainting-" + str(uuid.uuid4())[:8]
)  # Replace with your endpoint name
# Create an online endpoint
endpoint = ManagedOnlineEndpoint(
    name=online_endpoint_name,
    description="Online endpoint for " + model.name + ", for inpainting task",
    auth_mode="key",
)
workspace_ml_client.begin_create_or_update(endpoint).wait()

In [None]:
from azure.ai.ml.entities import OnlineRequestSettings, ProbeSettings

deployment_name = "inpainting-deploy"

# Create a deployment
demo_deployment = ManagedOnlineDeployment(
    name=deployment_name,
    endpoint_name=online_endpoint_name,
    model=model.id,
    instance_type="Standard_NC6s_v3",  # Use GPU instance type like Standard_NC6s_v3 or above
    instance_count=1,
    request_settings=OnlineRequestSettings(
        max_concurrent_requests_per_instance=1,
        request_timeout_ms=90000,
        max_queue_wait_ms=500,
    ),
    liveness_probe=ProbeSettings(
        failure_threshold=49,
        success_threshold=1,
        timeout=299,
        period=180,
        initial_delay=180,
    ),
    readiness_probe=ProbeSettings(
        failure_threshold=10,
        success_threshold=1,
        timeout=10,
        period=10,
        initial_delay=10,
    ),
)
workspace_ml_client.online_deployments.begin_create_or_update(demo_deployment).wait()
endpoint.traffic = {deployment_name: 100}
workspace_ml_client.begin_create_or_update(endpoint).result()

### 4. Test the endpoint

We will fetch some sample data from the test dataset and submit to online endpoint for inference.

The sample of input schema for inpainting task:
```json
{
   "input_data": {
        "columns": ["prompt", "image", "mask_image"],
        "data": [
            {
                "prompt": "sample prompt",
                "image": "base image1",
                "mask_image": "mask image1"
            },
            {
                "prompt": "sample prompt",
                "image": "base image2",
                "mask_image": "mask image2"
            }
        ],
        "index": [0, 1]
    }
}
```
> - The base and mask images (1 and 2) strings should be in base64 format or publicly accessible urls.
> - The mask structure is white for inpainting and black for keeping as is

The sample of output schema for inpainting task:
```json
[
    {
        "generated_image": "image1",
        "nsfw_content_detected": False
    },
    {
        "generated_image": "image2",
        "nsfw_content_detected": True
    }
]
```
> - If "nsfw_content_detected" is True then generated image will be totally black.
> - Generated images "image1" and "image2" strings are in base64 format.

In [None]:
# Create request json
import base64
import json


def read_image(image_path: str) -> bytes:
    """Reads an image from a file path into a byte array."""
    with open(image_path, "rb") as f:
        return f.read()


base_image = "inpainting_data/images/dog_on_bench.png"
mask_image = "inpainting_data/images/dog_on_bench_mask.png"

request_json = {
    "input_data": {
        "columns": ["image", "mask_image", "prompt"],
        "index": [0],
        "data": [
                {
                    "image": base64.encodebytes(read_image(base_image)).decode("utf-8"),
                    "mask_image": base64.encodebytes(read_image(mask_image)).decode("utf-8"),
                    "prompt": "A yellow cat, high resolution, sitting on a park bench"
                }
        ],
    }
}

request_file_name = "sample_request_data.json"

with open(request_file_name, "w") as request_file:
    json.dump(request_json, request_file)

In [None]:
response = workspace_ml_client.online_endpoints.invoke(
    endpoint_name=online_endpoint_name,
    deployment_name=demo_deployment.name,
    request_file=request_file_name,
)

In [None]:
import io
import base64
from PIL import Image

generations = json.loads(response)
for generation in generations:
    print(f"nsfw content detected: ", generation["nsfw_content_detected"])
    img = Image.open(io.BytesIO(base64.b64decode(generation["generated_image"])))
    display(img)

### 5. Clean up resources - delete the online endpoint
Don't forget to delete the online endpoint, else you will leave the billing meter running for the compute used by the endpoint.

In [None]:
workspace_ml_client.online_endpoints.begin_delete(name=online_endpoint_name).wait()