In [25]:
import os
import json
import torch
import openai
from transformers import pipeline
from datasets import load_dataset
from collections import OrderedDict

def format_baseline_results(bart_results, other_results):

    reformatted_results = {}
    for idx, sample in enumerate(bart_results):
        cur_results = {} #processed_preds, gold_intent
        labels = sample["labels"]
        processed_pred = {"top1":labels[0].replace(" ", "_"),
                        "top2":labels[1].replace(" ", "_"),
                        "top3":labels[2].replace(" ", "_")}
        gold_intent = other_results[str(idx)]["gold_intent"]
        cur_results["gold_intent"] = gold_intent
        cur_results["processed_pred"] = processed_pred
        reformatted_results[str(idx)] = cur_results
    return reformatted_results

def print_samples(results, intent):
    correct_count = 0
    sample_count = 0
    for k, v in results.items():
        if v["gold_intent"] == intent:
            sample_count += 1
            print(v["text"])
            print(v["processed_pred"])
            if v["gold_intent"] == v["processed_pred"]["top1"]:
                correct_count += 1

    print(f"Total of {correct_count} predicted correctly out of {sample_count}")


In [8]:
#Banking77
# gpt3_top3 = json.load(open("/home/willy/instructod/src/IC/results/banking77/full_banking77_top3_gpt3.5_processed.json", "r"))
# gpt3_top1 = json.load(open("/home/willy/instructod/src/IC/results/banking77/full_banking77_top1_gpt3.5_processed.json", "r"))
# gpt4_top3 = json.load(open("/home/willy/instructod/src/IC/results/banking77/full_banking77_top3_gpt4_processed.json", "r"))
# gpt4_top1 = json.load(open("/home/willy/instructod/src/IC/results/banking77/full_banking77_top1_gpt4_processed.json", "r"))
# bart_corr = format_baseline_results(json.load(open("/home/willy/instructod/src/IC/results/baselines/bart-large-mnli-intent-correct.json", "r")), gpt3_top1)
# bart = format_baseline_results(json.load(open("/home/willy/instructod/src/IC/results/baselines/bart-large-mnli.json", "r")), gpt3_top1)

#CLINC150
gpt3_top3 = json.load(open("/home/willy/instructod/src/IC/results/clinc150/full_clinc150_top3_gpt3.5_processed.json", "r"))
gpt3_top1 = json.load(open("/home/willy/instructod/src/IC/results/clinc150/full_clinc150_top1_gpt3.5_processed.json", "r"))
gpt4_top3 = json.load(open("/home/willy/instructod/src/IC/results/clinc150/full_clinc150_top3_gpt4_processed.json", "r"))
gpt4_top1 = json.load(open("/home/willy/instructod/src/IC/results/clinc150/full_clinc150_top1_gpt4_processed.json", "r"))
bart_corr = format_baseline_results(json.load(open("/home/willy/instructod/src/IC/results/baselines/bart-large-mnli_intent-correct_clinc150.json", "r")), gpt3_top1)
bart = format_baseline_results(json.load(open("/home/willy/instructod/src/IC/results/baselines/bart-large-mnli_clinc150.json", "r")), gpt3_top1)

In [9]:
all_results = {"gpt3_top3":gpt3_top3,
               "gpt3_top1":gpt3_top1,
               "gpt4_top3":gpt4_top3,
               "gpt4_top1":gpt4_top1,
               "bart":bart,
               "bart_corr":bart_corr}

def update(dict_res, intent, setting):
    if intent not in dict_res:
        dict_res[intent] = {}
    try:
        dict_res[intent][setting] += 1
    except:
        dict_res[intent][setting] = 1
    return dict_res

def get_repartition(results):
    comparative_results = {}
    for k, v in results.items():
        gold = v["gold_intent"]
        if "top1" in v["processed_pred"]:
            top1 = v["processed_pred"]["top1"]
        else:
            top1 = None
        if "top2" in v["processed_pred"]:
            top2 = v["processed_pred"]["top2"]
        else:
            top2 = None
        if "top3" in v["processed_pred"]:
            top3 = v["processed_pred"]["top3"]
        else:
            top3 = None

        if top1 and gold == top1:
            comparative_results = update(comparative_results, gold, "gpt3_top1")
        elif top2 and gold == top2:
            comparative_results = update(comparative_results, gold, "gpt3_top2")
        elif top3 and gold == top3:
            comparative_results = update(comparative_results, gold, "gpt3_top3")

        else:
            #completely missclassed
            if gold not in comparative_results:
                comparative_results[gold] = {}
            try:
                comparative_results[gold]["gpt3_multi_fail"] += 1
            except:
                comparative_results[gold]["gpt3_multi_fail"] = 1
    return comparative_results

def print_fail(processed_results, threshold):
    count_total = 0
    total_fail_count = 0
    for k, v in processed_results.items():
        # for k1, v1 in v.items():
        if "gpt3_multi_fail" not in v:
            continue
        fail_count = v["gpt3_multi_fail"]
        if fail_count >= threshold:
            print(k, fail_count)
            count_total += 1
            total_fail_count += fail_count
    print(f"There are {count_total} misclassified for a threshold of {threshold}")
    print(f"There are {total_fail_count} total fail count for mistakes with a thrshold above {threshold}")


In [10]:
processed_results_gpt3 = get_repartition(gpt3_top3)
processed_results_gpt4 = get_repartition(gpt4_top3)
processed_results_gpt3_1 = get_repartition(gpt3_top1)
processed_results_gpt4_1 = get_repartition(gpt4_top1)
processed_results_bart = get_repartition(bart)
processed_results_bart_corr = get_repartition(bart_corr)
# print(json.dumps(processed_results, indent=2))

In [20]:
#Print all intent that have number of mistakes > threshold
threshold = 23
print("gpt3-3")
print_fail(processed_results_gpt3, threshold)
print("===============")
print("gpt4-3")
print_fail(processed_results_gpt4, threshold)
print("===============")
print("gpt3-1")
print_fail(processed_results_gpt3_1, threshold)
print("===============")
print("gpt4-1")
print_fail(processed_results_gpt4_1, threshold)
print("===============")
print("gpt_bart")
print_fail(processed_results_bart, threshold)
print("===============")
print("gpt_bart_corr")
print_fail(processed_results_bart_corr, threshold)

gpt3-3
calendar 25
reminder_update 26
calories 23
how_busy 30
oos 1000
There are 5 misclassified for a threshold of 23
There are 1104 total fail count for mistakes with a thrshold above 23
gpt4-3
oos 113
There are 1 misclassified for a threshold of 23
There are 113 total fail count for mistakes with a thrshold above 23
gpt3-1
distance 25
insurance 25
reminder 28
food_last 24
calories 30
tire_change 24
how_busy 26
gas 26
oos 1000
There are 9 misclassified for a threshold of 23
There are 1208 total fail count for mistakes with a thrshold above 23
gpt4-1
distance 25
spending_history 30
no 25
pto_used 25
reminder_update 30
ingredient_substitution 24
todo_list 24
goodbye 23
cancel 25
gas 26
oos 379
There are 11 misclassified for a threshold of 23
There are 636 total fail count for mistakes with a thrshold above 23
gpt_bart
insurance_change 24
improve_credit_score 27
fun_fact 26
change_user_name 27
shopping_list_update 30
rollover_401k 24
user_name 29
next_song 23
restaurant_suggestion 23
re

In [31]:
processed_results_bart["how_busy"]

{'gpt3_multi_fail': 21, 'gpt3_top1': 8, 'gpt3_top2': 1}

In [34]:
#Print sample for specific intent
print_samples(gpt4_top1, "how_busy")

is the resataurant busy at 5:00 pm
{'top1': 'recipe'}
how busy is the cafe at 7:00
{'top1': 'how_busy'}
would tio's be crowded at 7
{'top1': 'how_busy'}
can i get a table for four at 8:00
{'top1': 'restaurant_reservation'}
is the friday's full after 4
{'top1': 'how_busy'}
how long is the wait at fridays
{'top1': 'weather'}
is the mexican place crowded at night
{'top1': 'how_busy'}
are cool people at the bar at 9:00 pm
{'top1': 'application_status'}
what is the best time to go to get a burger without a line
{'top1': 'recipe'}
how long before i can eat at chic fil a
{'top1': 'restaurant_reviews'}
how long to be seated at carrabas
{'top1': 'accept_reservations'}
is the wait at pizza hut long
{'top1': 'recipe'}
is the breakfast place full in the mornings
{'top1': 'how_busy'}
is the wait more than an hour at the italian place
{'top1': 'weather'}
tell me how busy red robin is at 5 pm
{'top1': 'how_busy'}
i wanna know how busy denny's is at 5 am
{'top1': 'how_busy'}
i need to know how busy de

In [43]:
processed_results_bart_corr["oos"]

{'gpt3_multi_fail': 992, 'gpt3_top2': 2, 'gpt3_top3': 4, 'gpt3_top1': 2}

In [48]:
processed_results_gpt4["oos"]

{'gpt3_top3': 293, 'gpt3_top1': 583, 'gpt3_multi_fail': 113, 'gpt3_top2': 11}

In [50]:
processed_results_gpt4_1["oos"]

{'gpt3_multi_fail': 379, 'gpt3_top1': 621}