In [1]:
import gc
import torch
import argparse
import numpy as np

from utils import *
from tqdm import tqdm
from os import path, makedirs
from datasets import load_dataset
#from unsloth import FastLanguageModel
from huggingface_hub import login as hf_login

## evaluation function
def evaluate(model, tokenizer, data):
    outputs=[]
    for text in tqdm(data['text']):
        inputs = tokenizer(text, return_tensors = "pt").to("cuda")
        response = model.generate(**inputs, max_new_tokens = 10)
        response = tokenizer.decode(response[0]).split('Response')[1]
        outputs.append(response)
    return outputs

In [2]:
from transformers import AutoTokenizer
tokenizer=AutoTokenizer.from_pretrained('meta-llama/Llama-3.1-8B-Instruct')

In [3]:
#-------------------
# Parameters
#-------------------    
parser = argparse.ArgumentParser()
parser.add_argument('--model_id', type=str, default='llama3')
parser.add_argument('--dataset', type=str, default='beanham/spatial_join_dataset')
parser.add_argument('--max_seq_length', type=int, default=2048)
parser.add_argument('--device', type=str, default='auto')
args = parser.parse_args(args=[])
args.save_path=f'inference_results/base/{args.model_id}/'
if not path.exists(args.save_path):
    makedirs(args.save_path)    
data = load_dataset(args.dataset)
methods = ['zero_shot', 'one_shot']
modes = ['no_exp', 'with_exp']
hf_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [20]:
def formatting_prompts_func(example):
    output = ""                
    if method=='zero_shot':                
        if mode=='no_exp':
            input = "Sidewalk: "+str(example['sidewalk'])+"\nRoad: "+str(example['road'])
            text = zero_shot_alpaca_prompt.format(instruction_no_exp, input, output)
        else:
            input = "Sidewalk: "+str(example['sidewalk'])+\
                    "\nRoad: "+str(example['road'])+\
                    "\nmin_angle: "+str(example['min_angle'])+\
                    "\nmin_distance: "+str(example['euc_dist'])
            text = zero_shot_alpaca_prompt.format(instruction_with_exp, input, output)
    else:
        if mode=='no_exp':
            input = "Sidewalk: "+str(example['sidewalk'])+"\nRoad: "+str(example['road'])
            text = few_shot_alpaca_prompt.format(instruction_no_exp, example_one_no_exp, example_two_no_exp, input, output)
        else:
            input = "Sidewalk: "+str(example['sidewalk'])+\
                    "\nRoad: "+str(example['road'])+\
                    "\nmin_angle: "+str(example['min_angle'])+\
                    "\nmin_distance: "+str(example['euc_dist'])
            text = few_shot_alpaca_prompt.format(instruction_with_exp, example_one_with_exp, example_two_with_exp, input, output)
    return { "text" : text}

In [24]:
method='few_shot'
mode='with_exp'
test = data['test'].map(formatting_prompts_func)

Map:   0%|          | 0/3069 [00:00<?, ? examples/s]

In [29]:
print(test['text'][10])

### Instruction:
You are a helpful geospatial analysis assistant. I will provide you with a pair of (sidewalk, road) GeoJSON, along with two key statistics:

1. min_angle: The minimum angle (in degrees) between the sidewalk and the road.
2. min_distance: The minimum distance between the sidewalk and the road.

Your task is to determine whether the sidewalk runs alongside the road by evaluating the following conditions:

1. Adjacency: The sidewalk and road should be in close proximity, meaning they are near each other but do not overlap or intersect. The min_distance value helps quantify this proximity.
2. Parallelism: The sidewalk should be approximately parallel to the road, with only a small angle difference between their orientations. The min_angle value provides a measure of this alignment.

If both conditions are satisfied, return 1. Otherwise, return 0.

### First Exmaple:
Sidewalk: {'coordinates': [[-122.15646960000001, 47.58741259999999], [-122.1562564, 47.58744089999999]], 'ty

In [31]:
len(tokenizer(test['text'][10])['input_ids'])

687