In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import seaborn.objects as so
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from datasets import load_dataset,concatenate_datasets
from sklearn.metrics import accuracy_score, confusion_matrix

In [2]:
def filter_data(data, metric_name, metric_value):    
    if metric_name == 'degree':
        p_id = data.filter(lambda x: ((x['label']==1) & (x['min_angle']<=metric_value)))
        p_ood = data.filter(lambda x: ((x['label']==1) & (x['min_angle']>metric_value)))        
        n_id = data.filter(lambda x: ((x['label']==0) & (x['min_angle']>metric_value)))
        n_ood = data.filter(lambda x: ((x['label']==0) & (x['min_angle']<=metric_value)))
        id_data=concatenate_datasets([p_id,n_id])
        ood_data=concatenate_datasets([p_ood,n_ood])        
    elif metric_name == 'distance':
        p_id = data.filter(lambda x: ((x['label']==1) & (x['euc_dist']>=metric_value)))
        p_ood = data.filter(lambda x: ((x['label']==1) & (x['euc_dist']<metric_value)))        
        n_id = data.filter(lambda x: ((x['label']==0) & (x['euc_dist']<metric_value)))
        n_ood = data.filter(lambda x: ((x['label']==0) & (x['euc_dist']>=metric_value)))        
        id_data=concatenate_datasets([p_id,n_id])
        ood_data=concatenate_datasets([p_ood,n_ood])    
    return id_data, ood_data

In [3]:
def metric_calculation(data):
    gt=data['label']
    pred=data['pred']
    acc=accuracy_score(gt, pred)
    _,fpc,fnc,_=confusion_matrix(gt, pred).ravel()
    fpr=fpc/len(data)
    fnr=fnc/len(data)
    return acc,fpr,fnr

In [4]:
def post_processing(data, model, metric_name, metric_value):
    
    ## load ground truth & predictions
    gt=np.array(data['label'])
    if model == 'heuristic':
        if metric_name=='degree':
            pred=np.array(data['min_angle'])<=metric_value
        elif metric_name == 'distance':
            pred=np.array(data['euc_dist'])>=metric_value
    else:
        pred=np.load(f'{metric_name}/{model}_{metric_name}_{metric_value}.npy')
        pred=np.array([int(i.replace('<|eot_id|>', '')\
                           .replace('</s>', '')\
                           .split('Label:')[1]\
                           .strip()) for i in pred])
    data=data.add_column("pred", pred)
    id_data, ood_data=filter_data(data, metric_name, metric_value)
    
    return data,id_data,ood_data

In [5]:
ds = load_dataset("beanham/spatial_join_dataset")
test=ds['test']

In [6]:
zero_no_exp=np.load('base/llama3/llama3_zero_shot_no_exp.npy')
zero_with_exp=np.load('base/llama3/llama3_zero_shot_with_exp.npy')
few_no_exp=np.load('base/llama3/llama3_few_shot_no_exp.npy')
few_with_exp=np.load('base/llama3/llama3_few_shot_with_exp.npy')

In [10]:
few_with_exp

array([": {0}\n\n### Second Exmaple:\nSidewalk: {'coordinates': [[-122.13341579999998, 47.54698270000001], [-122.1334011, 47.5468383]], 'type': 'LineString'}\nRoad: {'coordinates': [[-122.1328993, 47.5458957], [-122.1329478, 47.5460104], [-122.1330183, 47.5461317], [-122.1330885, 47.5462402], [-122.1333795, 47.5466214], [-122.1334411, 47.5467369], [-122.1334757, 47.5468199], [-122.1335148, 47.5469582]], 'type': 'LineString'}\nmin_angle: 9.973873687169487\nmin_distance: 8.72605420848234\n",
       ": {0}\n\n### Second Exmaple:\nSidewalk: {'coordinates': [[-122.13341579999998, 47.54698270000001], [-122.1334011, 47.5468383]], 'type': 'LineString'}\nRoad: {'coordinates': [[-122.1328993, 47.5458957], [-122.1329478, 47.5460104], [-122.1330183, 47.5461317], [-122.1330885, 47.5462402], [-122.1333795, 47.5466214], [-122.1334411, 47.5467369], [-122.1334757, 47.5468199], [-122.1335148, 47.5469582]], 'type': 'LineString'}\nmin_angle: 9.973873687169487\nmin_distance: 8.72605420848234\n",
       ": 