In [1]:
import numpy as np
import pandas as pd
import scienceplots
import seaborn as sns
import seaborn.objects as so
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from datasets import load_dataset,concatenate_datasets
from sklearn.metrics import accuracy_score, confusion_matrix

In [2]:
def filter_data(data, metric_name, metric_value):    
    if metric_name == 'degree':
        p_id = data.filter(lambda x: ((x['label']==1) & (x['min_angle']<=metric_value)))
        p_ood = data.filter(lambda x: ((x['label']==1) & (x['min_angle']>metric_value)))        
        n_id = data.filter(lambda x: ((x['label']==0) & (x['min_angle']>metric_value)))
        n_ood = data.filter(lambda x: ((x['label']==0) & (x['min_angle']<=metric_value)))
        id_data=concatenate_datasets([p_id,n_id])
        ood_data=concatenate_datasets([p_ood,n_ood])        
    elif metric_name == 'distance':
        p_id = data.filter(lambda x: ((x['label']==1) & (x['euc_dist']>=metric_value)))
        p_ood = data.filter(lambda x: ((x['label']==1) & (x['euc_dist']<metric_value)))        
        n_id = data.filter(lambda x: ((x['label']==0) & (x['euc_dist']<metric_value)))
        n_ood = data.filter(lambda x: ((x['label']==0) & (x['euc_dist']>=metric_value)))        
        id_data=concatenate_datasets([p_id,n_id])
        ood_data=concatenate_datasets([p_ood,n_ood])    
    return id_data, ood_data

In [3]:
def metric_calculation(data):
    gt=data['label']
    pred=data['pred']
    acc=accuracy_score(gt, pred)
    confusion=confusion_matrix(gt, pred)
    fpr=confusion[0,1]/len(gt) ## predict to be 1; actual 0
    fnr=confusion[1,0]/len(gt) ## predict to be 0; actual 1
    return acc,fpr,fnr

In [4]:
def post_processing(data, model, metric_name, metric_value):
    
    ## load ground truth & predictions
    gt=np.array(data['label'])
    if model == 'heuristic':
        if metric_name=='degree':
            pred=np.array(data['min_angle'])<=metric_value
        elif metric_name == 'distance':
            pred=np.array(data['euc_dist'])>=metric_value
    elif model == 'bert':
        pred=np.load(f'{metric_name}/{model}_{metric_name}_{metric_value}_weak.npy')        
    else:
        raw_pred=np.load(f'{metric_name}/{model}_{metric_name}_{metric_value}_weak.npy')
        pred=[]
        for i in range(len(raw_pred)):
            try:
                pred.append(int(raw_pred[i].replace('<|eot_id|>', '')\
                                      .replace('</s>', '')\
                                      .split('Label:')[1]\
                                      .strip()))
            except:
                pred.append(2)
    data=data.add_column("pred", np.array(pred))
    id_data, ood_data=filter_data(data, metric_name, metric_value)
    
    return data,id_data,ood_data

In [12]:
ds = load_dataset("beanham/spatial_join_dataset")
test=ds['test']
metric_name='degree'
models=['heuristic','llama3']
metric_values=[1]
results=[]
for model in models:
    for metric_value in metric_values:
        data,id_data,ood_data=post_processing(test,model,metric_name,metric_value)
        acc,fpr,fnr=metric_calculation(data)
        results.append([model, metric_name, metric_value, acc,fpr,fnr])
results=pd.DataFrame(results, columns=['model','metric','metric_value','acc','fpr','fnr'])
results

Unnamed: 0,model,metric,metric_value,acc,fpr,fnr
0,heuristic,degree,1,0.857283,0.030629,0.112089
1,llama3,degree,1,0.825676,0.037797,0.134572


# Distance

In [13]:
metric_name='distance'
models=['heuristic','llama3']
metric_values=[1]
results=[]
for model in models:
    for metric_value in metric_values:
        data,id_data,ood_data=post_processing(test,model,metric_name,metric_value)
        acc,fpr,fnr=metric_calculation(data)
        results.append([model, metric_name, metric_value, acc,fpr,fnr])
results=pd.DataFrame(results, columns=['model','metric','metric_value','acc','fpr','fnr'])
results

Unnamed: 0,model,metric,metric_value,acc,fpr,fnr
0,heuristic,distance,1,0.846204,0.152493,0.001303
1,llama3,distance,1,0.854676,0.117628,0.025741
