In [None]:
import torch
from collections import defaultdict
import numpy as np
import os

datasets = {"multiwoz":"multiwoz_results", "taskmaster":"taskmaster_results", "abcd":"abcd_results"}
for d in datasets:
    cmd = f"python plot.py --eval_type real --save_dir {datasets[d]}"
    !{cmd}
    
results_dict = {}
for d in datasets:
    results_dict[d] = defaultdict(lambda: defaultdict(int))
condition_names = {"Prompting":"Prompt", "DirectedBeamSearch":"DBS", "CGMH":"CGMH",'Retrieval':'FOP-retrieval','FuturesOfThePast(no-window)':'FOP-guided(no-window)','FuturesOfThePast':'FOP-guided', 'WindowFuturesOfThePast':'WindowFOP', 'WFirst':'$\mathcal{W}_{first}$', 'FinetunedModel':'Finetuned'}
best_scores = {}
for dataset in datasets:
    best_scores[dataset] = {'precision':-1,'recall':-1,'f1-score':-1}

final_strs = {}
for dataset in datasets:
    result_dir = datasets[dataset]
    if result_dir is None:
        continue
    for metric in ['precision','recall','f1-score']:
        if result_dir == "":
            continue
        data_filename = f"{result_dir}/keywords_{metric}_results_dict.pkl"
        if os.path.exists(data_filename):
            tmp_dict = torch.load(data_filename)
        else:
            continue
        for condition in condition_names:
            if condition not in tmp_dict:
                continue
            value = tmp_dict[condition][-1]
            assert value[0] == 9 # check it's for 9 keywords
            metric_value = value[1][0]
            results_dict[dataset][condition][metric] = metric_value
            if metric_value > best_scores[dataset][metric]:
                best_scores[dataset][metric] = round(metric_value,2)

# Printing precision, recall, F1-score results
print("Precision, recall, F1-score\n")
for condition in ['WFirst', 'FinetunedModel', 'Prompting','DirectedBeamSearch','CGMH','Retrieval','FuturesOfThePast', 'FuturesOfThePast(no-window)']:
    s = condition_names[condition]+" & "
    all_f1_scores = []
    for dataset in datasets:
        tmp_condition = condition
        if condition == "FuturesOfThePast":
            tmp_condition = "WindowFuturesOfThePast"
        elif condition == "FuturesOfThePast(no-window)":
            tmp_condition = "FuturesOfThePast"
        if datasets[dataset] is None:
            s += "0.000 & 0.000 & 0.000 & "
        else:
            for m in ['precision', 'recall','f1-score']:
                metric_value = round(results_dict[dataset][tmp_condition][m],2)
                if m == "f1-score":
                    all_f1_scores.append(metric_value)
                if metric_value == best_scores[dataset][m]:
                    s += "\\textbf{"+str(metric_value)+"} & "
                else:
                    s += f"{metric_value} & "
    average_f1_score = round(np.mean(all_f1_scores),2)
    s += f"{average_f1_score} & "
    s = " ".join(s.split()[0:-1]).strip()
    if condition != "FuturesOfThePast":
        s+="\\\\"
    if condition in ["CGMH", 'FinetunedModel']:
        s += " \midrule"
    print(s)
print("\n")