In [1]:
import numpy as np
from itertools import product
from collections import Counter
from datasets import load_dataset
from sklearn.metrics import accuracy_score, confusion_matrix

In [2]:
def metric_calculation(pred, gt):    
    acc=accuracy_score(gt, pred)    
    confusion=confusion_matrix(gt, pred)
    fpr=confusion[0,1]/len(gt) ## predict to be 1; actual 0
    fnr=confusion[1,0]/len(gt) ## predict to be 0; actual 1
    return acc, fpr, fnr

In [3]:
def post_processing(pred):
    new_pred=[]
    for i in pred:
        i=i.lower()
        if 'response' in i:
            try: new_pred.append(i.split('response')[1].split()[1].replace('</s>', ''))
            except: new_pred.append(2)
        elif 'output' in i:
            try: new_pred.append(i.split('output')[1].split()[1].replace('</s>', ''))
            except: new_pred.append(2)
        else:
            try: new_pred.append(i.split()[0].replace('</s>', ''))
            except:new_pred.append(2)
    new_pred = np.array([int(float(i)) if i in ['0', '0.0', '1', '1.0'] else 2 for i in new_pred])
    return new_pred

In [4]:
ds = load_dataset("beanham/spatial_join_dataset")
test=ds['test']
gt=np.array(test['label'])
models = ['llama3', 'mistral', '4o_mini']
methods = ['zero_shot', 'few_shot']    
modes = ['no_heur', 'with_heur_hint', 'with_heur_value']
heuristics = ['angle', 'distance', 'comb']
configs=['_'.join(i) for i in list(product(methods, modes, heuristics))]
configs.remove('zero_shot_no_heur_angle')
configs.remove('zero_shot_no_heur_distance')
configs.remove('zero_shot_no_heur_comb')
configs.remove('few_shot_no_heur_angle')
configs.remove('few_shot_no_heur_distance')
configs.remove('few_shot_no_heur_comb')
configs.append('zero_shot_no_heur')
configs.append('few_shot_no_heur')
configs.sort()

In [6]:
for model in models:
    print('======================================')
    print(f'Model: {model}...')
    for config in configs:
        print('----------------------------------')
        pred=np.load(f'base/{model}/{model}_{config}.npy')
        pred=post_processing(pred)
        print('  ',Counter(pred))
        print(f'   {config}:', metric_calculation(pred, gt[:len(pred)]))

Model: llama3...
----------------------------------
   Counter({0: 1635, 1: 1282, 2: 152})
   few_shot_no_heur: (0.4701857282502444, 0.1638970348647768, 0.31638970348647766)
----------------------------------
   Counter({0: 2135, 1: 775, 2: 159})
   few_shot_with_heur_hint_angle: (0.47539915281850764, 0.0765721733463669, 0.3962202671880091)
----------------------------------
   Counter({0: 2154, 1: 690, 2: 225})
   few_shot_with_heur_hint_comb: (0.4496578690127077, 0.07331378299120235, 0.40371456500488756)
----------------------------------
   Counter({0: 1949, 1: 868, 2: 252})
   few_shot_with_heur_hint_distance: (0.44639947865754315, 0.09938090583251874, 0.37210817855979145)
----------------------------------
   Counter({1: 1712, 0: 1137, 2: 220})
   few_shot_with_heur_value_angle: (0.6001955034213099, 0.16422287390029325, 0.1638970348647768)
----------------------------------
   Counter({0: 1644, 1: 1029, 2: 396})
   few_shot_with_heur_value_comb: (0.45356793743890517, 0.11665037471