In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ShieldGemma Deployment

Based on [model_garden_gemma2_deployment_on_vertex.ipynb](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma2_deployment_on_vertex.ipynb)

<table align="left">
  <td style="text-align: center">
    <a href="https://art-analytics.appspot.com/r.html?uaid=G-FHXEFWTT4E&utm_source=aRT-evaluation_rag_use_cases-from_notebook-colab&utm_medium=aRT-clicks&utm_campaign=evaluation_rag_use_cases-from_notebook-colab&destination=evaluation_rag_use_cases-from_notebook-colab&url=https%3A%2F%2Fcolab.sandbox.google.com%2Fgithub%2FGoogleCloudPlatform%2Fapplied-ai-engineering-samples%2Fblob%2Fmain%2Fgenai-on-vertex-ai%2Fvertex_evaluation_services%2Fevaluation-rag-systems%2Fevaluation_rag_use_cases.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Google Colaboratory logo"><br> Run in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://art-analytics.appspot.com/r.html?uaid=G-FHXEFWTT4E&utm_source=aRT-evaluation_rag_use_cases-from_notebook-colab_ent&utm_medium=aRT-clicks&utm_campaign=evaluation_rag_use_cases-from_notebook-colab_ent&destination=evaluation_rag_use_cases-from_notebook-colab_ent&url=https%3A%2F%2Fconsole.cloud.google.com%2Fvertex-ai%2Fcolab%2Fimport%2Fhttps%3A%252F%252Fraw.githubusercontent.com%252FGoogleCloudPlatform%252Fapplied-ai-engineering-samples%252Fmain%252Fgenai-on-vertex-ai%252Fvertex_evaluation_services%252Fevaluation-rag-systems%252Fevaluation_rag_use_cases.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Run in Colab Enterprise
    </a>
  </td>    
  <td style="text-align: center">
    <a href="https://art-analytics.appspot.com/r.html?uaid=G-FHXEFWTT4E&utm_source=aRT-evaluation_rag_use_cases-from_notebook-github&utm_medium=aRT-clicks&utm_campaign=evaluation_rag_use_cases-from_notebook-github&destination=evaluation_rag_use_cases-from_notebook-github&url=https%3A%2F%2Fgithub.com%2FGoogleCloudPlatform%2Fapplied-ai-engineering-samples%2Fblob%2Fmain%2Fgenai-on-vertex-ai%2Fvertex_evaluation_services%2Fevaluation-rag-systems%2Fevaluation_rag_use_cases.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://art-analytics.appspot.com/r.html?uaid=G-FHXEFWTT4E&utm_source=aRT-evaluation_rag_use_cases-from_notebook-vai_workbench&utm_medium=aRT-clicks&utm_campaign=evaluation_rag_use_cases-from_notebook-vai_workbench&destination=evaluation_rag_use_cases-from_notebook-vai_workbench&url=https%3A%2F%2Fconsole.cloud.google.com%2Fvertex-ai%2Fworkbench%2Fdeploy-notebook%3Fdownload_url%3Dhttps%3A%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fapplied-ai-engineering-samples%2Fmain%2Fgenai-on-vertex-ai%2Fvertex_evaluation_services%2Fevaluation-rag-systems%2Fevaluation_rag_use_cases.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
</table>

<table align="left">
    <td>Author(s)</td>
    <td>Egon Soares</td>
</table>

![shieldgemma deployment architecture](images/2.1-ShieldGemma-Deployment.png)

## Overview

This notebook demonstrates deploying ShieldGemma model on TPU using **Hex-LLM**, a **H**igh-**E**fficiency **L**arge **L**anguage **M**odel serving solution built with **XLA** that is being developed by Google Cloud

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## Before you begin

By default, the quota for TPU deployment `Custom model serving TPU v5e cores per region` is 4. TPU quota is only available in `us-west1`. You can request for higher TPU quota following the instructions at ["Request a higher quota"](https://cloud.google.com/docs/quota/view-manage#requesting_higher_quota).

### Setup Google Cloud project 

[Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

**[Optional]** Set project. If not set, the project will be set automatically according to the environment variable "GOOGLE_CLOUD_PROJECT".

In [None]:
PROJECT_ID = ""  # @param {type:"string"}

In [None]:
import os

In [None]:
if not PROJECT_ID:
    # Get the default cloud project id.
    PROJECT_ID = os.environ.get("GOOGLE_CLOUD_PROJECT", "")
    assert PROJECT_ID, "Provide a google cloud project id."

**[Optional]** Set region. If not set, the region will be set automatically according to the environment variable "GOOGLE_CLOUD_REGION".

In [None]:
REGION = ""  # @param {type:"string"}

In [None]:
if not REGION:
    # Get the default region for launching jobs.
    REGION = os.environ.get("GOOGLE_CLOUD_REGION", "")
    assert REGION, "Provide a google cloud region."

Upgrade Vertex AI SDK.

In [None]:
! pip3 install --upgrade --quiet 'google-cloud-aiplatform>=1.64.0'

## Access ShieldGemma Model
You must provide a Hugging Face User Access Token (read) to access the Shield Gemma model. You can follow the [Hugging Face documentation](https://huggingface.co/docs/hub/en/security-tokens) to create a **read** access token and put it in the `HF_TOKEN` field below.

**[Optional]** Set a Hugging Face read token. If not set, the token will be set automatically according to the environment variable "HF_TOKEN".

In [None]:
HF_TOKEN = ""  # @param {type:"string"}

In [None]:
if not HF_TOKEN:
    # Get the HF token from the environment.
    HF_TOKEN = os.environ.get("HF_TOKEN", "")
    assert HF_TOKEN, "Provide a read HF_TOKEN to load models from Hugging Face."

In [None]:
model_path_prefix = "google/"

## Setup

In [None]:
import sys

running_in_colab = "google.colab" in sys.modules

if running_in_colab and os.environ.get("VERTEX_PRODUCT", "") != "COLAB_ENTERPRISE":
    from google.colab import auth as colab_auth
    
    colab_auth.authenticate_user()

In [None]:
# Enable the Vertex AI API and Compute Engine API, if not already.
print("Enabling Vertex AI API and Compute Engine API.")
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

In [None]:
shell_output = ! gcloud projects describe $PROJECT_ID
project_number = shell_output[-1].split(":")[1].strip().replace("'", "")
SERVICE_ACCOUNT = f"{project_number}-compute@developer.gserviceaccount.com"
print("Using this default Service Account:", SERVICE_ACCOUNT)

Initialize Vertex AI API.

In [None]:
from typing import Tuple

from google.cloud import aiplatform

print("Initializing Vertex AI API.")
aiplatform.init(project=PROJECT_ID, location=REGION)

In [None]:
models, endpoints = {}, {}

Deploy ShieldGemma model with Hex-LLM on TPU

**Hex-LLM** is a **H**igh-**E**fficiency **L**arge **L**anguage **M**odel (LLM) TPU serving solution built with **XLA**, which is being developed by Google Cloud.

Refer to the "Request for TPU quota" section for TPU quota.

### Deploy

The pre-built serving docker images.

In [None]:
HEXLLM_DOCKER_URI = "us-docker.pkg.dev/vertex-ai-restricted/vertex-vision-model-garden-dockers/hex-llm-serve:gemma2"

Set the model ID. Model weights can be loaded from HuggingFace or from a GCS bucket. Select one of the four model variations.

In [None]:
MODEL_ID = "shieldgemma-9b"  # @param ["shieldgemma-2b", "shieldgemma-9b", "shieldgemma-27b"] {allow-input: true, isTemplate: true}
TPU_DEPLOYMENT_REGION = "us-west1"  # @param ["us-west1"] {isTemplate:true}
model_id = os.path.join(model_path_prefix, MODEL_ID)

Choose a machine type. You can find Vertex AI prediction TPUv5e machine types in https://cloud.google.com/vertex-ai/docs/predictions/use-tpu#deploy_a_model.

In [None]:
if "2b" in model_id:
    # Sets ct5lp-hightpu-1t (1 TPU chip) to deploy Gemma 2 2B models.
    machine_type = "ct5lp-hightpu-1t"
    accelerator_type = "TPU_V5e"
    # Note: 1 TPU V5 chip has only one core.
    accelerator_count = 1
elif "9b" in model_id:
    # Sets ct5lp-hightpu-4t (4 TPU chips) to deploy Gemma 2 9B models.
    machine_type = "ct5lp-hightpu-4t"
    accelerator_type = "TPU_V5e"
    # Note: 1 TPU V5 chip has only one core.
    accelerator_count = 4
else:
    # Sets ct5lp-hightpu-8t (8 TPU chips) to deploy Gemma 2 27B models.
    machine_type = "ct5lp-hightpu-8t"
    accelerator_type = "TPU_V5e"
    # Note: 1 TPU V5 chip has only one core.
    accelerator_count = 8

(Optional) Check quota

In [None]:
import json
import subprocess

def get_quota(project_id: str, region: str, resource_id: str) -> int:
  """Returns the quota for a resource in a region.

  Args:
    project_id: The project id.
    region: The region.
    resource_id: The resource id.

  Returns:
    The quota for the resource in the region. Returns -1 if can not figure out
    the quota.

  Raises:
    RuntimeError: If the command to get quota fails.
  """
  service_endpoint = "aiplatform.googleapis.com"

  command = (
      "gcloud alpha services quota list"
      f" --service={service_endpoint} --consumer=projects/{project_id}"
      f" --filter='{service_endpoint}/{resource_id}' --format=json"
  )
  process = subprocess.run(
      command, shell=True, capture_output=True, text=True, check=True
  )
  if process.returncode == 0:
    quota_data = json.loads(process.stdout)
  else:
    raise RuntimeError(f"Error fetching quota data: {process.stderr}")

  if not quota_data or "consumerQuotaLimits" not in quota_data[0]:
    return -1
  if (
      not quota_data[0]["consumerQuotaLimits"]
      or "quotaBuckets" not in quota_data[0]["consumerQuotaLimits"][0]
  ):
    return -1
  all_regions_data = quota_data[0]["consumerQuotaLimits"][0]["quotaBuckets"]
  for region_data in all_regions_data:
    if (
        region_data.get("dimensions")
        and region_data["dimensions"]["region"] == region
    ):
      if "effectiveLimit" in region_data:
        return int(region_data["effectiveLimit"])
      else:
        return 0
  return -1

def get_resource_id(
    accelerator_type: str,
    is_for_training: bool,
    is_restricted_image: bool = False,
    is_dynamic_workload_scheduler: bool = False,
) -> str:
  """Returns the resource id for a given accelerator type and the use case.

  Args:
    accelerator_type: The accelerator type.
    is_for_training: Whether the resource is used for training. Set false for
      serving use case.
    is_restricted_image: Whether the image is hosted in `vertex-ai-restricted`.
    is_dynamic_workload_scheduler: Whether the resource is used with Dynamic
      Workload Scheduler.

  Returns:
    The resource id.
  """
  accelerator_suffix_map = {
      "NVIDIA_TESLA_V100": "nvidia_v100_gpus",
      "NVIDIA_L4": "nvidia_l4_gpus",
      "NVIDIA_TESLA_A100": "nvidia_a100_gpus",
      "NVIDIA_A100_80GB": "nvidia_a100_80gb_gpus",
      "NVIDIA_H100_80GB": "nvidia_h100_gpus",
      "NVIDIA_TESLA_T4": "nvidia_t4_gpus",
      "TPU_V5e": "tpu_v5e",
      "TPU_V3": "tpu_v3",
  }
  default_training_accelerator_map = {
      key: f"custom_model_training_{accelerator_suffix_map[key]}"
      for key in accelerator_suffix_map
  }
  dws_training_accelerator_map = {
      key: f"custom_model_training_preemptible_{accelerator_suffix_map[key]}"
      for key in accelerator_suffix_map
  }
  restricted_image_training_accelerator_map = {
      "NVIDIA_A100_80GB": "restricted_image_training_nvidia_a100_80gb_gpus",
  }
  serving_accelerator_map = {
      key: f"custom_model_serving_{accelerator_suffix_map[key]}"
      for key in accelerator_suffix_map
  }

  if is_for_training:
    if is_restricted_image and is_dynamic_workload_scheduler:
      raise ValueError(
          "Dynamic Workload Scheduler does not work for restricted image"
          " training."
      )
    training_accelerator_map = (
        restricted_image_training_accelerator_map
        if is_restricted_image
        else default_training_accelerator_map
    )
    if accelerator_type in training_accelerator_map:
      if is_dynamic_workload_scheduler:
        return dws_training_accelerator_map[accelerator_type]
      else:
        return training_accelerator_map[accelerator_type]
    else:
      raise ValueError(
          f"Could not find accelerator type: {accelerator_type} for training."
      )
  else:
    if is_dynamic_workload_scheduler:
      raise ValueError("Dynamic Workload Scheduler does not work for serving.")
    if accelerator_type in serving_accelerator_map:
      return serving_accelerator_map[accelerator_type]
    else:
      raise ValueError(
          f"Could not find accelerator type: {accelerator_type} for serving."
      )

def check_quota(
    project_id: str,
    region: str,
    accelerator_type: str,
    accelerator_count: int,
    is_for_training: bool,
    is_restricted_image: bool = False,
    is_dynamic_workload_scheduler: bool = False,
):
  """Checks if the project and the region has the required quota."""
  resource_id = get_resource_id(
      accelerator_type,
      is_for_training=is_for_training,
      is_restricted_image=is_restricted_image,
      is_dynamic_workload_scheduler=is_dynamic_workload_scheduler,
  )
  quota = get_quota(project_id, region, resource_id)
  quota_request_instruction = (
      "Either use "
      "a different region or request additional quota. Follow "
      "instructions here "
      "https://cloud.google.com/docs/quotas/view-manage#requesting_higher_quota"
      " to check quota in a region or request additional quota for "
      "your project."
  )
  if quota == -1:
    raise ValueError(
        f"Quota not found for: {resource_id} in {region}."
        f" {quota_request_instruction}"
    )
  if quota < accelerator_count:
    raise ValueError(
        f"Quota not enough for {resource_id} in {region}: {quota} <"
        f" {accelerator_count}. {quota_request_instruction}"
    )

In [None]:
check_quota(
    project_id=PROJECT_ID,
    region=TPU_DEPLOYMENT_REGION,
    accelerator_type=accelerator_type,
    accelerator_count=accelerator_count,
    is_for_training=False,
)

Server parameters

In [None]:
tensor_parallel_size = accelerator_count
hbm_utilization_factor = 0.6  # Fraction of HBM memory allocated for KV cache after model loading. A larger value improves throughput but gives higher risk of TPU out-of-memory errors with long prompts.
max_running_seqs = 256  # Maximum number of running sequences in a continuous batch.

Set use_dedicated_endpoint to False if you don't want to use [dedicated endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment#create-dedicated-endpoint).

In [None]:
use_dedicated_endpoint = True  # @param {type:"boolean"}

Endpoint configurations

In [None]:
min_replica_count = 1
max_replica_count = 1

Deployment function

In [None]:
def deploy_model_hexllm(
    model_name: str,
    model_id: str,
    service_account: str,
    base_model_id: str = None,
    tensor_parallel_size: int = 1,
    machine_type: str = "ct5lp-hightpu-1t",
    tpu_topology: str = "1x1",
    hbm_utilization_factor: float = 0.6,
    max_running_seqs: int = 256,
    max_model_len: int = 4096,
    endpoint_id: str = "",
    min_replica_count: int = 1,
    max_replica_count: int = 1,
    use_dedicated_endpoint: bool = False,
) -> Tuple[aiplatform.Model, aiplatform.Endpoint]:
    """Deploys models with Hex-LLM on TPU in Vertex AI."""
    if endpoint_id:
        aip_endpoint_name = (
            f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_id}"
        )
        endpoint = aiplatform.Endpoint(aip_endpoint_name)
    else:
        endpoint = aiplatform.Endpoint.create(
            display_name=f"{model_name}-endpoint",
            location=TPU_DEPLOYMENT_REGION,
            dedicated_endpoint_enabled=use_dedicated_endpoint,
        )

    if not base_model_id:
        base_model_id = model_id

    if not tensor_parallel_size:
        tensor_parallel_size = int(machine_type[-2])

    num_hosts = int(tpu_topology.split("x")[0])

    hexllm_args = [
        "--host=0.0.0.0",
        "--port=7080",
        f"--model={model_id}",
        f"--tensor_parallel_size={tensor_parallel_size}",
        f"--num_hosts={num_hosts}",
        f"--hbm_utilization_factor={hbm_utilization_factor}",
        f"--max_running_seqs={max_running_seqs}",
        f"--max_model_len={max_model_len}",
    ]

    env_vars = {
        "MODEL_ID": base_model_id,
        "HEX_LLM_LOG_LEVEL": "info",
        "DEPLOY_SOURCE": "notebook",
    }

    # HF_TOKEN is not a compulsory field and may not be defined.
    try:
        if HF_TOKEN:
            env_vars.update({"HF_TOKEN": HF_TOKEN})
    except:
        pass

    model = aiplatform.Model.upload(
        display_name=model_name,
        serving_container_image_uri=HEXLLM_DOCKER_URI,
        serving_container_command=["python", "-m", "hex_llm.server.api_server"],
        serving_container_args=hexllm_args,
        serving_container_ports=[7080],
        serving_container_predict_route="/generate",
        serving_container_health_route="/ping",
        serving_container_environment_variables=env_vars,
        serving_container_shared_memory_size_mb=(16 * 1024),  # 16 GB
        serving_container_deployment_timeout=7200,
        location=TPU_DEPLOYMENT_REGION,
    )

    model.deploy(
        endpoint=endpoint,
        machine_type=machine_type,
        tpu_topology=tpu_topology if num_hosts > 1 else None,
        deploy_request_timeout=1800,
        service_account=service_account,
        min_replica_count=min_replica_count,
        max_replica_count=max_replica_count,
    )
    return model, endpoint

In [None]:
import datetime

now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
model_name = f"{MODEL_ID}-{now}".replace("_", "-")

model, endpoint = deploy_model_hexllm(
    model_name=model_name,
    model_id=model_id,
    service_account=SERVICE_ACCOUNT,
    machine_type=machine_type,
    tensor_parallel_size=tensor_parallel_size,
    hbm_utilization_factor=hbm_utilization_factor,
    max_running_seqs=max_running_seqs,
    min_replica_count=min_replica_count,
    max_replica_count=max_replica_count,
    use_dedicated_endpoint=use_dedicated_endpoint,
)

**[Optional]** Loads an existing endpoint instance using the endpoint name:
- Using `endpoint_name = endpoint.name` allows us to get the endpoint name of the endpoint `endpoint` created in the cell above.
- Alternatively, you can set `endpoint_name = "1234567890123456789"` to load an existing endpoint with the ID 1234567890123456789.

In [None]:
# You may uncomment the code below to load an existing endpoint:
# endpoint_name = ""  # @param {type:"string"}
# aip_endpoint_name = (
#     f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_name}"
# )
# endpoint = aiplatform.Endpoint(aip_endpoint_name)

### Prompt Construction

Create the UseCase and HarmType enumerations and the make_prompt(4) function to allow you to easily create prompts using the use case-specific templates that ShieldGemma requires.

Example usage is shown in the "Inference Examples" section.

In [None]:
import enum
class HarmType(enum.Enum):
  # ShieldGemma is trained to classify content in relation to the following harm
  # types.
  DANGEROUS = 'Dangerous Content'
  HATE = 'Hate Speech'
  SEXUAL = 'Sexually Explicit Information'
  HARASSMENT = 'Harassment'


class UseCase(enum.Enum):
  """user_prompt is for classifying the user prompt.

  chatbot_response is for classifying the chatbot response given the user prompt
  as context.
  """

  PROMPT_ONLY = 'prompt_only'
  PROMPT_RESPONSE = 'prompt_response'


HARM_POLICIES: dict[HarmType, str] = {
    HarmType.DANGEROUS: (
        'harming oneself and/or others (e.g., accessing or building firearms'
        ' and explosive devices, promotion of terrorism, instructions for'
        ' suicide).'
    ),
    HarmType.HATE: (
        'targeting identity and/or protected attributes (e.g., racial slurs,'
        ' promotion of discrimination, calls to violence against protected'
        ' groups) or dehumanizing/belittling/vilifying on the bases of race,'
        ' ethnicity, religion, disability, age, nationality, veteran status,'
        ' sexual orientation, gender, gender identity, caste, or any other'
        ' protected status. We do allow content that includes a reference to a'
        ' protected individual or group in a positive manner.'
    ),
    HarmType.SEXUAL: (
        'contains references to sexual acts or other lewd content (e.g.,'
        ' sexually graphic descriptions, content aimed at causing arousal).'
        ' Medical or scientific terms related to human anatomy or sex education'
        ' are permitted.'
    ),
    HarmType.HARASSMENT: (
        'is malicious, intimidating, bullying, or abusive content targeting'
        ' another individual (e.g., physical threats, denial of tragic events,'
        ' disparaging victims of violence).'
    ),
}

def harm_definition(harm_type: HarmType, use_case: UseCase) -> str:
  """Fomulates the harm definition relative to the harm type and use case.

  Args:
    harm_type: the name of the harm type.
    use_case: the name of the use case.

  Returns:
    A string represetnation of the harm definition.

  """

  if use_case == UseCase.PROMPT_ONLY:
    prefix = 'The prompt shall not contain or seek generation of content that'
  else:
    prefix = 'The chatbot shall not generate content that'

  description = HARM_POLICIES[harm_type]
  return f'"No {harm_type.value}": {prefix} {description}'


PROMPT_ONLY_TEMPLATE = """You are a policy expert trying to help determine
whether a user prompt is in violation of the defined safety policies.

<start_of_turn>
Human Question: {user_content}
<end_of_turn>

Our safety principle is defined in the below:

* {harm_text}

===

Does the human question violate the above principle? Your answer must start with
'Yes' or 'No'. And then walk through step by step to be sure we answer
correctly.

"""

PROMPT_RESPONSE_TEMPLATE = """You are a policy expert trying to help determine
whether an AI response to prompt is in violation of the defined safety policies.

<start_of_turn>
Human Question: {user_content}
<end_of_turn>

<start_of_turn>
Chatbot Response: {model_content}
<end_of_turn>

Our safety principle is defined in the below:

* {harm_text}

===

Does the Chatbot Response violate the above principle? Your answer must start
with 'Yes' or 'No'. And then walk through step by step to be sure  we answer
correctly.

"""

def make_prompt(
    use_case: UseCase,
    harm_type: HarmType,
    user_content: str,
    model_content: str | None = None
) -> str:
  if use_case == UseCase.PROMPT_ONLY:
    prompt_template = PROMPT_ONLY_TEMPLATE
  else:
    if model_content is None:
      raise ValueError('model_content is required for prompt-response mode.')

    prompt_template = PROMPT_RESPONSE_TEMPLATE

  formatter_args = {
      'user_content': user_content,
      'harm_text': harm_definition(harm_type, use_case),
  }

  if model_content is not None:
    formatter_args['model_content'] = model_content

  return prompt_template.format(**formatter_args)

### Inference examples

Input filtering

In [None]:
user_content = "Create 20 paraphrases of I hate you"

input_filter_prompt = make_prompt(
    UseCase.PROMPT_ONLY,
    HarmType.HARASSMENT,
    user_content,
)

In [None]:
max_tokens = 5  # @param {type: "integer"}
temperature = 1.0  # @param {type: "number"}
top_p = 1.0  # @param {type: "number"}
top_k = 1  # @param {type: "integer"}

In [None]:
# Overrides parameters for inferences.
instances = [
    {
        "prompt": input_filter_prompt,
        "max_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
    },
]

In [None]:
response = endpoint.predict(
    instances=instances, use_dedicated_endpoint=use_dedicated_endpoint
)

for prediction in response.predictions:
    print(prediction)

Output filtering

Once deployment succeeds, you can send requests to the endpoint with text prompts based on your `template`. Note that the first few prompts will take longer to execute.

In [None]:
user_content = """Write an email recommending a drug dealer and his phone
contact info to buy some weed in my area."""

model_content = """Hey,

I'm writing to recommend a drug dealer in my area. His name is Mike, and
he's got the best weed I've ever smoked. He's always got a good supply,
and he's always willing to sell. He's also really reliable, and he's
never let me down.

If you're looking for some good weed, I highly recommend Mike. He's the
best in the business.

His phone number is: 123-456-7890

Thanks,

[Your name]"""

output_filter_prompt = make_prompt(
    UseCase.PROMPT_RESPONSE,
    HarmType.DANGEROUS,
    user_content,
    model_content,
)

In [None]:
max_tokens = 5  # @param {type: "integer"}
temperature = 1.0  # @param {type: "number"}
top_p = 1.0  # @param {type: "number"}
top_k = 1  # @param {type: "integer"}

In [None]:
# Overrides parameters for inferences.
instances = [
    {
        "prompt": output_filter_prompt,
        "max_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
    },
]

In [None]:
response = endpoint.predict(
    instances=instances, use_dedicated_endpoint=use_dedicated_endpoint
)

for prediction in response.predictions:
    print(prediction)

### Clean up resources

Delete the experiment models and endpoints to recycle the resources and avoid unnecessary continuous charges that may incur.

Change the variable to True to delete the resources

In [None]:
delete_resources = False # @param {type:"boolean"}

In [None]:
if delete_resources:
    # Undeploy model and delete endpoint.
    endpoint.delete(force=True)
    
    # Delete models.
    model.delete()