In [1]:
import numpy as np
from itertools import product
from collections import Counter
from datasets import load_dataset
from sklearn.metrics import accuracy_score, confusion_matrix

In [2]:
def metric_calculation(pred, gt):    
    acc=accuracy_score(gt, pred)    
    confusion=confusion_matrix(gt, pred)
    fpr=confusion[0,1]/len(gt) ## predict to be 1; actual 0
    fnr=confusion[1,0]/len(gt) ## predict to be 0; actual 1
    return acc, fpr, fnr

In [3]:
def post_processing(pred):
    new_pred=[]
    for i in pred:
        i=i.lower()
        if 'response' in i:
            try: new_pred.append(i.split('response')[1].split()[1].replace('</s>', ''))
            except: new_pred.append(2)
        elif 'output' in i:
            try: new_pred.append(i.split('output')[1].split()[1].replace('</s>', ''))
            except: new_pred.append(2)
        else:
            try: new_pred.append(i.split()[0].replace('</s>', ''))
            except:new_pred.append(2)
    new_pred = np.array([int(float(i)) if i in ['0', '0.0', '1', '1.0'] else 2 for i in new_pred])
    return new_pred

In [4]:
ds = load_dataset("beanham/spatial_join_dataset")
test=ds['test']
gt=np.array(test['label'])
models = ['llama3', '4o_mini']
methods = ['zero_shot', 'few_shot'] 
modes = ['no_heur', 'with_heur_hint', 'with_heur_value']
heuristics = ['angle', 'distance', 'comb']
configs=['_'.join(i) for i in list(product(methods, modes, heuristics))]
configs.remove('zero_shot_no_heur_angle')
configs.remove('zero_shot_no_heur_distance')
configs.remove('zero_shot_no_heur_comb')
configs.remove('few_shot_no_heur_angle')
configs.remove('few_shot_no_heur_distance')
configs.remove('few_shot_no_heur_comb')
configs.append('zero_shot_no_heur')
configs.append('few_shot_no_heur')
configs.sort()

In [5]:
for model in models:
    print('======================================')
    print(f'Model: {model}...')
    for config in configs:
        print('----------------------------------')
        pred=np.load(f'base/{model}/{model}_{config}.npy')
        pred=post_processing(pred)
        print('  ',Counter(pred))
        print(f'   {config}:', metric_calculation(pred, gt[:len(pred)]))

Model: llama3...
----------------------------------
   Counter({0: 1659, 1: 1263, 2: 147})
   few_shot_no_heur: (0.47214076246334313, 0.1603128054740958, 0.3196480938416422)
----------------------------------
   Counter({0: 2172, 1: 759, 2: 138})
   few_shot_with_heur_hint_angle: (0.46953405017921146, 0.07852720755946563, 0.4069729553600521)
----------------------------------
   Counter({0: 1872, 1: 1079, 2: 118})
   few_shot_with_heur_hint_comb: (0.508634734441186, 0.11534701857282502, 0.33756924079504724)
----------------------------------
   Counter({0: 1839, 1: 1150, 2: 80})
   few_shot_with_heur_hint_distance: (0.5099380905832519, 0.1283805799934832, 0.3356142065819485)
----------------------------------
   Counter({1: 1717, 0: 1155, 2: 197})
   few_shot_with_heur_value_angle: (0.598892147279244, 0.16650374714890845, 0.1704138155751059)
----------------------------------
   Counter({0: 1747, 1: 979, 2: 343})
   few_shot_with_heur_value_comb: (0.5148256761159987, 0.0824372759856630

### 4o

In [8]:
ds = load_dataset("beanham/spatial_join_dataset")
test=ds['test']
gt=np.array(test['label'])
models = ['llama3', '4o_mini', '4o']
configs=['zero_shot_with_heur_value_angle',
         'zero_shot_with_heur_value_comb',
         'few_shot_with_heur_value_angle',
         'few_shot_with_heur_value_comb']

In [9]:
for model in models:
    print('======================================')
    print(f'Model: {model}...')
    for config in configs:
        print('----------------------------------')
        pred=np.load(f'base/{model}/{model}_{config}.npy')
        pred=post_processing(pred)
        print('  ',Counter(pred))
        print(f'   {config}:', metric_calculation(pred, gt[:len(pred)]))

Model: llama3...
----------------------------------
   Counter({1: 2171, 0: 896, 2: 2})
   zero_shot_with_heur_value_angle: (0.5962854349951124, 0.25513196480938416, 0.1479309221244705)
----------------------------------
   Counter({1: 2031, 0: 1028, 2: 10})
   zero_shot_with_heur_value_comb: (0.5513196480938416, 0.25448028673835127, 0.19094167481264254)
----------------------------------
   Counter({1: 1717, 0: 1155, 2: 197})
   few_shot_with_heur_value_angle: (0.598892147279244, 0.16650374714890845, 0.1704138155751059)
----------------------------------
   Counter({0: 1747, 1: 979, 2: 343})
   few_shot_with_heur_value_comb: (0.5148256761159987, 0.08243727598566308, 0.2909742587161942)
Model: 4o_mini...
----------------------------------
   Counter({1: 2005, 0: 1064})
   zero_shot_with_heur_value_angle: (0.9338546757901597, 0.059302704463994785, 0.006842619745845552)
----------------------------------
   Counter({1: 2189, 0: 880})
   zero_shot_with_heur_value_comb: (0.8836754643206256