## Part 1: Define metrics

In [1]:
from urllib.parse import urlparse, parse_qs

def parse_url(url):
    """
    Parse a URL into its components.
    """
    parsed_url = urlparse(url)
    query_params = parse_qs(parsed_url.query)
    return {
        'path': parsed_url.path,
        'params': query_params
    }

def is_url_structure_matching(candidate, reference):
    """
    Compare the path and query parameters of the candidate and reference URLs.
    """
    if candidate['path'] != reference['path']:
        return False

    if sorted(candidate['params'].keys()) != sorted(reference['params'].keys()):
        return False

    for key in reference['params']:
        if key not in candidate['params']:
            return False
        if sorted(candidate['params'][key]) != sorted(reference['params'][key]):
            return False
    
    return True

def evaluate_get_request_accuracy(generated_url, reference_url):
    """
    Evaluate if the generated GET request is equivalent to the reference GET request.
    """
    candidate = parse_url(generated_url)
    reference = parse_url(reference_url)
    
    return is_url_structure_matching(candidate, reference)

def score_ast_batched(preds, refs):
    evals = tuple(map(evaluate_get_request_accuracy, preds, refs))
    return sum(evals) / len(evals)

In [2]:
import evaluate
bert_scorer = evaluate.load("bertscore")
meteor_scorer = evaluate.load("meteor")
bert_score_fn = lambda preds, refs: bert_scorer.compute(predictions=preds, references=refs, lang="en", model_type="microsoft/codebert-base", num_layers=12, device="cuda")
meteor_score_fn = lambda preds, refs: meteor_scorer.compute(predictions=preds, references=refs)

[nltk_data] Downloading package wordnet to /home/atubati/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/atubati/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/atubati/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [3]:
preds = ["/v3/query/?q=symbol:ZFAND4&species=mouse"]
refs = ["/v3/query/?species=mouse&q=symbol:ZFAND4"]

print("AST eval", score_ast_batched(preds, refs))
print("BERT Score", bert_score_fn(preds, refs))
print("METEOR", meteor_score_fn(preds, refs))

AST eval 1.0
BERT Score {'precision': [0.9818707704544067], 'recall': [0.9818708300590515], 'f1': [0.9818708300590515], 'hashcode': 'microsoft/codebert-base_L12_no-idf_version=0.3.12(hug_trans=4.43.4)'}
METEOR {'meteor': 0.9067055393586005}


## Part 2: Load models

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
with open("gene_query_docs.txt", "r") as doc_fd:
    docs = doc_fd.read()

with open("data/original/compact_desc_with_context.csv") as desc_fd:
    description = desc_fd.read()

In [None]:
import outlines

@outlines.prompt
def default_prompt(instruction, docs, description):
    """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Use the documentation and schema to complete the user-given task.
Docs: {{ docs }}\n Schema: {{ description }}<|eot_id|><|start_header_id|>user<|end_header_id|>
{{ instruction }}. Write an API call.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

@outlines.prompt
def few_shot_prompt(instruction, examples, docs, description):
    """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Use the documentation and schema to complete the user-given task.
Docs: {{ docs }}\n Schema: {{ description }}<|eot_id|><|start_header_id|>user<|end_header_id|>
{{ instruction }}. Write an API call.

Examples
--------

{% for example in examples %}
Query: {{ example.instruction }}
API Call: {{ example.output }}

{% endfor %}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

In [None]:
inst = "Find the UniProt ID for the ENSG00000103187 gene in human. Limit the search to Ensembl gene IDs."
prompt = default_prompt(inst, docs, description)
print("Start", prompt[:250])
print("End", prompt[-150:])

#### Deprecated

In [None]:
import inspect
prompt_template = inspect.cleandoc("""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Use the documentation and schema to complete the user-given task.
Docs: {docs}\n Schema: {description}<|eot_id|><|start_header_id|>user<|end_header_id|>
{instruction}. Write an API call.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
""")

In [None]:
prompt_gen = lambda inst: prompt_template.format(docs=docs, description=description, instruction=inst)
prompt = prompt_gen("Find the UniProt ID for the ENSG00000103187 gene in human. Limit the search to Ensembl gene IDs.")
print("Start", prompt[:250])
print("End", prompt[-150:])

#### Latest

In [None]:
import outlines

model = AutoModelForCausalLM.from_pretrained("models/meta_llama3_1", torch_dtype=torch.bfloat16, output_attentions=True).to("cuda")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("models/meta_llama3_1")

In [None]:
def evaluate(api_call: str):
    return None

model = outlines.models.Transformers(model, tokenizer)
generator = outlines.generate.json(model, evaluate)
# generator = outlines.generate.regex(model, r"/v3/.+/.+")

In [None]:
sample_api_call = generator([prompt])
sample_api_call

## Part 3: Evaluate

In [4]:
from datasets import load_dataset

dataset = load_dataset("moltres23/biothings-query-instruction-pairs")
train_set, test_set = dataset["train"], dataset["test"]
train_set

Dataset({
    features: ['output', 'instruction'],
    num_rows: 2854
})

In [None]:
import tqdm
import random

random.seed(42)
BATCH_SIZE = 1
N_SHOT = 10  # size of ICL examples
all_responses = []

with torch.no_grad():
    for idx in tqdm.tqdm(range(0, len(test_set), BATCH_SIZE)):
        # hacky dict of lists to list of dicts conversion
        icl_example_indices = random.sample(range(len(train_set)), N_SHOT)  # same examples for each test batch
        icl_examples = [dict(zip(train_set[icl_example_indices].keys(), values)) for values in zip(*train_set[icl_example_indices].values())]

        batch = test_set[idx:(idx + BATCH_SIZE)]
        batched_inputs = list(map(few_shot_prompt, batch["instruction"], [icl_examples], [docs], [description]))
        batch_responses = generator(batched_inputs)
        all_responses.extend(batch_responses.values())

In [17]:
# NOTE: using regex match to get the response in the expected format
import re
def regex_match(string):
    match = re.search(r"/v3/.+", string)
    return match.group(0) if match is not None else ""

In [14]:
regex_match("GET https://mygene.info/v3/query?fields=human")

'/v3/query?fields=human'

In [15]:
import pickle
import numpy as np

def run_eval(responses_path):
    with open(responses_path, 'rb') as fd:
        all_responses = pickle.load(fd)

    dup = all_responses[:]
    all_responses = list(map(lambda x: regex_match(x), all_responses))
    empty_idxs = []
    for idx in range(len(all_responses)):
        if all_responses[idx] == "":
            empty_idxs.append(idx)

    ast_eval = score_ast_batched(all_responses, test_set["output"])
    bertscore_evals = bert_score_fn(all_responses, test_set["output"])
    meteor_scores = meteor_score_fn(all_responses, test_set["output"])

    # printing samples
    print("empty insts")
    for eidx in empty_idxs[:10]:
        print(test_set["instruction"][eidx])
    print("empty answers")
    for eidx in empty_idxs[:10]:
        print(test_set["output"][eidx])
    print("empty resps")
    for eidx in empty_idxs[:10]:
        print(dup[eidx])

    # we include 0s in mean because otherwise merely getting
    # one correct answer will skew the metric
    bertscores = np.array(bertscore_evals["recall"])  # recall because upper bound

    return ast_eval, bertscores.mean(), np.delete(bertscores, empty_idxs).mean(), np.mean(meteor_scores["meteor"]), (len(empty_idxs) / bertscores.shape[0])

In [16]:
file_names = [
    "responses_openai_mini_rag_icl.pkl",
    # "responses_icl.pkl",
    # "responses_train_split.pkl"
]

for file_name in file_names:
    ast_eval, bert_recall, bert_no_zeros, meteor, frac_empty = run_eval(file_name)
    print("\n\n", file_name, sep="")
    print("AST eval", ast_eval)
    print("BERT Score", bert_recall)
    print("\nBERT Score, excluding empty matches", bert_no_zeros)
    
    print("METEOR Score", meteor)
    
    print("Frac empty matches", frac_empty)



empty insts
Retrieve the symbol, name, and summary for the gene CSMD3. Use the API to fetch the data
How can I find the ATOH1 gene in humans by symbol, and what ensembl information is available for this gene?
Retrieve the symbol, Entrez gene ID, and Ensembl gene ID for the human gene BCL7C.
Get the name, symbol, entrezgene, taxid, and genomic_pos for the human gene myocd.
What is the gene information for the KEGG pathway ID hsa04068?
What are the details about the human CD36 gene, including its symbol, aliases, name, and type of gene?
What are the KEGG pathway IDs and names for human genes?
Which genes are associated with pathways in Reactome, KEGG, WikiPathways, and BioCarta?
Retrieve the name, other names, symbol, and type of gene fields for the human gene symbols CCDC185, CAPN8, LOC105373281, CAPN2, LOC105373046, LOC105373041, TP53BP2, GTF2IP20, SEPTIN7P13, and LOC124905682,.
Get the uniprot and refseq fields for the mouse gene ADIPOR1 using the MyGene.info API. Search within the sy

In [9]:
import pickle
with open('responses_openai_rag_icl.pkl', 'rb') as fd:
   all_responses = pickle.load(fd)

In [None]:
sum([s==0.0 for s in bertscore_evals["recall"]])

In [10]:
all_responses[:10]

['/v3/query?q=symbol:(CDK2 OR ABHD15)&fields=uniprot&size=2',
 '/v3/query?q=symbol:plaur&fields=entrezgene,symbol,facets=all',
 '/v3/query/?q=symbol:TSPAN6&fields=symbol,ensembl.gene',
 '/v3/query?q=symbol:MC1R&species=human&size=1&fields=ensembl.gene',
 '/v3/query?q=symbol:LOC123388108*&species=9669&fields=symbol,ensembl.gene',
 '/v3/query?q=entrezgene:287731&species=rat',
 '/v3/query?q=symbol:CDK2&species=human&fields=HGNC,MIM,summary,name,exac,symbol&size=10&facet_size=10&dotfield=true',
 '/v3/query?q=symbol:Cd74&species=mouse&fields=entrezgene',
 '/v3/query?q=MTOR&fields=all&size=10&from=0&fetch_all=false&facet_size=10&entrezonly=false&ensemblonly=false&dotfield=false',
 '/v3/query/?fields=symbol,name,entrezgene&q=zfin:ZDB-GENE-041010-37&species=zebrafish']

In [11]:
print(*test_set["output"][:10],sep="\n")

/v3/query?q=CDK2 OR ABHD15&fields=uniprot
/v3/query?q=Plaur&fields=entrezgene,symbol&size=10&from=0&fetch_all=false&facet_size=10&entrezonly=false&ensemblonly=false&dotfield=false
/v3/query/?fields=symbol,ensembl&q=TSPAN6
/v3/query?q=MC1R&species=human&size=1&fields=ensembl.ensembl_id
/v3/query?species=9669&fields=symbol,ensembl.gene&q=symbol:LOC123388108*
/v3/query/?species=rat&q=entrezgene:287731
/v3/query?q=symbol:CDK2&species=9606&size=10&from=0&fetch_all=false&facet_size=10&entrezonly=false&ensemblonly=false&dotfield=true&fields=HGNC,MIM,summary,name,exac,symbol
/v3/query/?species=mouse&scopes=symbol&fields=entrezgene&q=Cd74
/v3/query/?q=MTOR
/v3/query/?fields=symbol,name,entrezgene&q=ZFIN:ZDB-GENE-041010-37


In [None]:
print(*test_set["instruction"][:10],sep="\n")