In [1]:
import string
import numpy as np
import pandas as pd
from itertools import product
from collections import Counter
from datasets import load_dataset
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score

In [2]:
def metric_calculation(pred, gt):    
    acc=accuracy_score(gt, pred)
    f1=f1_score(gt, pred, average='macro')
    confusion=confusion_matrix(gt, pred)
    fpr=confusion[0,1]/len(gt) ## predict to be 1; actual 0
    fnr=confusion[1,0]/len(gt) ## predict to be 0; actual 1
    return acc, f1, fpr, fnr

In [3]:
def post_processing(pred, model):

    if model=='mistral':
        new_pred = [p.replace('</s>', '').split()[0] for p in pred]
        new_pred = np.array([int(float(i)) if i in ['0', '0.0', '1', '1.0'] else 2 for i in new_pred])
    else:
        new_pred=[]        
        for p in pred:
            if (p.split()[0]=='0') or (p.split()[0]=='1'):
                new_pred.append(p.split()[0])
            else:
                p = p.lower().replace('</s>', '').replace('boxed', '')
                splits=[s for s in p.lower().split('\n') if s != '']
                p = ' '.join(splits[-3:]).translate(str.maketrans('', '', string.punctuation))                
                if 'response' in p:
                    try: new_pred.append([t for t in p.split('response')[-1].split() if t.isnumeric()][0])
                    except: new_pred.append(2)
                elif 'output' in p:
                    try: new_pred.append([t for t in p.split('output')[-1].split() if t.isnumeric()][0])
                    except: new_pred.append(2)
                elif 'return' in p:
                    try: new_pred.append([t for t in p.split('return')[-1].split() if t.isnumeric()][0])
                    except: new_pred.append(2)
                elif 'result' in p:
                    try: new_pred.append([t for t in p.split('result')[-1].split() if t.isnumeric()][0])
                    except: new_pred.append(2)
                elif 'plaintext' in p:
                    try: new_pred.append([t for t in p.split('plaintext')[-1].split() if t.isnumeric()][0])
                    except: new_pred.append(2)
                elif 'json' in p:
                    try: new_pred.append([t for t in p.split('json')[-1].split() if t.isnumeric()][0])
                    except: new_pred.append(2)
                else:
                    try: new_pred.append(p.split()[0])
                    except:new_pred.append(2)
        new_pred = np.array([int(float(i)) if i in ['0', '0.0', '1', '1.0'] else 2 for i in new_pred])
    return new_pred

In [4]:
ds = load_dataset("beanham/spatial_join_dataset")
test=ds['test']
gt=np.array(test['label'])
configs = [
    "few_shot_no_heur_cot",
    "few_shot_with_heur_hint_all_cot",
    "few_shot_with_heur_value_all_cot"
]

In [9]:
results=[]
models=['4o_mini', 'qwen_plus']
for model in models:
    print(f'Model: {model}...')
    for config in configs:
        pred=np.load(f'base/{model}_cot/{model}_{config}.npy')
        pred=post_processing(pred, model)
        metrics=metric_calculation(pred, gt)
        results.append([config, model, round(metrics[0],3), metrics[1]])
results=pd.DataFrame(results, columns=['config', 'model', 'acc', 'f1'])

Model: 4o_mini...
Model: qwen_plus...


In [26]:
model='4o_mini'
config='few_shot_with_heur_hint_all_cot'
pred=np.load(f'base/{model}_cot/{model}_{config}.npy')
proc_pred=post_processing(pred, model)
index=np.where(proc_pred==2)[0]

In [30]:
print(pred[index[6]])

To determine whether the sidewalk runs alongside the road based on the provided conditions, we will evaluate the following steps:

1. **Extract Coordinates**: 
   - Sidewalk coordinates: 
     ```
     [[-122.10419000000002, 47.55392949999999], 
      [-122.1041552, 47.553988100000005], 
      [-122.1040971, 47.554099499999985], 
      [-122.1040119, 47.5542421], 
      [-122.1039511, 47.5543572], 
      [-122.10393710000001, 47.554414699999995], 
      [-122.1039292, 47.5544353], 
      [-122.10391910000001, 47.55444569999999], 
      [-122.103895, 47.554466499999975], 
      [-122.10387349999999, 47.5544882], 
      [-122.10386210000001, 47.5545036], 
      [-122.10385670000001, 47.55451719999999], 
      [-122.10385610000002, 47.55454609999999], 
      [-122.10386080000002, 47.55457099999999], 
      [-122.1038695, 47.55459230000001], 
      [-122.1038782, 47.5546077], 
      [-122.1039043, 47.55462709999999], 
      [-122.10393649999999, 47.554646999999996], 
      [-122.1039681, 4