In [None]:
import sys
import json 
from datetime import datetime, timedelta
import pandas as pd
from os.path import abspath, join
from itertools import groupby
import matplotlib.pyplot as plt
import numpy as np

utils_path = abspath(join('..', 'utils'))
if utils_path not in sys.path:
    sys.path.append(utils_path)

from shield_utils import setup_env, get_token_usage, query_inferences, search_tasks

pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_rows', None)

setup_env(base_url="<URL>", api_key="<API_KEY>")

---
Local Helper Functions 

In [2]:
def bucket_inferences_by_day(inferences): 
    def millis_to_datetime(millis):
        return datetime.fromtimestamp(millis / 1000.0)

    bucketed_objects = {}
    for inf in inferences:
        date = millis_to_datetime(inf["created_at"]).date()
        if date not in bucketed_objects:
            bucketed_objects[date] = []
        bucketed_objects[date].append(inf)

    return bucketed_objects

def get_prompt_and_response_inf_per_day(task_inferences_map_by_day):

    prompt_inf_day_map = {}
    response_inf_day_map = {}

    for day in task_inferences_map_by_day: 

        for inf in task_inferences_map_by_day.get(day): 
            prompt_inf_curr_day = prompt_inf_day_map.get(day)

            if prompt_inf_curr_day is None: 
                prompt_inf_curr_day = []

            response_inf_curr_day = response_inf_day_map.get(day)
            if response_inf_curr_day is None: 
                response_inf_curr_day = []

            prompt_inf = inf["inference_prompt"]
            response_inf = inf["inference_response"]

            if prompt_inf: 
                prompt_inf_curr_day.append(prompt_inf)
            if response_inf: 
                response_inf_curr_day.append(response_inf)

            prompt_inf_day_map[day] = prompt_inf_curr_day
            response_inf_day_map[day] = response_inf_curr_day    

    return prompt_inf_day_map, response_inf_day_map

---

### 1. Retrieve relevant inference and task infromation for a given period of time

Enter the following: 

- Start Time 
- End Time 
- Task Ids (or ALL_TASKS = TRUE for all tasks in the environment)
- Rule Types 



In [None]:
end = datetime.now() + timedelta(days=1)
start = end - timedelta(days=30)

task_ids = ["<TASK1>", "<TASK2>"]

ALL_TASKS = True

rule_types = ["ModelHallucinationRuleV2", "ModelSensitiveDataRule", "PIIDataRule", "ToxicityRule", "RegexRule", "KeywordRule", "PromptInjectionRule"]

if ALL_TASKS: 
    inferences = query_inferences(start=start, end=end)
else: 
    inferences = query_inferences(start=start, end=end, task_ids=task_ids)

# Get additional task information, like name, for enrichemnt purposes 
tasks = search_tasks(all_tasks=ALL_TASKS, task_ids=task_ids)

all_tasks_info_map = {task["id"]: task for task in tasks}

print(inferences)

In [4]:
def custom_sort(item):
    task_id = item.get('task_id')
    if task_id is None:
        return "N/A"
    else:
        return task_id

sorted_inferences = sorted(inferences, key=custom_sort)

tasks_inferences_map = {task_id: list(objects) for task_id, objects in groupby(
    sorted_inferences, key=custom_sort)}

In [None]:
# Token usage information 
usage = get_token_usage(start=start, end=end, groupby_rule_type=True, groupby_task=True)

usage_per_task_map = {}

for usage_metric in usage: 

    task_id = usage_metric["task_id"]
    rule_type = usage_metric["rule_type"]

    task_usage_data =  usage_per_task_map.get(task_id)

    if task_usage_data is None: 
        task_usage_data = {}
    
    task_usage_data_per_rule = task_usage_data.get(rule_type)

    if task_usage_data_per_rule is None: 
        task_usage_data_per_rule = usage_metric["count"]
    

    task_usage_data[rule_type] = task_usage_data_per_rule

    usage_per_task_map[task_id] = task_usage_data

print(usage_per_task_map)


---
### 2. Total Inference & Token Usage across all tasks specified for the given time period 

In [None]:
all_inferences_by_day = bucket_inferences_by_day(inferences)

all_inf_propmt_day_map, all_inf_response_day_map = get_prompt_and_response_inf_per_day(all_inferences_by_day)


print(f'''
*********************************************************************************************
      
Usage metrics per task. Start: {start} to End: {end}

Tasks: {all_tasks_info_map.keys()}
      
*********************************************************************************************
''')

sorted_days = sorted(all_inferences_by_day.keys(), reverse=False)
min_date = min(sorted_days)
max_date = max(sorted_days)
all_dates = [min_date + timedelta(days=i) for i in range((max_date - min_date).days + 1)]

all_dates_str = [str(day) for day in all_dates]

total_inf_counts = []
prompt_inf_counts = []
response_inf_counts = []
response_inf_counts_failed = []
response_inf_counts_pass = []
prompt_inf_counts_failed = []
prompt_inf_counts_pass = []

for day in all_dates:
    if day in all_inferences_by_day:
        total_inf_counts.append(len(all_inferences_by_day[day]))
    else:
        total_inf_counts.append(0)
    
    if day in all_inf_propmt_day_map:
        prompt_inf_counts.append(len(all_inf_propmt_day_map[day]))
        prompt_inf_counts_failed.append(sum(1 for result in all_inf_propmt_day_map[day] if result['result'] == "Fail"))
        prompt_inf_counts_pass.append(sum(1 for result in all_inf_propmt_day_map[day] if result['result'] == "Pass"))
    else:
        prompt_inf_counts.append(0)
        prompt_inf_counts_failed.append(0)
        prompt_inf_counts_pass.append(0)
    
    if day in all_inf_response_day_map:
        response_inf_counts.append(len(all_inf_response_day_map[day]))
        response_inf_counts_failed.append(sum(1 for result in all_inf_response_day_map[day] if result['result'] == "Fail"))
        response_inf_counts_pass.append(sum(1 for result in all_inf_response_day_map[day] if result['result'] == "Pass"))
    else:
        response_inf_counts.append(0)
        response_inf_counts_failed.append(0)
        response_inf_counts_pass.append(0)
        
# Plotting the prompt data 
plt.figure(figsize=(8, 6))
plt.plot(all_dates_str, prompt_inf_counts, marker='o', color='black', linestyle='-', label = 'Prompt Inf Count')
plt.plot(all_dates_str, prompt_inf_counts_failed, marker='o', color='red', linestyle='--', label = 'Failures')
plt.plot(all_dates_str, prompt_inf_counts_pass, marker='o', color='green', linestyle='--', label = 'Success')

plt.ylabel('Inferences')
plt.title(f"Prompt inferences")
plt.xticks(rotation=45) 
plt.grid(True)
plt.tight_layout() 
plt.legend()
plt.show()

# Plotting the response data 
plt.figure(figsize=(8, 6))
plt.plot(all_dates_str, response_inf_counts, marker='o', color='black', linestyle='-', label = 'Response Inf Count')
plt.plot(all_dates_str, response_inf_counts_failed, marker='o', color='red', linestyle='--', label = 'Failures')
plt.plot(all_dates_str, response_inf_counts_pass, marker='o', color='green', linestyle='--', label = 'Success')
plt.ylabel('Inferences')
plt.title(f"Response inferences")
plt.xticks(rotation=45) 
plt.grid(True)
plt.tight_layout() 
plt.legend()
plt.show()

---
### 3. Inference & Token Usage information per task 

In [None]:
print(f'''
*********************************************************************************************
      
Usage metrics per task. Start: {start} to End: {end}
      
*********************************************************************************************''')

sorted_tasks_inferences_map = dict(sorted(tasks_inferences_map.items(), key=lambda x: len(x[1]), reverse=True))


for key, value in sorted_tasks_inferences_map.items():
    print(f"Task: {key}, Num inferences: {len(value)}")


for task_id in sorted_tasks_inferences_map: 
    print("*********************************************************************************************")
    task_info = all_tasks_info_map.get(task_id)
    if task_info is None: 
        task_name = task_id
    else: 
        task_name = task_info["name"]
    print(f"Task Name: {task_name}")
    print(f"Task Id: {task_id}")

    inferences = sorted_tasks_inferences_map.get(task_id)

    print(type(inferences))

    print(f"Total num of inferences over selected time period {len(inferences)}")
    inferences_by_day = bucket_inferences_by_day(inferences)

    prompt_inf_day_map, response_inf_day_map = get_prompt_and_response_inf_per_day(inferences_by_day)

    sorted_days = sorted(inferences_by_day.keys(), reverse=False)

    # Total counts 
    total_inf_counts = [len(inferences_by_day[day]) for day in sorted_days]
    prompt_inf_counts = [len(prompt_inf_day_map[day]) for day in sorted_days]
    response_inf_counts =  [len(response_inf_day_map[day]) for day in sorted_days]
    sorted_days_str = [str(day) for day in sorted_days]

    # Specific counts 
    response_inf_counts_failed = [sum(1 for result in response_inf_day_map[day] if result['result'] == "Fail") for day in sorted_days]
    response_inf_counts_pass = [sum(1 for result in response_inf_day_map[day] if result['result'] == "Pass") for day in sorted_days]

    prompt_inf_counts_failed = [sum(1 for result in prompt_inf_day_map[day] if result['result'] == "Fail") for day in sorted_days]
    prompt_inf_counts_pass = [sum(1 for result in prompt_inf_day_map[day] if result['result'] == "Pass") for day in sorted_days]


    # Plotting the prompt data 
    plt.figure(figsize=(8, 6))
    plt.plot(sorted_days_str, prompt_inf_counts, marker='o', color='black', linestyle='-', label = 'Prompt Inf Count')
    plt.plot(sorted_days_str, prompt_inf_counts_failed, marker='o', color='red', linestyle='--', label = 'Failures')
    plt.plot(sorted_days_str, prompt_inf_counts_pass, marker='o', color='green', linestyle='--', label = 'Success')

    plt.ylabel('Inferences')
    plt.title(f"\"{task_name}\" prompt inferences")
    plt.xticks(rotation=45) 
    plt.grid(True)
    plt.tight_layout() 
    plt.legend()
    plt.show()

    # Plotting the response data 
    plt.figure(figsize=(8, 6))
    plt.plot(sorted_days_str, response_inf_counts, marker='o', color='black', linestyle='-', label = 'Response Inf Count')
    plt.plot(sorted_days_str, response_inf_counts_failed, marker='o', color='red', linestyle='--', label = 'Failures')
    plt.plot(sorted_days_str, response_inf_counts_pass, marker='o', color='green', linestyle='--', label = 'Success')
    plt.ylabel('Inferences')
    plt.title(f"\"{task_name}\" response inferences")
    plt.xticks(rotation=45) 
    plt.grid(True)
    plt.tight_layout() 
    plt.legend()
    plt.show()
    
    # Token Usage Data 
    print("Token usage information")
    print('''
    User input (raw input by user) - metric produced by Arthur
    Prompt (raw input by user + Shield prompt augmentation) - metric produced by Azure
    Completion (raw input by user + Shield prompt augmentation) - metric produced by Azure
    ''')
    print(json.dumps(usage_per_task_map.get(task_id), indent=4))



---
### 4. Rule status information per task and rule type

In [None]:
print(f'''
*********************************************************************************************
      
Rule status details per task (prompt + reponse). Start: {start} to End: {end}
      
*********************************************************************************************
''')
for task in tasks_inferences_map: 
    print("*********************************************************************************************")
    inferences = tasks_inferences_map.get(task)
    task_info = all_tasks_info_map.get(task)
    if task_info is None: 
        task_name = task_id
    else: 
        task_name = task_info["name"]
    print(f"Task Name: {task_name}")
    print(f"Task Id: {task}")
    print(f"Total num of inferences {len(inferences)}")
    inferences_by_day = bucket_inferences_by_day(inferences)

    prompt_inf_day_map, response_inf_day_map = get_prompt_and_response_inf_per_day(inferences_by_day)


    print("******************** Prompt rule statuses per day by rule type ***************************")
    results_by_day_prompt = {}
    for day in prompt_inf_day_map: 
        result_buckets_prompt = {}
        for prompt_inf in prompt_inf_day_map[day]:

            prompt_rule_results = prompt_inf["prompt_rule_results"]
            for result in prompt_rule_results:
                rule_type = result["rule_type"]
                result = result["result"]
                if rule_type not in result_buckets_prompt:
                    result_buckets_prompt[rule_type] = {"Pass": 0, "Fail": 0, "Skipped": 0, "Unavailable": 0}
                result_buckets_prompt[rule_type][result] += 1
        results_by_day_prompt[str(day)] = result_buckets_prompt
    print(json.dumps(results_by_day_prompt, indent=4))

    dates = list(results_by_day_prompt.keys())
    sorted_days = sorted(results_by_day_prompt.keys(), reverse=False)
    
    rules = set(rule for date in dates for rule in results_by_day_prompt[date].keys())
    statuses = ['Pass', 'Fail', 'Skipped', 'Unavailable']
    colors = {'Pass': 'green', 'Fail': 'red', 'Skipped': 'grey', 'Unavailable': 'black'}

    for rule in rules:
        pass_counts = [results_by_day_prompt[date].get(rule, {}).get('Pass', 0) for date in sorted_days]
        fail_counts = [results_by_day_prompt[date].get(rule, {}).get('Fail', 0) for date in sorted_days]
        skip_counts = [results_by_day_prompt[date].get(rule, {}).get('Skipped', 0) for date in sorted_days]
        unavailable_counts = [results_by_day_prompt[date].get(rule, {}).get('Unavailable', 0) for date in sorted_days]

        pass_total = sum(pass_counts)
        fail_total = sum(fail_counts)
        skip_total = sum(skip_counts)
        unavailable_total = sum(unavailable_counts)

        bar_width = 0.2
        index = range(len(sorted_days))

        plt.bar(index, pass_counts, bar_width, label='Pass', color=colors["Pass"])
        plt.bar([i + bar_width for i in index], fail_counts, bar_width, label='Fail', color=colors["Fail"])
        plt.bar([i + 2 * bar_width for i in index], skip_counts, bar_width, label='Skipped', color=colors["Skipped"])
        plt.bar([i + 3 * bar_width for i in index], unavailable_counts, bar_width, label='Unavailable', color=colors["Unavailable"])

        text_strings = [
            f'Totals\nPass: {pass_total} Fail: {fail_total}',
            f'Skipped: {skip_total} Unavailable: {unavailable_total}'
        ]
        
        text = '\n'.join(text_strings)
        fig = plt.gcf()
        fig.text(0.05, 0.95, text, ha='center', fontsize=10, bbox=dict(facecolor='white', alpha=0.5))

        plt.xlabel('Date')
        plt.ylabel('Count')
        plt.title(f'{rule} Statuses (Prompt)')
        plt.xticks([i + 1.5 * bar_width for i in index], sorted_days, rotation=45, ha='right') 
        plt.legend()
        plt.tight_layout()
        plt.show()

    print("******************** Response rule statuses per day by rule type ***************************")
    results_by_day = {}
    for day in response_inf_day_map: 
        result_buckets = {}
        for response_inf in response_inf_day_map[day]:
            response_rule_results = response_inf["response_rule_results"]
            for result in response_rule_results:
                rule_type = result["rule_type"]
                result = result["result"]
                if rule_type not in result_buckets:
                    result_buckets[rule_type] = {"Pass": 0, "Fail": 0, "Skip": 0, "Unavailable": 0}
                result_buckets[rule_type][result] += 1
        results_by_day[str(day)] = result_buckets
    print(json.dumps(results_by_day, indent=4))

    dates = list(results_by_day.keys())
    sorted_days = sorted(results_by_day.keys(), reverse=False)
    
    rules = set(rule for date in dates for rule in results_by_day[date].keys())
    statuses = ['Pass', 'Fail', 'Skipped', 'Unavailable']
    colors = {'Pass': 'green', 'Fail': 'red', 'Skipped': 'grey', 'Unavailable': 'black'}

    for rule in rules:
        pass_counts = [results_by_day[date].get(rule, {}).get('Pass', 0) for date in sorted_days]
        fail_counts = [results_by_day[date].get(rule, {}).get('Fail', 0) for date in sorted_days]
        skip_counts = [results_by_day[date].get(rule, {}).get('Skipped', 0) for date in sorted_days]
        unavailable_counts = [results_by_day[date].get(rule, {}).get('Unavailable', 0) for date in sorted_days]

        pass_total = sum(pass_counts)
        fail_total = sum(fail_counts)
        skip_total = sum(skip_counts)
        unavailable_total = sum(unavailable_counts)

        bar_width = 0.1
        index = range(len(sorted_days))

        plt.bar(index, pass_counts, bar_width, label='Pass', color=colors["Pass"])
        plt.bar([i + bar_width for i in index], fail_counts, bar_width, label='Fail', color=colors["Fail"])
        plt.bar([i + 2 * bar_width for i in index], skip_counts, bar_width, label='Skipped', color=colors["Skipped"])
        plt.bar([i + 3 * bar_width for i in index], unavailable_counts, bar_width, label='Unavailable', color=colors["Unavailable"])

        plt.xlabel('Date')
        plt.ylabel('Count')
        plt.title(f'{rule} Statuses (Response)')
        plt.xticks([i + 1.5 * bar_width for i in index], sorted_days, rotation=45, ha='right') 
        plt.legend()

        text_strings = [
            f'Totals\nPass: {pass_total} Fail: {fail_total}',
            f'Skipped: {skip_total} Unavailable: {unavailable_total}'
        ]
        
        text = '\n'.join(text_strings)
        fig = plt.gcf()
        fig.text(0.05, 0.95, text, ha='center', fontsize=10, bbox=dict(facecolor='white', alpha=0.5))

        plt.tight_layout()
        plt.show()

### Further analysis 

In [13]:
if ALL_TASKS: 
    unavailables = query_inferences(start, end, rule_types=["ModelHallucinationRuleV2"], rule_statuses=["Unavailable"])
    skipped = query_inferences(start, end, rule_types=["ModelHallucinationRuleV2"], rule_statuses=["Skipped"])

else: 
    unavailables = query_inferences(start, end, task_ids=task_ids, rule_types=["ModelHallucinationRuleV2"], rule_statuses=["Unavailable"])
    skipped = query_inferences(start, end, task_ids=task_ids, rule_types=["ModelHallucinationRuleV2"], rule_statuses=["Skipped"])