# Arthur Shield Test Harness

The code below can be used to perform a test of the PII, Prompt Injection, Sensitivity, and Toxicity detection capabilities of Arthur Shield.

### Pre-Requisites: 
- You will need to configure which checks you would like to perform within Arthur Shield. See https://shield.docs.arthur.ai/docs/rule-configuration-guide
- For non-hallucination checks:
  - You will need to compile a CSV file containing test prompts for each of the rules you would like to evaluate. You can reference the main template CSV in this folder for formatting, and use it to create your own list of prompts. The columns your input CSV should include are as follows (case-sensitive): 
    * *id* - A unique ID for the prompt
    * *prompt* - The text you are sending to Shield
    * *flag* - Which flag you expect this prompt to trigger. Must choose one of the following (ENUM): pii, prompt_injection, sensitivity, toxicity, control
- For hallucination checks: 
  - You will need to compile a separate CSV file from the one above containing test prompts for hallucination rules.
  - For each example, you will provide both a prompt (e.g. a question about an employee manual) and context (e.g. a paragraph from the employee manual). You should include two types of prompts:
      - *Test Prompts*: Factual questions which can be directly answered by the information provided in the context
      - *Control Prompts*: Factual questions requesting information that is unrelated to the information provided in the context (i.e. the context does not provide the answer to the question)
  - You can reference the hallucination template CSV in this folder for formatting, and use it to create your own list of prompts. The columns your input CSV should include are as follows (case-sensitive):
    * *id* - A unique ID for the prompt
    * *prompt* - The prompt you are sending to the LLM (should be a question related to the context, or an unrelated control question)
    * *context* - The context that the LLM should reference to answer the prompt question (or unrelated if control)
    * *context_answers_prompt* - Whether the answer to the question can be found in the prompt context or not. Bool value ("True" or "False")
- Please aim to have at least 15 (and preferably more) prompts for each of the rules you would like to test in order to gather accurate analysis results, including controls

### Output:
This harness will output 2 CSVs: 
- Test output file containing info on the prompt, Shield flags, and latency
- Metrics file containing various evaluation performance metrics for each of the Shield rules


### Configuration
Fill out the config below with the details specific to your Shield instance

In [None]:
import csv
import json
import pandas as pd
import math
import requests 
from datetime import datetime

SHIELD_VAL_ENDPOINT = "<your-shield-endpoint>/api/v2/validate_prompt"
SHIELD_API_KEY = "<insert-shield-api-key>"

MAIN_INPUT_FILE = 'shield_main_input.csv' # Change this to whatever your input file name is
MAIN_OUTPUT_FILE = 'shield_main_output.csv'
METRICS_FILE = 'shield_test_metrics.csv'

shield_headers = {
    'Authorization': f'Bearer {SHIELD_API_KEY}'
}


## i. Non-Hallucination Checks

### Support Functions

Shield's response schema can be found on the API docs: https://shield.jpmc-poc.arthur.ai/docs#/Default%20Validation/default_validate_prompt_api_v2_validate_prompt_post

In [None]:
def get_firewall_flags(resp_dict):
    # Determines which rules (PII, Prompt Injection, Toxicity, Sensitivity) were flagged
    flags = []
    for rule in resp_dict["rule_results"]:
        if rule["result"] == "Fail":
            if rule["rule_type"] == "ModelHallucinationRuleV2":
                # Only used for responses - does not check for other flags
                flags.append("hallucination")
                break
            elif rule["rule_type"] == "PIIDataRule" and "pii" not in flags:
                flags.append("pii")
            elif rule["rule_type"] == "PromptInjectionRule" and "prompt_injection" not in flags:
                flags.append("prompt_injection")
            elif rule["rule_type"] == "ToxicityRule" and "toxicity" not in flags:
                flags.append("toxicity")
            elif rule["rule_type"] == "ModelSensitiveDataRule" and "sensitivity" not in flags:
                flags.append("sensitivity")
    return flags

def get_gt_output_match(row):
    # Returns True if GT label matches Shield flags and False otherwise
    if row["ground_truth"] == "control": 
        return False if row["shield_response"] else True
    else:
        return True if row["ground_truth"] in row["shield_response"] else False
    

def get_pii_triggers(shield_response):
    # Isolate the text which triggered the PII rule to fire
    output = []
    for rule in shield_response["rule_results"]:
        if rule["rule_type"]== "PIIDataRule" and rule["result"]=="Fail":
            for flag in rule["details"]["pii_entities"]:
                output.append({
                    "string_trigger": flag["span"],
                    "flag_type": flag["entity"]
                })
    return output

def get_shield_response(row):
    # Calls Shield API and updates the row dictionary based upon the results from Shield
    shield_start = datetime.now()
    response = requests.post(SHIELD_VAL_ENDPOINT, headers=shield_headers, json={"prompt": row["prompt"]})
    shield_end = datetime.now()
    row["latency_ms"] = int((shield_end - shield_start).microseconds/1000)
    row["shield_response"] = get_firewall_flags(response.json())
    row["gt_output_match"] = get_gt_output_match(row)
    row["sub_flags"] = get_pii_triggers(response.json()) if "pii" in row["shield_response"] else None
    return row


### Generate Output File

The output file generated will contain: 
  * *id* - The ID of the corresponding input file prompt
  * *prompt* - The prompt that was passed
  * *ground_truth* - The flag that was expected
  * *shield_response* - The flag that Shield raised
  * *test_passed* - Whether the ground truth label matches the flag that Shield raised (True or False)
  * *sub-flags* - (For PII) Which sentences were flagged and which PII flag was raised
  * *latency_ms* - Latency of Shield call in ms

In [None]:
with open(MAIN_INPUT_FILE, newline='') as infile:
    reader= csv.DictReader(infile)
    with open(MAIN_OUTPUT_FILE, 'w', newline='') as outfile:
        # List of columns for the output file
        fieldnames = [
            "id",
            "prompt",
            "ground_truth",
            "shield_response",
            "gt_output_match",
            "sub_flags",
            "latency_ms"
        ]
        writer= csv.DictWriter(outfile, fieldnames=fieldnames)
        writer.writeheader()
        
        for row in reader:
            out_row = {
                "id": row["id"],
                "ground_truth": row["flag"],
                "prompt": row["prompt"]
                }

            out_row = get_shield_response(out_row)
            writer.writerow(out_row)
            

## ii. Hallucination Checks

In order to perform hallucination checks, you will need to connect to an LLM. Below is a simple API connection; if your LLM requires a different setup, you may have to modify this block to fit your setup.

In [None]:
SHIELD_RESP_ENDPOINT = "<your-shield-endpoint>/api/v2/validate_response/"

LLM_ENDPOINT = "<llm-endpoint>"
LLM_API_KEY = "<llm-api-key>"

HALLUCINATION_INPUT = 'shield_hallucination_input_template.csv' # Change this to whatever your input file name is 
HALLUCINATION_OUTPUT = 'shield_hallucination_output.csv'

#Note - you may need to customize how the LLM connection works per your setup
llm_headers = {
    'Authorization': f'Bearer {LLM_API_KEY}',
    'Content-Type': 'application/json'
}

### Support Functions

In [None]:
def get_claims(response):
    # Extracts the claims from the Shield flag 
    for rule in response["rule_results"]:
        if rule["rule_type"] != "ModelHallucinationRuleV2":
            continue
        else:
            if rule["result"] == "Pass":
                return []
            else:
                return rule["details"]["claims"]

def get_inference_id(prompt):
    # Conduct an initial Shield call to get an inference ID for the response endpoint
    response = (requests.post(
        SHIELD_VAL_ENDPOINT, 
        headers=shield_headers, 
        json={"prompt": prompt})).json()
    return str(response["inference_id"])

## NOTE: May need to re-write the below function based upon your LLM setup. This function is 
## compatible with OpenAI's API
def get_llm_response(prompt, context):
    # Get the LLM response to the initial prompt, passing along the context provided
    # Note: This data JSON format is specific to OpenAI; may need modification for other 
    data = {
        "model": "gpt-4o",
        "messages": [
            {"role": "system", "content": context},
            {"role": "user", "content": prompt}
        ]
    }

    response = requests.post(LLM_ENDPOINT, headers=llm_headers, json=data)
    
    if response.status_code == 200:
        return response.json()['choices'][0]['message']['content']
    else:
        raise Exception(f"Error: {response.status_code}, {response.text}")    

        
def get_hallucination_shield_response(row):
    # Calls Shield API and updates the row dictionary based upon the results from Shield
    resp_endpoint = SHIELD_RESP_ENDPOINT + get_inference_id(row["prompt"])
    row["llm_response"] = get_llm_response(row["prompt"], row["context"])
    shield_start = datetime.now()
    response = requests.post(
        resp_endpoint, 
        headers=shield_headers, 
        json={"response": row["llm_response"], "context": row["context"]}
    )
    shield_end = datetime.now()
    row["latency_ms"] = int((shield_end - shield_start).microseconds/1000)
    row["shield_response"] = get_firewall_flags(response.json())
    row["valid_response"] = ""
    row["claims"] = get_claims(response.json())
    return row
    

### Generate Output File

In [None]:
with open(HALLUCINATION_INPUT, newline='') as infile:
    reader= csv.DictReader(infile)
    with open(HALLUCINATION_OUTPUT, 'w', newline='') as outfile:
        # List of columns for the output file
        fieldnames = [
            "id",
            "prompt",
            "context",
            "context_answers_prompt",
            "llm_response",
            "shield_response",
            "valid_response",
            "claims",
            "latency_ms"
        ]
        writer= csv.DictWriter(outfile, fieldnames=fieldnames)
        writer.writeheader()
        
        for row in reader:
            out_row = {
                "id": row["id"],
                "context_answers_prompt": row["context_answers_prompt"],
                "prompt": row["prompt"],
                "context": row["context"]
                }

            out_row = get_hallucination_shield_response(out_row)
            writer.writerow(out_row)

**NOTE: In order to run analysis on the hallucination output, you will need to manually classify each output example as valid or invalid.**

Under the "valid_response" column of the output file, you will assess each statement's results to determine whether or not the Shield response is valid. You will enter "True" or "False" (case sensitive). Here's a guide for classifying the responses:
- valid_response = True if:
    - The "shield_response" flagged the statement as a hallucination, and the "llm_response" did indeed hallucinate when answering the question (i.e. LLM responded with information not included in the context).
    - The "shield_response" did not flag the statement as a hallucination, and the "llm_response" either a) did not attempt to answer the question, or b) correctly answered the question with information it retrieved from the context
- valid_response = False if:
    - The "shield_response" flagged the statement as a hallucination, but the "llm_response" either a) correctly responded to the question using the provided context or b) did not answer the question and did not provide any information not included in the context
    - The "shield_response" did not flag the statement as a hallucination, but the "llm_response" attempted to answer the question with information that was not included in the provided context

## iii. Analysis

The metrics file generated will contain:
  * True Positive Count
  * False Positive Count
  * Precision
  * Recall
  * Specificity
  * Miss Rate
  * False Positive Rate
  * F1 Score

In [None]:
def round_row(data_row, precision=3):
    return [round(item, precision) if isinstance(item, (float, int)) else item for item in data_row]

def create_metrics_dict(metric_name):
    if metric_name == "hallucination":
        hdf= pd.read_csv(HALLUCINATION_OUTPUT)
        metrics = {
            "tp": hdf[(hdf["valid_response"]==True) & (hdf['shield_response'].apply(lambda x: metric_name in x))].shape[0],
            "fp": hdf[(hdf["valid_response"]==False) & (hdf['shield_response'].apply(lambda x: metric_name in x))].shape[0],
            "fn": hdf[(hdf["valid_response"]==False) & (hdf['shield_response'].apply(lambda x: metric_name not in x))].shape[0],
            "tn": hdf[(hdf["valid_response"]==True) & (hdf['shield_response'].apply(lambda x: metric_name not in x))].shape[0]
        }
    else:
        df = pd.read_csv(MAIN_OUTPUT_FILE)
        metrics = {
            "tp": df[(df['gt_output_match']==True) & (df['ground_truth']==metric_name)].shape[0],
            "fp": df[(df['ground_truth']!=metric_name) & (df['shield_response'].apply(lambda x: metric_name in x))].shape[0],
            "fn": df[(df['ground_truth']==metric_name) & (df['gt_output_match']==False)].shape[0],
            "tn": df[(df['ground_truth']!=metric_name) & (df['shield_response'].apply(lambda x: metric_name not in x))].shape[0]
        }
    try:
        metrics["prec"] = metrics["tp"]/(metrics["tp"]+metrics["fp"])
    except: 
        metrics["prec"] = "N/A"
    try:
        metrics["recall"] = metrics["tp"]/(metrics["tp"]+metrics["fn"])
    except:
        metrics["recall"] = "N/A"
    return metrics

def run_analysis(checks=["pii", "prompt_injection", "toxicity", "sensitivity", "hallucination"]):
    # Runs analysis on the selected checks. Runs on all by default
    with open(METRICS_FILE, 'w', newline='') as outfile:
        writer = csv.writer(outfile)
        metrics = []
        for check in checks:
            metrics.append(create_metrics_dict(check))
        writer.writerow([""]+checks)
    
        # True Positive Count
        tpc= ["True Positive Count"]
        for metric in metrics:
            tpc.append(metric["tp"])
        writer.writerow(round_row(tpc))
    
        # False Positive Count
        fpc= ["False Positive Count"]
        for metric in metrics:
            fpc.append(metric["fp"])
        writer.writerow(round_row(fpc))
                        
        # Precision
        prec = ["Precision"]
        for metric in metrics:
            prec.append(metric["prec"])
        writer.writerow(round_row(prec))
        
        # Recall
        recall = ["Recall"]
        for metric in metrics:
            recall.append(metric["recall"])
        writer.writerow(round_row(recall))

        # Specificity
        spec = ["Specificity"]
        for metric in metrics:
            try:
                spec.append(metric["tn"]/(metric["tn"]+metric["fp"]))
            except:
                spec.append("N/A")
        writer.writerow(round_row(spec))
    
        # Miss Rate
        miss = ["Miss Rate"]
        for metric in metrics:
            try:
                miss.append(metric["fn"]/(metric["fn"]+metric["tp"]))
            except:
                miss.append("N/A")
        writer.writerow(round_row(miss))
        
        # False Positive Rate
        fpr = ["False Positive Rate"]
        for metric in metrics:
            try:
                fpr.append(metric["fp"]/(metric["fp"]+metric["tn"]))
            except:
                fpr.append("N/A")
        writer.writerow(round_row(fpr))
    
        # F1 Score
        f1 = ["F1 Score"]
        for metric in metrics:
            try:
                f1.append((2*metric["prec"]*metric["recall"])/(metric["prec"]+metric["recall"]))
            except:
                f1.append("N/A")
        writer.writerow(round_row(f1))


The cell below will run analysis on all of the flags included in the list (the list can contain "pii", "prompt_injection", "toxicity", and/or "sensitivity"). **If any of these flags are not enabled in the Shield instance you are evaluating, omit it from the input list to avoid generating errors.**

In [None]:
run_analysis([
    "pii",
    "prompt_injection",
    "toxicity",
    "sensitivity",
    "hallucination"
])