# How to create an Azure AI Content safety enabled LLaMA online endpoint
### This notebook will walk you through the steps to create an __Azure AI Content Safety__ enabled __LLaMA__ online endpoint.
### The steps are:
1. Create an __Azure AI Content Safety__ resource for moderating the request from user and response from the __LLaMA__ online endpoint.
2. Create a new __LLaMA__ online endpoint.
3. Create a new __Azure AI Content Safety__ enabled __LLaMA__ online endpoint with a custom score.py which will integrate with the __Azure AI Content Safety__ resource to moderate the response from the __LLaMA__ model and the request from the user, but to make the custom score.py to sucessfully autheticated to the __Azure AI Content Safety__ resource, we have 2 options:
    1. __UAI__, recommended but more complex approach, is to create a __User Assigned Identity (UAI)__ and assign appropriate roles to the __UAI__. Then, the custom score.py can obtain the access token of the __UAI__ from the AAD server to access the Azure AI Content Safety resource.
    2. __Environment variable__, simpler but less secure approach, is to just pass the access key of the __Azure AI Content Safety__ resource to the custom score.py via environment variable, then the custom score.py can use the key directly to access the Azure AI Content Safety resource, this option is less secure than the first option, if someone in your org has access to the endpoint, he/she can get the access key from the environment variable and use it to access the Azure AI Content Safety resource.
  

### 1. Prerequisites
#### 1.1 Check List:
- [x] You have created an new Python virtual environment for this notebook.
- [x] The identity you are using to execute this notebook(yourself or your VM) need to have the __Contributor__ role on the resource group where the AML Workspace your specified is located, because this notebook will create an Azure AI Content Safety resource using that identity.
- [x] Required If you choose to use the UAI approach, the identity executing this notebook (either yourself or your virtual machine) needs to have the owner role on the resource group that contains the specified AML Workspace. This is because the notebook will create a new UAI and assign the UAI some required roles to successfully create the Azure AI Content Safety enabled LLaMA endpoint.

#### 1.2 Install Dependencies

In [None]:
%pip install azure-identity==1.13.0
%pip install azure-mgmt-cognitiveservices==13.4.0
%pip install azure-ai-ml==1.8.0
%pip install azure-mgmt-msi==7.0.0
%pip install azure-mgmt-authorization==3.0.0

#### 1.3 Assign variables for the workspace and deployment

In [None]:
# NOTE: Update following workspace information to contain
#       your subscription ID, resource group name, and workspace name
subscription_id = "c830bb7a-83f5-45e3-81fc-3c2053e7d16f"
resource_group = "bowgong-test"
workspace_name = "bowgong-aml-test"

#### 1.3 Decide on a name for your Content Safety enabled LLaMA online endpoint

In [None]:
import random
rand = random.randint(0, 10000)

endpoint_name = f"safetyllama{rand}" # the final endpoint name of the safety enabled llama endpoint
# with the given name, there will be 4 resources created in the resource group of your AML workspace:
# 1. an Azure AI Content Safety resource: {endpoint_name}-aacs
# 2. a LLaMA online endpoint: {endpoint_name}-llama
# 3. a UAI(User Assigned Identity): {endpoint_name}-uia
# 4. an Azure AI Content Safety enabled LLaMA online endpoint, which you will be using it to do your AI work: {endpoint_name}
aacs_name = f"{endpoint_name}-aacs"
uai_name = f"{endpoint_name}-uai"
llama_endpoint_name = f"{endpoint_name}-llama"
print(f"going to create the following resources:")
print(f"-  Azure AI Content Safety resource: {aacs_name}")
print(f"-  LLaMA online endpoint: {llama_endpoint_name}")
print(f"-  UAI: {uai_name}")
print(f"-  Azure AI Content Safety enabled LLaMA online endpoint: {uai_name}")


### 2. Connect to your AML Workspace

In [None]:
import os, json
from azure.ai.ml import MLClient
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential

try:
    credential = DefaultAzureCredential()
    # Check if given credential can get token successfully.
    credential.get_token("https://management.azure.com/.default")
except Exception as ex:
    # Fall back to InteractiveBrowserCredential in case DefaultAzureCredential not work
    credential = InteractiveBrowserCredential()

try:
    ml_client = MLClient(credentials=credential, subscription_id=subscription_id, resource_group=resource_group, workspace_name=workspace_name)
except Exception as ex:
    client_config = {
        "subscription_id": subscription_id,
        "resource_group": resource_group,
        "workspace_name": workspace_name,
    }
    # write and reload from config file
    config_path = "./config.json"
    os.makedirs(os.path.dirname(config_path), exist_ok=True)
    with open(config_path, "w") as fo:
        fo.write(json.dumps(client_config))
    ml_client = MLClient.from_config(credential=credential, path=config_path)
    
workspace_location = ml_client.workspaces.get(ml_client.workspace_name).location
workspace_resource_id = ml_client.workspaces.get(ml_client.workspace_name).id
subscription_id = ml_client.subscription_id
resource_group_name = ml_client.resource_group_name
workspace_name = ml_client.workspace_name
print(f"Connected to workspace {workspace_resource_id}")
print(f"Workspace location is {workspace_location}") 

### 4. Create Azure AI Content Safety

#### 4.1 Choose a region for your Azure AI Content Safety
Currently, Azure AI Content Safety only available in the following regions:
- East US
- West Europe

__NOTE__: before you choose the region to deploy the Azure AI Content Safety, please aware of that your data will be transferred to the region you choose and by selecting a region outside your current location, you may be allowing the transmission of your data to regions outside your jurisdiction. It is important to note that data protection and privacy laws may vary between jurisdictions. Before proceeding, we strongly advise you to familiarize yourself with the local laws and regulations governing data transfer and ensure that you are legally permitted to transmit your data to an overseas location for processing. By continuing with the selection of a different region, you acknowledge that you have understood and accepted any potential risks associated with such data transmission. Please proceed with caution.

In [None]:
# location for the Azure AI Content Safety resource
available_acs_locations = ['east us', 'west europe']
aacs_location = available_acs_locations[0] 

print(f"will create Azure AI Content Safety `{aacs_name}` in {aacs_location}")

In [None]:
from azure.mgmt.cognitiveservices import CognitiveServicesManagementClient
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from azure.mgmt.cognitiveservices.models import Account, Sku, AccountProperties
import time

try:
    credential = DefaultAzureCredential()
    # Check if given credential can get token successfully.
    credential.get_token("https://management.azure.com/.default")
except Exception as ex:
    # Fall back to InteractiveBrowserCredential in case DefaultAzureCredential not work
    credential = InteractiveBrowserCredential()

client = CognitiveServicesManagementClient(credential, subscription_id)

# create a new Cognitive Services Account
kind = "ContentSafety"
sku_name = "S0"
parameters = Account(sku=Sku(name=sku_name), kind=kind, location=aacs_location, properties= AccountProperties(custom_sub_domain_name=aacs_name, public_network_access="Enabled"))
# How many seconds to wait between checking the status of an async operation.
wait_time = 10

poller = client.accounts.begin_create(resource_group_name, aacs_name, parameters)
while (not poller.done()) :
    print("Waiting {wait_time} seconds for operation to finish.".format(wait_time=wait_time))
    time.sleep (wait_time)
    # This will raise an exception if the server responded with an error.
    result = poller.result()


print("Resource created.")

aacs=client.accounts.get(resource_group_name, aacs_name)
aacs_endpoint = aacs.properties.endpoint
aacs_resource_id = aacs.id
print(f"AACS endpoint is {aacs_endpoint}")
print(f"AACS ResourceId is {aacs_resource_id}")

aacs_access_key = client.accounts.list_keys(resource_group_name=resource_group_name, account_name=aacs_name).key1
print(f"AACS access key is {aacs_access_key}")

### 5. Create LLaMA online endpoint

#### 5.1 Decide on SKU and instance count for the LLaMA online endpoint.

In [None]:
compute_sku_for_llama="Standard_DS5_v2" # the sku of the compute instance for LLaMA endpoint
compute_instance_count_for_llama=1 # the number of compute instance
llama_endpoint_name=f"{endpoint_name}-llama"
print(f"Will create LLaMA endpoint {llama_endpoint_name} using {compute_instance_count_for_llama} {compute_sku_for_llama} compute instance(s)")

#### 5.2 Check if LLaMA model is available in the AML registry.

In [None]:
available_llama_models_pre_trained = ["Llama-2-7b", "Llama-2-13b"]
available_llama_models_fine_tuned = ["Llama-2-7b-chat", "Llama-2-13b-chat"]

model_name = "gpt2" # TODO(mingtwan) change to LLaMA

registry_ml_client = MLClient(credential, registry_name="azureml")
version_list = list(registry_ml_client.models.list(model_name)) # list available versions of the model
foundation_model = None
if len(version_list) == 0:
    print("Model not found in registry")
else:
    model_version = version_list[0].version
    foundation_model = registry_ml_client.models.get(model_name, model_version)
    print(
        f"Using model name: {foundation_model.name}, version: {foundation_model.version}, id: {foundation_model.id} for inferencing"
    )

#### 5.3 Create LLaMA online endpoint
This step may take a few minutes.

In [None]:
from azure.ai.ml.entities import ManagedOnlineEndpoint, ManagedOnlineDeployment, OnlineRequestSettings

auth_mode_for_llama = "key" # please DO NOT change this. 
# create an online endpoint
llama_endpoint = ManagedOnlineEndpoint(
        name=llama_endpoint_name,
        description="Online endpoint for LLaMA",
        auth_mode=auth_mode_for_llama,
    )
ml_client.begin_create_or_update(llama_endpoint).result()

llama_deployment_name="demo"
llama_deployment = ManagedOnlineDeployment(
    name=llama_deployment_name,
    endpoint_name=llama_endpoint_name,
    model=foundation_model.id,
    instance_type=compute_sku_for_llama,
    instance_count=compute_instance_count_for_llama,
    request_settings=OnlineRequestSettings(
        request_timeout_ms=60000,
    )
)
ml_client.online_deployments.begin_create_or_update(llama_deployment).wait()
# deployment takes 100 traffic
llama_endpoint.traffic = {llama_deployment_name: 100}
ml_client.online_endpoints.begin_create_or_update(llama_endpoint)

llama_endpoint = ml_client.online_endpoints.get(name=llama_endpoint_name)
llama_endpoint_id = llama_endpoint.id
llama_score_uri = llama_endpoint.scoring_uri
print(f"LLaMA endpoint scoring uri is {llama_score_uri}")
llama_access_key = ml_client.online_endpoints.get_keys(name=llama_endpoint_name).primary_key
print(f"LLaMA endpoint access key is {llama_access_key}")

### 6. Create `score.py` for the Azure AI Content Safety enabled LLaMA endpoint


#### 6.1 Create a folder to save the score.py and conda dependencies file.
First create a source folder for the score.py file and conda dependencies file:

In [None]:
import os

scoring_src_dir = "./safety-llama"
os.makedirs(scoring_src_dir, exist_ok=True)
print(f"Scoring script directory: {scoring_src_dir}")

#### 6.2 Create the score.py

In [None]:
%%writefile {scoring_src_dir}/score.py
import logging
import json
from azure.identity import ManagedIdentityCredential
from azure.ai.ml import MLClient
import os

# environment variable names
env_key_of_aacs_endpoint = "AACS_ENDPOINT"
env_key_of_uai_id = "UAI_CLIENT_ID" # if provided, the script will use the UAI's AAD token to obtain the access key of the LLaMA online endpoint, and use the token to authenticate to the AACS resource directly.
env_key_of_aacs_key = "AACS_ACCESS_KEY" # if the UAI_CLIENT_ID not provided, the the script will fallback to use the access of the AACS resource.
env_key_of_llama_score_uri = "LLAMA_SCORE_URI"
env_key_of_subscription_id = "SUBSCRIPTION_ID"
env_key_of_resource_group_name = "RESOURCE_GROUP_NAME"
env_key_of_workspace_name = "WORKSPACE_NAME"
env_key_of_llama_endpoint_name = "LLAMA_ENDPOINT_NAME"


def init():
    """
    This function is called when the container is initialized/started, typically after create/update of the deployment.
    You can write the logic here to perform init operations like caching the model in memory
    """
    aacs_endpoint = os.environ.get(env_key_of_aacs_endpoint)
    llama_score_uri = os.environ.get(env_key_of_llama_score_uri)
    uai_id = os.environ.get(env_key_of_uai_id) 
    
    logging.info("AACS endpoint: ", aacs_endpoint)
    logging.info("LLaMA score uri: ", llama_score_uri)
    logging.info("UAI ID: ", uai_id)
    logging.info("Init complete")

def run(raw_data):
    """
    This function is called for every invocation of the endpoint to perform the actual scoring/prediction.
    In the example we extract the data from the json input and call the scikit-learn model's predict()
    method and return the result back
    """
    data = json.loads(raw_data)["data"]
    logging.info("Request processed")
    return '{"text":"Hello World"}'

def _get_aad_token_for_aacs():
    """
    Get access key for Azure AI Content Safety
    """
    credential = ManagedIdentityCredential(client_id=os.environ.get(env_key_of_uai_id))
    aacs_token = credential.get_token("https://cognitiveservices.azure.com/.default") # get token for AACS

def _get_llama_access_key():
    """
    Helper function to get the access key for LLaMA endpoint
    """
    credential = ManagedIdentityCredential(client_id=os.environ.get(env_key_of_uai_id))
    subscription_id = os.environ.get(env_key_of_subscription_id)
    resource_group =  os.environ.get(env_key_of_resource_group_name)
    workspace_name = os.environ.get(env_key_of_workspace_name)

    ml_client = MLClient(
        credential,
        subscription_id=subscription_id,
        resource_group_name=resource_group,
        workspace_name=workspace_name,
    )
    llama_endpoint_name = os.environ.get(env_key_of_llama_endpoint_name)
    if not llama_endpoint_name:
        raise ValueError("LLaMA endpoint name is not provided")
    
    llama_access_key = ml_client.online_endpoints.get_keys(name=llama_endpoint_name).primary_key
    return llama_access_key

#### 6.3 Create the conda.yaml

In [None]:
%%writefile {scoring_src_dir}/conda.yaml
name: aacs-conda
channels:
  - defaults
dependencies:
  - python=3.9
  - pip:
    - azure-identity==1.13.0
    - azure-ai-ml==1.8.0
    - azureml-inference-server-http==0.8.4

### 7. Create the Azure AI Content Safety enabled LLaMA online endpoint
Before we get started to create the Azure AI Contenet Safety enabled LLaMA online endpoint, you need to make the decision on which approach you want to use for the custome score.py to authenticate to the Azure AI Content Safety resource, if you choose the UAI approach you need to the steps 7.1 and 7.2, if you choose the environment variable approach then you can skip to step 7.3 to get started.

#### 7.1 Create a Managed Identity for the Azure AI Content Safety enabled LLaMA endpoint
NOTE: Azure Content Safey is support AAD token based authentication by default, which means we need to create a new UAI for the Azure AI Content Safety enabled LLaMA endpoint, so that it can access the Azure AI Content Safety using AAD token.

##### 7.1.1 Get a handle to the ManagedServiceIdentityClient

In [None]:
from azure.mgmt.msi import ManagedServiceIdentityClient
from azure.mgmt.msi.models import Identity

msi_client = ManagedServiceIdentityClient(
    subscription_id=subscription_id,
    credential=credential,
)

##### 7.1.2 Create the User Assigned Identity:

In [None]:
uai_name = f"{endpoint_name}-uai"
print(f"Will create UAI(User Assigned Identity) {uai_name} for the Azure AI Content Safety enabled LLaMA endpoint.")

msi_client.user_assigned_identities.create_or_update(
    resource_group_name=resource_group_name,
    resource_name=uai_name,
    parameters=Identity(location=workspace_location),
)

##### 7.1.3 Retrieve the identity object
we need to retrieve the identity object so that we can use it to deploy the Azure AI Content Safety enabled LLaMA online endpoint.

In [None]:
uai_identity = msi_client.user_assigned_identities.get(
    resource_group_name=resource_group_name,
    resource_name=uai_name,
)
uai_principal_id = uai_identity.principal_id
uai_client_id = uai_identity.client_id
uai_id = uai_identity.id
print(f"UAI principal id: {uai_principal_id}")
print(f"UAI client id: {uai_client_id}")
print(f"UAI id: {uai_id}")

#### 7.2 Grant appropriate roles to the UAI we created above.
Note: In order to successfully run scripts in current step, your must have owner permission on the AACS resource and the LLaMA endpoint, which we created in the previous steps.

##### 7.2.1 Get an AuthorizationManagementClient to list Role Definitions

In [None]:
from azure.mgmt.authorization import AuthorizationManagementClient
from azure.mgmt.authorization.v2020_10_01_preview.models import RoleAssignmentCreateParameters
import uuid

role_definition_client = AuthorizationManagementClient(
    credential=credential,
    subscription_id=subscription_id,
    api_version="2018-01-01-preview",
)

role_assignment_client = AuthorizationManagementClient(
    credential=credential,
    subscription_id=subscription_id,
    api_version="2020-10-01-preview",
)

uai_role_check_list = {
    "Cognitive Services User": {"step": "7.2.2", "description":"assigne the role Cognitive Services User to the UAI on the Azure AI Content Safety resource."},
    "AzureML Data Scientist": {"step": "7.2.3", "description":"assigne the role AzureML Data Scientist to the UAI on the workspace."},
    "AcrPull": {"step": "7.2.4", "description":"assigne the role AcrPull to the UAI on the Azure Container Registry."},
    "Storage Blob Data Reader": {"step": "7.2.5", "description":"assigne the role Storage Blob Data Reader to the UAI on the Azure Storage account."},
}

##### 7.2.2 Grant the user identity access to the Azure Content Safety resource
Cognitive Services User role is required to access the Azure Content Safety resource.

In [None]:
role_name = "Cognitive Services User" # minimum role required for accessing AACS
scope = aacs_resource_id

role_defs = role_definition_client.role_definitions.list(scope=scope)
role_def = next((r for r in role_defs if r.role_name == role_name))

from azure.core.exceptions import ResourceExistsError
try:
    role_assignment_client.role_assignments.create(
        scope=scope,
        role_assignment_name=str(uuid.uuid4()),
        parameters=RoleAssignmentCreateParameters(
            role_definition_id=role_def.id,
            principal_id=uai_principal_id,
            principal_type="ServicePrincipal",
        ),
    )
except ResourceExistsError as ex:
    pass
except Exception as ex:
    print(ex)
    raise ex

if role_name in uai_role_check_list:
    del uai_role_check_list[role_name] 
print(f"Role assignment for {role_name} at the Azure AI Content Safety resource level completed.")

##### 7.2.3 Grant the user identity access to the LLaMA online endpoint.
In order to retrieve the key of the LLaMA online endpoint, the MI must have `AzureML Data Scientist` role on the workspace level.

In [None]:
role_name = "AzureML Data Scientist"
scope = workspace_resource_id

role_defs = role_definition_client.role_definitions.list(scope=scope)
role_def = next((r for r in role_defs if r.role_name == role_name))
from azure.core.exceptions import ResourceExistsError
try:
    role_assignment_client.role_assignments.create(
        scope=scope,
        role_assignment_name=str(uuid.uuid4()),
        parameters=RoleAssignmentCreateParameters(
            role_definition_id=role_def.id,
            principal_id=uai_principal_id,
            principal_type="ServicePrincipal",
        ),
    )
except ResourceExistsError as ex:
    pass
except Exception as ex:
    print(ex)
    raise ex

if role_name in uai_role_check_list:
    del uai_role_check_list[role_name] 
print(f"Role assignment for {role_name} at the workspace level completed.")

##### 7.2.4 Assign AcrPull at the workspace container registry scope
Since we will create the content safety enabled LlaMa endpoint with User Assigned Identity, the user managed identity must have Storage Blob Data Reader permission on the storage account for the workspace, and AcrPull permission on the Azure Container Registry (ACR) for the workspace. Make sure your User Assigned Identity has the right permission.

In [None]:
workspace = ml_client.workspaces.get(workspace_name)
container_registry = workspace.container_registry

role_name = "AcrPull"

role_defs = role_definition_client.role_definitions.list(scope=container_registry)
role_def = next((r for r in role_defs if r.role_name == role_name))

from azure.core.exceptions import ResourceExistsError
try:
    role_assignment_client.role_assignments.create(
        scope=container_registry,
        role_assignment_name=str(uuid.uuid4()),
        parameters=RoleAssignmentCreateParameters(
            role_definition_id=role_def.id,
            principal_id=uai_principal_id,
            principal_type="ServicePrincipal",
        ),
    ) 
except ResourceExistsError as ex:
    pass
except Exception as ex:
    print(ex)
    raise ex

if role_name in uai_role_check_list:
    del uai_role_check_list[role_name] 
print("Role assignment for AcrPull at the workspace container registry completed.")

##### 7.2.5 Assign Storage Blob Data Reader at the workspace storage account scope

In [None]:
role_name = "Storage Blob Data Reader"
blob_scope = workspace.storage_account

role_defs = role_definition_client.role_definitions.list(scope=blob_scope)
role_def = next((r for r in role_defs if r.role_name == role_name))

from azure.core.exceptions import ResourceExistsError
try:
    role_assignment_client.role_assignments.create(
        scope=blob_scope,
        role_assignment_name=str(uuid.uuid4()),
        parameters=RoleAssignmentCreateParameters(
            role_definition_id=role_def.id,
            principal_id=uai_principal_id,
            principal_type="ServicePrincipal",
        ),
    )
except ResourceExistsError as ex:
    pass
except Exception as ex:
    print(ex)
    raise ex

if role_name in uai_role_check_list:
    del uai_role_check_list[role_name]  
print("Role assignment for `Storage Blob Data Reader` at the workspace storage account completed.")

#### 7.3 Create Content Safety enabled LLaMA online endpoint using above score.py

##### 7.3.1 Decide on SKU and instance count for the Content Safety enabled LLaMA online endpoint.
TODO: Add more details about SKU and instance count recommandation.

In [None]:
compute_sku_for_safety_proxy = "Standard_DS5_v2"
compute_count = 1

safety_llama_auth_mode = "key" # currently, "key" and "aml_token" are supported
aacs_authentication_mode = "UAI" # UAI or EnvVar, if EnvVar is used, the access keys of AACS and LLaMA online endpoint will be set to environment variables.

##### 7.3.2 Create the Safety-Enabled LLaMA Online Endpoint
This step may take a few minutes.

__Before we proceed to the next step, let's make sure we didn't miss anything in the previous steps, please execute the following script to check on that:__

In [None]:
# Check everything is properly done before creating the Azure AI Content Safety Enabled LLaMA online endpoint
missing_steps = []
if aacs_authentication_mode == "UAI":
    print("You selected UAI to deploy the Azure AI Content Safety Enabled LLaMA online endpoint, checking if the UAI has the required roles assigned...")
    if uai_role_check_list:
        for key, value in uai_role_check_list.items():
            missing_steps.append(f'Please go to step {value["step"]} to {value["description"]}')

if missing_steps:
    print("Seems you missed some step above.")
    steps = "\n".join(missing_steps)
    raise Exception(f"Please complete the missing steps before proceeding:\n{steps}")
else:
    print("All steps are completed, proceeding to create the Azure AI Content Safety Enabled LLaMA online endpoint...")

In [None]:
# environment variables that will be used in the scoring script
# environment variable names
env_key_of_aacs_endpoint = "AACS_ENDPOINT"
env_key_of_uai_id = "UAI_CLIENT_ID" # if provided, the script will use the UAI's AAD token to obtain the access key of the LLaMA online endpoint, and use the token to authenticate to the AACS resource directly.
env_key_of_aacs_key = "AACS_ACCESS_KEY" # if the UAI_CLIENT_ID is not provided, the the script will fallback to use the access key of the AACS resource.
env_key_of_llama_key = "LLAMA_ACCESS_KEY" # if the UAI_CLIENT_ID is not provided, the the script will fallback to use the access key of the LLaMA online endpoint.
env_key_of_llama_score_uri = "LLAMA_SCORE_URI"
env_key_of_subscription_id = "SUBSCRIPTION_ID"
env_key_of_resource_group_name = "RESOURCE_GROUP_NAME"
env_key_of_workspace_name = "WORKSPACE_NAME"
env_key_of_llama_endpoint_name = "LLAMA_ENDPOINT_NAME"

from azure.ai.ml.entities import (
    ManagedOnlineDeployment,
    ManagedOnlineEndpoint,
    CodeConfiguration,
    Environment,
    ManagedIdentityConfiguration,
    IdentityConfiguration
)

if not aacs_endpoint:
    raise Exception("AACS Endpoint is not valid.")
else:
    print(f"AACS Endpoint: {aacs_endpoint}")

environment_variables = { 
    env_key_of_aacs_endpoint: aacs_endpoint,
    env_key_of_llama_score_uri: llama_score_uri,
    env_key_of_subscription_id: subscription_id,
    env_key_of_resource_group_name: resource_group_name,
    env_key_of_workspace_name: workspace_name,
    env_key_of_llama_endpoint_name: llama_endpoint_name
}

uai_identity_config = None
if uai_client_id:
    environment_variables[env_key_of_uai_id] = uai_client_id
else:
    print("UAI_CLIENT_ID is not provided. The script will fallback to use the access key of the AACS resource and the LLaMA online endpoint.")
    # use environment varibale to pass access keys
    environment_variables[env_key_of_aacs_key] = aacs_access_key
    environment_variables[env_key_of_llama_key] = llama_access_key
    uai_identity_config = IdentityConfiguration(
            type="user_assigned",
            user_assigned_identities=[
                ManagedIdentityConfiguration(resource_id=uai_id)
            ],
        )

deployment = ManagedOnlineDeployment(
        name="blue",
        endpoint_name=endpoint_name,
        code_configuration=CodeConfiguration(
            code=f"{scoring_src_dir}", scoring_script="score.py"
        ),
        environment=Environment(
            conda_file=f"{scoring_src_dir}/conda.yaml",
            image="mcr.microsoft.com/azureml/openmpi4.1.0-ubuntu20.04:latest",
        ),
        instance_type=compute_sku_for_safety_proxy,
        instance_count=compute_count,
        environment_variables=environment_variables,
    )


endpoint = ManagedOnlineEndpoint(
        name=endpoint_name,
        description="Azure AI Content Safety enabled LLaMA online endpoint",
        auth_mode=safety_llama_auth_mode,
        identity=uai_identity_config,
    )
# create online endpoint
ml_client.online_endpoints.begin_create_or_update(endpoint).result()

endpoint = ml_client.online_endpoints.get(endpoint_name)
print(endpoint.identity.type)
print(endpoint.identity.user_assigned_identities)

# create deployment
ml_client.online_deployments.begin_create_or_update(deployment).result()
# check status
deployment = ml_client.online_deployments.get(
        endpoint_name=endpoint_name, name=deployment.name
    )
print(deployment)
# Set traffic to 100% for deployment
endpoint.traffic = {str(deployment.name): 100}
ml_client.begin_create_or_update(endpoint).result()


### 8. Test the Safety Enabled LLaMA online endpoint.

#### 8.1 Test endpoint with normal sample request

In [None]:
import os

test_src_dir = "./safety-llama-test"
os.makedirs(test_src_dir, exist_ok=True)
print(f"test script directory: {test_src_dir}")

In [None]:
%%writefile {test_src_dir}/sample-request.json
{"data": "Hello World"}

In [None]:
sample_data = f"{test_src_dir}/sample-request.json"
ml_client.online_endpoints.invoke(endpoint_name=endpoint_name, request_file=sample_data)

#### 8.2 Test endpoint with harmful sample request

In [None]:
%%writefile {test_src_dir}/sample-harmful-request.json
{"data": "I wanna kill you!"}

In [None]:
sample_data = f"{test_src_dir}/sample-harmful-request.json"
ml_client.online_endpoints.invoke(endpoint_name=endpoint_name, request_file=sample_data)