In [1]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.metrics import cohen_kappa_score
import pandas as pd
import csv



In [2]:
def get_test_subset():
    file_path = "splitted_data.txt"

    with open(file_path, 'r') as file:
        reader = csv.reader(file, delimiter='\t')
        df = pd.DataFrame(reader, columns=['turn_id', 'user_utterance', 'response',  'passage_id', 'passage_txt', 'score', 'ptkb', 'lable'])
    df.head()   


    test_set = df[df['lable']=='test']
    turn_passages_test = []

    for _, row in test_set.iterrows():
        turn_passages_test.append(row["turn_id"] +'****'+row["passage_id"])
    
    return turn_passages_test

In [3]:
def get_pool(paths):
    pool_1 = defaultdict(int)

    with open(paths, 'r') as f:
        lines = f.readlines()

    for line in lines:
        turn_id, _, passage_id, score = line.split('\t')   
        id = turn_id+'****'+ passage_id
        pool_1[id] = int(score.strip())
    return pool_1


In [24]:
def confusion_matrix_print(path_gpt, path_nist, subset):
    
    gpt_pool = get_pool(path_gpt)
    nist_pool = get_pool(path_nist)
    gpt_label = []
    nist_label = []

    for id in gpt_pool:
        if id in subset:
            gpt_label.append(gpt_pool[id])
            nist_label.append(nist_pool[id])

    
    gpt_binary = [0 if elem<2 else 1 for elem in gpt_label]
    nist_binary = [0 if elem<2 else 1 for elem in nist_label]


    conf_mat_graded = confusion_matrix(gpt_label, nist_label)
    print(conf_mat_graded)

    lines = []
    for row in conf_mat_graded:
        line = ''
        for col in row:
            line+= str(col) + '  &  '
        lines.append(line)

    for line in lines:
        print(line)
        
    conf_mat_binary = confusion_matrix(gpt_binary, nist_binary)
    print(conf_mat_binary)

    lines = []
    for row in conf_mat_binary:
        line = ''
        for col in row:
            line+= str(col) + '  &  '
        lines.append(line)

    for line in lines:
        print(line)
    

    return


In [25]:
path_nist = 'pools/human_qrels_tab'
subset = get_test_subset()
models_main_name = {
                    'gpt3.5-one-shot-pool-V2': 'one-shot',
                    'gpt3.5-one-shot-pool-V2-temp0': 'one-shot (tmp=0)',
                    'gpt3.5-two-shot-pool' : 'two-shot',
                    'gpt3.5-two-shot-pool-V2-temp0': 'two-shot (tmp=0)',
                    'gpt3.5-zero-shot-pool':'zero-shot',
                    'gpt3.5-zero-shot-paul-pool-temp0-': 'zero-shot (tmp=0)',
                    'Llama-3-FT-pool': 'Llama-3 FT',
                    'Llama-3-inst-FT-pool': 'Llama-3-inst FT',
                    'Llama-3-zero-pool': 'Llama-3 zero-shot',
                    'Llama-3-inst-zero-pool': 'Llama-3-inst zero-shot'
                    }


model_names = ['gpt3.5-one-shot-pool-V2-temp0', 'Llama-3-FT-pool', 'Llama-3-inst-FT-pool']
for model_name in model_names:
    print(model_name)
    path_gpt = 'outputs/'+model_name+'.txt'
    confusion_matrix_print(path_gpt, path_nist, subset)
    print('*****************')



gpt3.5-one-shot-pool-V2-temp0
[[221   7  21   9   0]
 [ 71  20  37   8   4]
 [ 42   8  57  22   4]
 [ 32  17  75  37   3]
 [ 23   8 105  74  12]]
221  &  7  &  21  &  9  &  0  &  
71  &  20  &  37  &  8  &  4  &  
42  &  8  &  57  &  22  &  4  &  
32  &  17  &  75  &  37  &  3  &  
23  &  8  &  105  &  74  &  12  &  
[[319  79]
 [130 389]]
319  &  79  &  
130  &  389  &  
*****************
Llama-3-FT-pool
[[359  22  42   5   1]
 [  2   1   9   2   0]
 [ 22  30 198  52   4]
 [  6   7  42  79  11]
 [  0   0   4  12   7]]
359  &  22  &  42  &  5  &  1  &  
2  &  1  &  9  &  2  &  0  &  
22  &  30  &  198  &  52  &  4  &  
6  &  7  &  42  &  79  &  11  &  
0  &  0  &  4  &  12  &  7  &  
[[384  59]
 [ 65 409]]
384  &  59  &  
65  &  409  &  
*****************
Llama-3-inst-FT-pool
[[310  13  12   2   1]
 [  4   0   6   0   0]
 [ 68  39 232  62   6]
 [  7   8  43  80  13]
 [  0   0   2   6   3]]
310  &  13  &  12  &  2  &  1  &  
4  &  0  &  6  &  0  &  0  &  
68  &  39  &  232  &  62  &  6 