## Part 1: Define metrics

In [1]:
from urllib.parse import urlparse, parse_qs

def parse_url(url):
    """
    Parse a URL into its components.
    """
    parsed_url = urlparse(url)
    query_params = parse_qs(parsed_url.query)
    return {
        'path': parsed_url.path,
        'params': query_params
    }

def is_url_structure_matching(candidate, reference):
    """
    Compare the path and query parameters of the candidate and reference URLs.
    """
    if candidate['path'] != reference['path']:
        return False

    if sorted(candidate['params'].keys()) != sorted(reference['params'].keys()):
        return False

    for key in reference['params']:
        if key not in candidate['params']:
            return False
        if sorted(candidate['params'][key]) != sorted(reference['params'][key]):
            return False
    
    return True

def evaluate_get_request_accuracy(generated_url, reference_url):
    """
    Evaluate if the generated GET request is equivalent to the reference GET request.
    """
    candidate = parse_url(generated_url)
    reference = parse_url(reference_url)
    
    return is_url_structure_matching(candidate, reference)

def score_ast_batched(preds, refs):
    evals = tuple(map(evaluate_get_request_accuracy, preds, refs))
    return sum(evals) / len(evals)

In [2]:
import evaluate
bert_scorer = evaluate.load("bertscore")
bert_score_fn = lambda preds, refs: bert_scorer.compute(predictions=preds, references=refs, lang="en", model_type="microsoft/codebert-base", num_layers=12, device="cuda")

In [3]:
preds = ["/v3/query/?q=symbol:ZFAND4&species=mouse"]
refs = ["/v3/query/?species=mouse&q=symbol:ZFAND4"]

print("AST eval", score_ast_batched(preds, refs))
print("BERT Score", bert_score_fn(preds, refs))

AST eval 1.0
BERT Score {'precision': [0.9818707704544067], 'recall': [0.9818708300590515], 'f1': [0.9818708300590515], 'hashcode': 'microsoft/codebert-base_L12_no-idf_version=0.3.12(hug_trans=4.43.4)'}


## Part 2: Load models

In [4]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [5]:
with open("gene_query_docs.txt", "r") as doc_fd:
    docs = doc_fd.read()

with open("data/original/compact_desc_with_context.csv") as desc_fd:
    description = desc_fd.read()

In [6]:
import inspect
prompt_template = inspect.cleandoc("""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Use the documentation and schema to complete the user-given task.
Docs: {docs}\n Schema: {description}<|eot_id|><|start_header_id|>user<|end_header_id|>
Write an API call to do the following - {instruction}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
""")

In [7]:
prompt_gen = lambda inst: prompt_template.format(docs=docs, description=description, instruction=inst)
prompt = prompt_gen("Find the UniProt ID for the ENSG00000103187 gene in human. Limit the search to Ensembl gene IDs.")
print("Start", prompt[:250])
print("End", prompt[-150:])

Start <|begin_of_text|><|start_header_id|>system<|end_header_id|>
Use the documentation and schema to complete the user-given task.
Docs: Gene query service

This page describes the reference for MyGene.info gene query web service. It’s also recommended to
End nd the UniProt ID for the ENSG00000103187 gene in human. Limit the search to Ensembl gene IDs.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>


In [8]:
import outlines

model = AutoModelForCausalLM.from_pretrained("models/meta_llama3_1", torch_dtype=torch.bfloat16, output_attentions=True).to("cuda")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("models/meta_llama3_1")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [9]:
def evaluate(api_call: str):
    return None

model = outlines.models.Transformers(model, tokenizer)
generator = outlines.generate.json(model, evaluate)

In [10]:
sample_api_call = generator([prompt])
sample_api_call

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


{'api_call': 'GET /v3/query?q=ensembl.Gene:ENSG00000103187\\\\[symbol"]=\\'}

## Part 3: Evaluate

In [11]:
from datasets import load_dataset

dataset = load_dataset("moltres23/biothings-query-instruction-pairs", split="test")
dataset

Dataset({
    features: ['output', 'instruction'],
    num_rows: 712
})

In [12]:
import tqdm

BATCH_SIZE = 1
all_responses = []
prompt_gen = lambda inst: prompt_template.format(docs=docs, description=description, instruction=inst)
with torch.no_grad():
    for idx in tqdm.tqdm(range(0, len(dataset), BATCH_SIZE)):
        batch = dataset[idx:(idx + BATCH_SIZE)]
        batched_inputs = list(map(prompt_gen, batch["instruction"]))
        batch_responses = generator(batched_inputs)
        all_responses.extend(batch_responses.values())

 12%|██████▉                                                    | 84/712 [13:48<1:43:13,  9.86s/it]


KeyboardInterrupt: 

In [19]:
import pickle

with open('responses.pkl', 'wb') as fd:
   pickle.dump(all_responses, fd)