# Create an Azure AI Content Safety __(AACS)__ enabled Inpainting online endpoint (Preview)
### This notebook is under preview.

### Steps to create an __AACS__ enabled __inpainting__ online endpoint
1. Create an __AACS__ resource for moderating the request from user and response from the __inpainting__ online endpoint.
2. Create a new __AACS__ enabled __inpainting__ online endpoint with a custom [score_online.py](./aacs-scoring-files/score/score_online.py) which will integrate with the __AACS__ resource to moderate the response from the __text-to-image__ inpainting model and the request from the user, but to make the custom [score_online.py](./aacs-scoring-files/score/score_online.py) to successfully authenticated to the __AACS__ resource, we have 2 options:
    1. __UAI__, recommended but more complex approach, is to create a User Assigned Identity (UAI) and assign appropriate roles to the UAI. Then, the custom [score_online.py](./aacs-scoring-files/score/score_online.py) can obtain the access token of the UAI from the AAD server to access the AACS resource. Use [this notebook](aacs-prepare-uai.ipynb) to create UAI account for step 3 below
    2. __Environment variable__, simpler but less secure approach, is to just pass the access key of the AACS resource to the custom [score_online.py](./aacs-scoring-files/score/score_online.py) via environment variable, then the custom [score_online.py](./aacs-scoring-files/score/score_online.py) can use the key directly to access the AACS resource, this option is less secure than the first option, if someone in your org has access to the endpoint, he/she can get the access key from the environment variable and use it to access the AACS resource.


### Task
`inpainting` task takes an original image, a text prompt and a mask image as input. The model generates inpainted image by modifying the original image.

 
### Model
Models that can perform the `inpainting` task are tagged with `text-to-image`. We will use the `runwayml-stable-diffusion-inpainting` model in this notebook. If you opened this notebook from a specific model card, remember to replace the specific model name.

### Outline
1. Setup pre-requisites
2. Create AACS resource
3. Pick a model to deploy
4. Deploy the model to an online endpoint for real time inference
5. Test the endpoint
6. Clean up resources - delete the online endpoint

### 1. Setup pre-requisites
* Check List
* Install dependencies
* Connect to AzureML Workspace. Learn more at [set up SDK authentication](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication?tabs=sdk). Replace  `<WORKSPACE_NAME>`, `<RESOURCE_GROUP>` and `<SUBSCRIPTION_ID>` below.
* Connect to `azureml` system registry

> [x] The identity you are using to execute this notebook(yourself or your VM) need to have the __Contributor__ role on the resource group where the AML Workspace your specified is located, because this notebook will create an AACS resource using that identity.

In [None]:
# Install the required packages
%pip install azure-identity==1.13.0
%pip install azure-mgmt-cognitiveservices==13.4.0
%pip install azure-ai-ml==1.11.1
%pip install azure-mgmt-msi==7.0.0
%pip install azure-mgmt-authorization==3.0.0

In [None]:
from azure.ai.ml import MLClient
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential

try:
    credential = DefaultAzureCredential()
    credential.get_token("https://management.azure.com/.default")
except Exception as ex:
    credential = InteractiveBrowserCredential()

try:
    workspace_ml_client = MLClient.from_config(credential)
    subscription_id = workspace_ml_client.subscription_id
    resource_group = workspace_ml_client.resource_group_name
    workspace_name = workspace_ml_client.workspace_name
except Exception as ex:
    print(ex)
    # Enter details of your AML workspace
    subscription_id = "<SUBSCRIPTION_ID>"
    resource_group = "<RESOURCE_GROUP>"
    workspace_name = "<AML_WORKSPACE_NAME>"
workspace_ml_client = MLClient(
    credential, subscription_id, resource_group, workspace_name
)

print(f"Connected to workspace {workspace_name}")

In [None]:
# The models, fine tuning pipelines and environments are available in the AzureML system registry, "azureml"

registry_name = "azureml"

registry_ml_client = MLClient(
    credential,
    subscription_id,
    resource_group,
    registry_name=registry_name,
)

### 2. Create AACS resource

#### 2.1 Assign variables for Azure Content Safety
Currently, AACS is available in a limited set of regions:


__NOTE__: before you choose the region to deploy the AACS, please be aware that your data will be transferred to the region you choose and by selecting a region outside your current location, you may be allowing the transmission of your data to regions outside your jurisdiction. It is important to note that data protection and privacy laws may vary between jurisdictions. Before proceeding, we strongly advise you to familiarize yourself with the local laws and regulations governing data transfer and ensure that you are legally permitted to transmit your data to an overseas location for processing. By continuing with the selection of a different region, you acknowledge that you have understood and accepted any potential risks associated with such data transmission. Please proceed with caution.

In [None]:
# The severity level that will trigger response be blocked
# Please reference Azure AI content documentation for more details
# https://learn.microsoft.com/en-us/azure/cognitive-services/content-safety/concepts/harm-categories
content_severity_threshold = "2"

# UAI to be used for endpoint if you choose to use UAI as authentication method.
# Use default name "aacs-uai" as used in prepare uai notebook

# uai_name = "aacs-uai"

# If you choose environment variables for authentication of AACS resource, then assign empty ("") value to uai_name
uai_name = ""

In [None]:
from uuid import uuid4
from azure.mgmt.cognitiveservices import CognitiveServicesManagementClient

aacs_client = CognitiveServicesManagementClient(credential, subscription_id)


# settings for the Azure AI Content Safety (AACS) resource
# we will choose existing AACS resource if it exists, otherwise create a new one
# name of AACS resource, has to be unique

aacs_name = f"aacs-inpainting-{str(uuid4())[:8]}"
available_aacs_locations = ["east us", "west europe"]

# create a new Cognitive Services Account
kind = "ContentSafety"
aacs_sku_name = "S0"
aacs_location = available_aacs_locations[0]


print("Available SKUs:")
aacs_skus = aacs_client.resource_skus.list()
print("SKU Name\tSKU Tier\tLocations")
for sku in aacs_skus:
    if sku.kind == "ContentSafety":
        locations = ",".join(sku.locations)
        print(sku.name + "\t\t" + sku.tier + "\t\t" + locations)

print(f"Choose a new AACS resource in {aacs_location} with SKU {aacs_sku_name}")

#### 2.2 Create AACS Resource

In [None]:
from azure.mgmt.cognitiveservices.models import Account, Sku, AccountProperties

parameters = Account(
    sku=Sku(name=aacs_sku_name),
    kind=kind,
    location=aacs_location,
    properties=AccountProperties(
        custom_sub_domain_name=aacs_name, public_network_access="Enabled"
    ),
)


def find_acs(accounts):
    return next(
        x
        for x in accounts
        if x.kind == "ContentSafety"
        and x.location == aacs_location
        and x.sku.name == aacs_sku_name
    )


try:
    # check if AACS exists
    aacs = aacs_client.accounts.get(resource_group, aacs_name)
    print(f"Found existing AACS Account {aacs.name}.")
except:
    try:
        # check if there is an existing AACS resource within same resource group
        aacs = find_acs(aacs_client.accounts.list_by_resource_group(resource_group))
        print(
            f"Found existing AACS Account {aacs.name} in resource group {resource_group}."
        )
    except:
        print(f"Creating AACS Account {aacs_name}.")
        aacs_client.accounts.begin_create(resource_group, aacs_name, parameters).wait()
        print("Resource created.")
        aacs = aacs_client.accounts.get(resource_group, aacs_name)

In [None]:
aacs_endpoint = aacs.properties.endpoint
aacs_resource_id = aacs.id
aacs_name = aacs.name
print(
    f"AACS name is {aacs.name} .\nUse this name in UAI preparation notebook to create UAI."
)
print(f"AACS endpoint is {aacs_endpoint}")
print(f"AACS ResourceId is {aacs_resource_id}")

aacs_access_key = aacs_client.accounts.list_keys(
    resource_group_name=resource_group, account_name=aacs.name
).key1

#### 2.3 Check if UAI is used (Required for using UAI authentication method)

In [None]:
uai_id = ""
uai_client_id = ""
if uai_name != "":
    from azure.mgmt.msi import ManagedServiceIdentityClient
    from azure.mgmt.msi.models import Identity

    try:
        msi_client = ManagedServiceIdentityClient(
            subscription_id=subscription_id,
            credential=credential,
        )
        uai_resource = msi_client.user_assigned_identities.get(resource_group, uai_name)
        uai_id = uai_resource.id
        uai_client_id = uai_resource.client_id
    except Exception as ex:
        print("Please run aacs-prepare-uai.ipynb notebook and re-run the cell.")
        raise ex

### 3. Pick a model to deploy

Browse models in the Model Catalog in the AzureML Studio, filtering by the `text-to-image` task. In this example, we use the `runwayml-stable-diffusion-inpainting` model. If you have opened this notebook for a different model, replace the model name accordingly. This is a pre-trained model.

In [None]:
# Name of the inpainting model to be deployed
model_name = "runwayml-stable-diffusion-inpainting"

try:
    model = registry_ml_client.models.get(name=model_name, label="latest")
except Exception as ex:
    print(
        f"No model named {model_name} found in registry. "
        "Please check model name present in Azure model catalog"
    )
    raise ex

print(
    f"\n\nUsing model name: {model.name}, version: {model.version}, id: {model.id} for generating images from text."
)

#### 3.1 Register Model in Workspace

The above retrieved model from `azureml` registry will be registered within the user’s workspace. This registration will maintain the original name of the model, assign a unique version identifier (corresponding to the first field of the UUID), and label it as the “latest” version. Please note that this step take several minutes.

In [None]:
local_model_path = "local_model"

registry_ml_client.models.download(
    name=model.name, version=model.version, download_path=local_model_path
)

In [None]:
from azure.ai.ml.entities import Model
from azure.ai.ml.constants import AssetTypes
import os

local_model = Model(
    path=os.path.join(local_model_path, model.name, "mlflow_model_folder"),
    type=AssetTypes.MLFLOW_MODEL,
    name=model.name,
    version=str(uuid4().fields[0]),
    description="Model created from local file for text to image deployment.",
)

model = workspace_ml_client.models.create_or_update(local_model)

### 4. Deploy the model to an online endpoint for real time inference
Online endpoints give a durable REST API that can be used to integrate with applications that need to use the model.

Create an online endpoint

In [None]:
# Endpoint names need to be unique in a region,
# hence using uuid (first 8 character) to create unique endpoint name

endpoint_name = f"safe-inpainting-{str(uuid4())[:8]}"  # Replace with your endpoint name

In [None]:
from azure.ai.ml.entities import (
    ManagedOnlineEndpoint,
    IdentityConfiguration,
    ManagedIdentityConfiguration,
)

# Check if the endpoint already exists in the workspace
try:
    endpoint = workspace_ml_client.online_endpoints.get(endpoint_name)
    print("---Endpoint already exists---")
except:
    # Create an online endpoint if it doesn't exist

    # Define the endpoint
    endpoint = ManagedOnlineEndpoint(
        name=endpoint_name,
        description=f"Test endpoint for {model.name}",
        identity=IdentityConfiguration(
            type="user_assigned",
            user_assigned_identities=[ManagedIdentityConfiguration(resource_id=uai_id)],
        )
        if uai_id != ""
        else None,
    )

    # Trigger the endpoint creation
    try:
        workspace_ml_client.begin_create_or_update(endpoint).wait()
        print("\n---Endpoint created successfully---\n")
    except Exception as err:
        raise RuntimeError(
            f"Endpoint creation failed. Detailed Response:\n{err}"
        ) from err

Create a deployment. This step may take a several minutes.

In [None]:
# Initialize deployment parameters

deployment_name = "inpainting-deploy"
sku_name = "STANDARD_NC6S_V3"  # Name of the sku(instance type). Check the model card in catalog to get the most optimal sku for model.

REQUEST_TIMEOUT_MS = 90000

deployment_env_vars = {
    "CONTENT_SAFETY_ACCOUNT_NAME": aacs_name,
    "CONTENT_SAFETY_ENDPOINT": aacs_endpoint,
    "CONTENT_SAFETY_KEY": aacs_access_key if uai_client_id == "" else None,
    "CONTENT_SAFETY_THRESHOLD": content_severity_threshold,
    "SUBSCRIPTION_ID": subscription_id,
    "RESOURCE_GROUP_NAME": resource_group,
    "UAI_CLIENT_ID": uai_client_id,
}

In [None]:
from azure.ai.ml.entities import (
    OnlineRequestSettings,
    CodeConfiguration,
    ManagedOnlineDeployment,
    ProbeSettings,
)

code_configuration = CodeConfiguration(
    code="./aacs-scoring-files/score/", scoring_script="score_online.py"
)

deployment = ManagedOnlineDeployment(
    name=deployment_name,
    endpoint_name=endpoint_name,
    model=model.id,
    instance_type=sku_name,
    instance_count=1,
    code_configuration=code_configuration,
    environment_variables=deployment_env_vars,
    request_settings=OnlineRequestSettings(request_timeout_ms=REQUEST_TIMEOUT_MS),
    liveness_probe=ProbeSettings(
        failure_threshold=30,
        success_threshold=1,
        period=100,
        initial_delay=500,
    ),
    readiness_probe=ProbeSettings(
        failure_threshold=30,
        success_threshold=1,
        period=100,
        initial_delay=500,
    ),
)

# Trigger the deployment creation
try:
    workspace_ml_client.begin_create_or_update(deployment).wait()
    print("\n---Deployment created successfully---\n")
except Exception as err:
    raise RuntimeError(
        f"Deployment creation failed. Detailed Response:\n{err}"
    ) from err

### 5. Test the endpoint

We will fetch some sample data from the test dataset and submit to online endpoint for inference.

The sample of input schema for inpainting task:
```json
{
   "input_data": {
        "columns": ["prompt", "image", "mask_image"],
        "data": [
            {
                "prompt": "sample prompt",
                "image": "base image1",
                "mask_image": "mask image1"
            },
            {
                "prompt": "sample prompt",
                "image": "base image2",
                "mask_image": "mask image2"
            }
        ],
        "index": [0, 1]
    }
}
```
> The base and mask images (1 and 2) strings should be in base64 format or publicly accessible urls.

The sample of output schema for inpainting task:
```json
[
    {
        "generated_image": "image1",
        "nsfw_content_detected": False
    },
    {
        "generated_image": "image2",
        "nsfw_content_detected": True
    }
]
```
> - If "nsfw_content_detected" is True then generated image will be totally black.
> - Generated images "image1" and "image2" strings are in base64 format.

#### 5.1 Sample input for safe prompt.

In [None]:
# Create request json
import base64
import json


def read_image(image_path: str) -> bytes:
    """Reads an image from a file path into a byte array."""
    with open(image_path, "rb") as f:
        return f.read()


base_image = "inpainting_data/images/dog_on_bench.png"
mask_image = "inpainting_data/masks/dog_on_bench.png"

request_json = {
    "input_data": {
        "columns": ["image", "mask_image", "prompt"],
        "index": [0],
        "data": [
            {
                "image": base64.encodebytes(read_image(base_image)).decode("utf-8"),
                "mask_image": base64.encodebytes(read_image(mask_image)).decode(
                    "utf-8"
                ),
                "prompt": "A cat sitting on a park bench in high resolution.",
            }
        ],
    }
}

request_file_name = "sample_request_data.json"

with open(request_file_name, "w") as request_file:
    json.dump(request_json, request_file)

In [None]:
# Invoke the endpoint

response = workspace_ml_client.online_endpoints.invoke(
    endpoint_name=endpoint.name,
    deployment_name=deployment.name,
    request_file=request_file_name,
)

In [None]:
# Visualize the model output

import io
import base64
from PIL import Image

generations = json.loads(response)
for generation in generations:
    print(f"nsfw content detected: ", generation["nsfw_content_detected"])
    img = Image.open(io.BytesIO(base64.b64decode(generation["generated_image"])))
    display(img)

#### 5.2 Sample input for un-safe prompt

In [None]:
# Create request json
import base64
import json


def read_image(image_path):
    with open(image_path, "rb") as f:
        return f.read()


base_image = "inpainting_data/images/dog_on_bench.png"
mask_image = "inpainting_data/masks/dog_on_bench.png"

request_json = {
    "input_data": {
        "columns": ["image", "mask_image", "prompt"],
        "index": [0],
        "data": [
            {
                "image": base64.encodebytes(read_image(base_image)).decode("utf-8"),
                "mask_image": base64.encodebytes(read_image(mask_image)).decode(
                    "utf-8"
                ),
                "prompt": "a dog with severed leg and bleeding profusely from deep laceration to the lower extremities, exposing tissues",
            }
        ],
    }
}

request_file_name = "sample_request_data.json"

with open(request_file_name, "w") as request_file:
    json.dump(request_json, request_file)

In [None]:
# Invoke the endpoint

response = workspace_ml_client.online_endpoints.invoke(
    endpoint_name=endpoint.name,
    deployment_name=deployment.name,
    request_file=request_file_name,
)

In [None]:
# Model response should be empty because it is blocked by the Azure AI Content Safety (AACS) service.
print(response)

### 6. Clean up resources - delete the online endpoint
Don't forget to delete the online endpoint, else you will leave the billing meter running for the compute used by the endpoint.

In [None]:
workspace_ml_client.online_endpoints.begin_delete(name=endpoint.name).wait()