In [25]:
import numpy as np
import pandas as pd
from itertools import product
from collections import Counter
from datasets import load_dataset
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score

In [26]:
def metric_calculation(pred, gt):    
    acc=accuracy_score(gt, pred)
    f1=f1_score(gt, pred, average='macro')
    confusion=confusion_matrix(gt, pred)
    fpr=confusion[0,1]/len(gt) ## predict to be 1; actual 0
    fnr=confusion[1,0]/len(gt) ## predict to be 0; actual 1
    return acc, f1, fpr, fnr

In [27]:
def post_processing(pred):
    new_pred=[]
    for i in pred:
        i=i.lower()
        if 'response' in i:
            try: new_pred.append(i.split('response')[1].split()[1].replace('</s>', ''))
            except: new_pred.append(2)
        elif 'output' in i:
            try: new_pred.append(i.split('output')[1].split()[1].replace('</s>', ''))
            except: new_pred.append(2)
        else:
            try: new_pred.append(i.split()[0].replace('</s>', ''))
            except:new_pred.append(2)
    new_pred = np.array([int(float(i)) if i in ['0', '0.0', '1', '1.0'] else 2 for i in new_pred])
    return new_pred

In [28]:
ds = load_dataset("beanham/spatial_join_dataset")
test=ds['test']
gt=np.array(test['label'])
models=['llama3', '4o_mini', 'qwen']
configs = [
    'zero_shot_no_heur',
    'zero_shot_with_heur_hint_angle',
    'zero_shot_with_heur_hint_distance',
    'zero_shot_with_heur_hint_comb',
    'zero_shot_with_heur_hint_all',
    'zero_shot_with_heur_value_angle',
    'zero_shot_with_heur_value_distance',
    'zero_shot_with_heur_value_comb',
    'zero_shot_with_heur_value_all',
    'few_shot_no_heur',
    'few_shot_with_heur_hint_angle',
    'few_shot_with_heur_hint_distance',
    'few_shot_with_heur_hint_comb',
    'few_shot_with_heur_hint_all',
    'few_shot_with_heur_value_angle',
    'few_shot_with_heur_value_distance',
    'few_shot_with_heur_value_comb',
    'few_shot_with_heur_value_all',
]
## evaluate on a subset
np.random.seed(100)
index=np.random.randint(0, 3069, 1000)

### Heuristics

In [31]:
## distance
for i in [1,2,3,4,5]:
    pred=np.array(test['min_euc_dist'])>=i
    metrics=metric_calculation(pred, gt)
    print(i, metrics[0], metrics[1])

## area
for i in [0.1,0.2,0.3,0.4,0.5]:
    pred=np.array(test['max_area'])>=i
    metrics=metric_calculation(pred, gt)
    print(i, metrics[0], metrics[1])

## combination
angles=[1,2,5,10,20]
distances=[1,2,3,4,5]
test_gt=np.array(test['label'])
combs=list(product(angles, distances))
accuracy=[]
for a,d in combs:
    pred=(np.array(test['min_euc_dist'])>=d)&(np.array(test['min_angle'])<=a)
    accuracy.append([a,d,accuracy_score(test_gt, pred)])
accuracy_comb=pd.DataFrame(accuracy,columns=['angle','distance','acc'])

## all
angles=[1,2,5,10,20]
distances=[1,2,3,4,5]
max_area=[.1,.2,.3,.4,.5]
test_gt=np.array(test['label'])
combs=list(product(angles, distances, max_area))
accuracy=[]
for a,d,m in combs:
    pred=(np.array(test['min_euc_dist'])>=d)&(np.array(test['min_angle'])<=a)&(np.array(test['max_area'])>=m)
    accuracy.append([a,d,m,accuracy_score(test_gt, pred)])
accuracy_all=pd.DataFrame(accuracy,columns=['angle','distance','area','acc'])

0.1 0.655 0.5076991570953394
0.2 0.701 0.6018636509502675
0.3 0.735 0.6673357611546085
0.4 0.753 0.7164566494359554
0.5 0.654 0.6494441765180283


### LLMs

In [9]:
results=[]
models=['llama3', '4o_mini', 'qwen']
for model in models:
    print(f'Model: {model}...')
    for config in configs:
        pred=np.load(f'base/{model}/{model}_{config}.npy')
        pred=post_processing(pred)        
        if len(pred)==1000:
            metrics=metric_calculation(pred, gt)
        else:
            metrics=metric_calculation(pred[index], gt)
        results.append([config, model, metrics[0], metrics[1]])
results=pd.DataFrame(results, columns=['config', 'model', 'acc', 'f1'])

Model: llama3...
Model: 4o_mini...
Model: qwen...


In [12]:
results[results['model']=='qwen']

Unnamed: 0,config,model,acc,f1
36,zero_shot_no_heur,qwen,0.603,0.39211
37,zero_shot_with_heur_hint_angle,qwen,0.645,0.496562
38,zero_shot_with_heur_hint_distance,qwen,0.609,0.405572
39,zero_shot_with_heur_hint_comb,qwen,0.676,0.564495
40,zero_shot_with_heur_hint_all,qwen,0.696,0.607906
41,zero_shot_with_heur_value_angle,qwen,0.946,0.942522
42,zero_shot_with_heur_value_distance,qwen,0.732,0.65684
43,zero_shot_with_heur_value_comb,qwen,0.95,0.946998
44,zero_shot_with_heur_value_all,qwen,0.948,0.944533
45,few_shot_no_heur,qwen,0.633,0.565963
