In [1]:
import string
import numpy as np
import pandas as pd
from itertools import product
from collections import Counter
from datasets import load_dataset
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score

In [2]:
def metric_calculation(pred, gt):    
    acc=accuracy_score(gt, pred)
    f1=f1_score(gt, pred, average='macro')
    confusion=confusion_matrix(gt, pred)
    fpr=confusion[0,1]/len(gt) ## predict to be 1; actual 0
    fnr=confusion[1,0]/len(gt) ## predict to be 0; actual 1
    return acc, f1, fpr, fnr

In [3]:
def post_processing(pred, model):

    if model=='mistral':
        new_pred = [p.replace('</s>', '').split()[0] for p in pred]
        new_pred = np.array([int(float(i)) if i in ['0', '0.0', '1', '1.0'] else 2 for i in new_pred])
    else:
        new_pred=[]        
        for p in pred:
            if (p.split()[0]=='0') or (p.split()[0]=='1'):
                new_pred.append(p.split()[0])
            else:
                p = p.lower().replace('</s>', '').replace('boxed', '')
                splits=[s for s in p.lower().split('\n') if s != '']
                p = ' '.join(splits[-3:]).translate(str.maketrans('', '', string.punctuation))                
                if 'response' in p:
                    try: new_pred.append([t for t in p.split('response')[-1].split() if t.isnumeric()][0])
                    except: new_pred.append(2)
                elif 'output' in p:
                    try: new_pred.append([t for t in p.split('output')[-1].split() if t.isnumeric()][0])
                    except: new_pred.append(2)
                elif 'return' in p:
                    try: new_pred.append([t for t in p.split('return')[-1].split() if t.isnumeric()][0])
                    except: new_pred.append(2)
                elif 'result' in p:
                    try: new_pred.append([t for t in p.split('result')[-1].split() if t.isnumeric()][0])
                    except: new_pred.append(2)
                elif 'plaintext' in p:
                    try: new_pred.append([t for t in p.split('plaintext')[-1].split() if t.isnumeric()][0])
                    except: new_pred.append(2)
                elif 'json' in p:
                    try: new_pred.append([t for t in p.split('json')[-1].split() if t.isnumeric()][0])
                    except: new_pred.append(2)
                else:
                    try: new_pred.append(p.split()[0])
                    except:new_pred.append(2)
        new_pred = np.array([int(float(i)) if i in ['0', '0.0', '1', '1.0'] else 2 for i in new_pred])
    return new_pred

In [4]:
ds = load_dataset("beanham/spatial_join_dataset")
test=ds['test']
gt=np.array(test['label'])
configs = [
    'zero_shot_with_heur_value_angle',
    'zero_shot_with_heur_value_distance',
    'zero_shot_with_heur_value_area',
    'zero_shot_with_heur_value_angle_distance',
    'zero_shot_with_heur_value_angle_area',
    'zero_shot_with_heur_value_distance_area',
    'zero_shot_with_heur_value_all',    
    'few_shot_with_heur_value_angle',
    'few_shot_with_heur_value_distance',
    'few_shot_with_heur_value_area',
    'few_shot_with_heur_value_angle_distance',
    'few_shot_with_heur_value_angle_area',
    'few_shot_with_heur_value_distance_area',
    'few_shot_with_heur_value_all',
]

In [5]:
results=[]
models=['mistral', '4o_mini', 'qwen_plus', '4o']
for model in models:
    print(f'Model: {model}...')
    for config in configs:
        pred=np.load(f'base/{model}/{model}_{config}_cot.npy')
        pred=post_processing(pred, model)
        metrics=metric_calculation(pred, gt)
        results.append([config, model, round(metrics[0],3), metrics[1]])
results=pd.DataFrame(results, columns=['config', 'model', 'acc', 'f1'])

Model: mistral...
Model: 4o_mini...
Model: qwen_plus...
Model: 4o...


In [None]:
pred=np.load(f'base/{model}/{model}_{config}_cot.npy')

In [6]:
results[results['model']=='qwen_plus']

Unnamed: 0,config,model,acc,f1
28,zero_shot_with_heur_value_angle,qwen_plus,0.921,0.616536
29,zero_shot_with_heur_value_distance,qwen_plus,0.723,0.437245
30,zero_shot_with_heur_value_area,qwen_plus,0.605,0.272416
31,zero_shot_with_heur_value_angle_distance,qwen_plus,0.927,0.62244
32,zero_shot_with_heur_value_angle_area,qwen_plus,0.887,0.607156
33,zero_shot_with_heur_value_distance_area,qwen_plus,0.725,0.451601
34,zero_shot_with_heur_value_all,qwen_plus,0.789,0.558434
35,few_shot_with_heur_value_angle,qwen_plus,0.935,0.620474
36,few_shot_with_heur_value_distance,qwen_plus,0.816,0.525895
37,few_shot_with_heur_value_area,qwen_plus,0.716,0.456935
