In [21]:
from nltk.translate.bleu_score import corpus_bleu
import argparse
import json
import logging
import nltk
import os
import pandas as pd
from evaluation_utils import failed_generation_index, eval_dataset, get_nested_values, load_dataset, safe_loc, compute_precision, compute_recall, corpus_meteor, average_precision

In [10]:
arguments = {
    "dataset": "../outputs/queries_with_execution_and_limits_new_2.parquet.gzip",
    "preprocess_gold": "../outputs/preprocess_gold_queries_with_execution_and_limits_new_2.json",
    "model": "Mistral-7B-Instruct-v0.2",
    "output": ".",
    "save_name": "test",
    "log_level": "warning",
    "log_file": "",
}

In [11]:
args = argparse.Namespace()
args.__dict__.update(arguments)
print(args)

numeric_log_level = getattr(logging, args.log_level.upper(), None)
if not isinstance(numeric_log_level, int):
    raise ValueError(f"Invalid log level: {args.log_level}.")
logging.basicConfig(filename=args.log_file if args.log_file else None, level=numeric_log_level)

if not os.path.exists(args.dataset):
    raise FileNotFoundError(f"The dataset file not found with path: {args.dataset}")

if args.preprocess_gold != None and not os.path.exists(args.preprocess_gold):
    raise FileNotFoundError(f"The preprocess gold dataset file not found with path: {args.preprocess_gold}")

nltk.download('wordnet', quiet=True)



True

In [13]:
df = load_dataset(args.dataset)
df_no_gen_fail = df # df.drop(failed_generation_index(df))
df_exec_timeout = df_no_gen_fail.loc[df_no_gen_fail['execution'] == 'timeout']
df_exec_fail = df_no_gen_fail.loc[df_no_gen_fail['execution'].str.startswith('exception')]
df_exec_empty = df_no_gen_fail.loc[df_no_gen_fail['execution'].isnull()]
df_exec_to_eval = df_no_gen_fail.drop(df_exec_timeout.index).drop(df_exec_fail.index).drop(df_exec_empty.index)
df_eval = eval_dataset(df_exec_to_eval)

In [15]:
# df_gold = None
# df_gold_exec_timeout = None
# df_gold_exec_fail = None
# df_gold_exec_empty = None
# df_gold_exec_to_eval = None
# df_gold_eval = None
# if args.gold != None:
#     df_gold = load_dataset(args.gold)
#     df_gold_exec_timeout = df_gold.loc[df_gold['execution'] == 'timeout']
#     df_gold_exec_fail = df_gold.loc[df_gold['execution'].str.startswith('exception')]
#     df_gold_exec_empty = df_gold.loc[df_gold['execution'].isnull()]
#     df_gold_exec_to_eval = df_gold.drop(df_gold_exec_timeout.index).drop(df_gold_exec_fail.index).drop(df_gold_exec_empty.index)
#     df_gold_eval = eval_dataset(df_gold_exec_to_eval, "gold_eval")
# else:
with open(args.preprocess_gold, "r") as f:
    data = json.load(f)
df_gold = pd.read_json(data['df_gold'])
df_gold_exec_timeout = pd.read_json(data['df_gold_exec_timeout'])
df_gold_exec_fail = pd.read_json(data['df_gold_exec_fail'])
df_gold_exec_empty = pd.read_json(data['df_gold_exec_empty'])
df_gold_exec_to_eval = pd.read_json(data['df_gold_exec_to_eval'])
df_gold_eval = pd.read_json(data['df_gold_eval'])

In [31]:
df_merged_eval = df_eval.copy()
df_merged_eval["gold_eval"] = df_merged_eval.apply(lambda x: safe_loc(x, df_gold_eval, "gold_eval", default=None), axis=1)
df_merged_eval["precision"] = df_merged_eval.apply(lambda x: compute_precision(get_nested_values(x['eval']), get_nested_values(x['gold_eval'])), axis=1)
df_merged_eval["recall"] = df_merged_eval.apply(lambda x: compute_recall(get_nested_values(x['eval']), get_nested_values(x['gold_eval'])), axis=1)

In [30]:
safe_loc(df_merged_eval.iloc[0], df_gold_eval, "gold_eval", default=None)

gold_eval    [{'property': {'type': 'uri', 'value': 'http:/...
Name: 0, dtype: object

In [19]:
df_gold_eval

Unnamed: 0,query,description,context,prompt,num_tokens,execution,executed_query,gold_eval
0,SELECT ?property ?propertyType ?propertyLabel ...,Wikidata properties in numerical order,Counting stuff on Wikidata\nAll Wikidata prope...,<s>[INST] <<SYS>>This is a conversation betwee...,267,"[{'property': {'type': 'uri', 'value': 'http:/...",SELECT ?property ?propertyType ?propertyLabel ...,"[{'property': {'type': 'uri', 'value': 'http:/..."
1,SELECT ?id ?idLabel ?idDescription ?new{\n?id ...,Wikidata properties excluding external IDs,Counting stuff on Wikidata\nVariation of the a...,<s>[INST] <<SYS>>This is a conversation betwee...,296,"[{'id': {'type': 'uri', 'value': 'http://www.w...",SELECT ?id ?idLabel ?idDescription ?new{\n?id ...,"[{'id': {'type': 'uri', 'value': 'http://www.w..."
3,SELECT (COUNT(DISTINCT ?article) AS ?count)\nW...,Count of fictional characters,Counting stuff on Wikidata\nCount of fictional...,<s>[INST] <<SYS>>This is a conversation betwee...,203,[{'count': {'datatype': 'http://www.w3.org/200...,SELECT (COUNT(DISTINCT ?article) AS ?count)\nW...,[{'count': {'datatype': 'http://www.w3.org/200...
4,SELECT (COUNT(?item) AS ?count)\nWHERE { ?item...,Count of items with coordinate locations,Counting stuff on Wikidata\nCount of items wit...,<s>[INST] <<SYS>>This is a conversation betwee...,186,[{'count': {'datatype': 'http://www.w3.org/200...,SELECT (COUNT(?item) AS ?count)\nWHERE { ?item...,[{'count': {'datatype': 'http://www.w3.org/200...
5,SELECT ?q ?langcode ?langLabel (COUNT(?label) ...,"Number of labels for an item, split by language",Counting stuff on Wikidata\nCount of labels fo...,<s>[INST] <<SYS>>This is a conversation betwee...,302,"[{'q': {'type': 'uri', 'value': 'http://www.wi...",SELECT ?q ?langcode ?langLabel (COUNT(?label) ...,"[{'q': {'type': 'uri', 'value': 'http://www.wi..."
...,...,...,...,...,...,...,...,...
2840,select ?feat ?en (count(*) as ?total)\nwith {\...,Grammatical features used on English nouns and...,,<s>[INST] <<SYS>>This is a conversation betwee...,309,"[{'feat': {'type': 'uri', 'value': 'http://www...",select ?feat ?en (count(*) as ?total)\nwith {\...,"[{'feat': {'type': 'uri', 'value': 'http://www..."
2841,SELECT ?country ?countryLabel ?lifeExpectancy\...,List of countries sorted by life expectancy,,<s>[INST] <<SYS>>This is a conversation betwee...,250,"[{'country': {'type': 'uri', 'value': 'http://...",SELECT ?country ?countryLabel ?lifeExpectancy\...,"[{'country': {'type': 'uri', 'value': 'http://..."
2842,SELECT ?item ?itemLabel ?coord ?image\nWHERE\n...,Railway bridges in Germany and Austria,,<s>[INST] <<SYS>>This is a conversation betwee...,271,"[{'item': {'type': 'uri', 'value': 'http://www...",SELECT ?item ?itemLabel ?coord ?image\nWHERE\n...,"[{'item': {'type': 'uri', 'value': 'http://www..."
2843,SELECT DISTINCT ?train_wreck ?train_wreckLabel...,Number of deaths in train wrecks per country,,<s>[INST] <<SYS>>This is a conversation betwee...,297,"[{'train_wreck': {'type': 'uri', 'value': 'htt...",SELECT DISTINCT ?train_wreck ?train_wreckLabel...,"[{'train_wreck': {'type': 'uri', 'value': 'htt..."


In [32]:
df_merged_eval

Unnamed: 0_level_0,query,description,context,prompt,num_tokens,execution,executed_query,eval,gold_eval,precision,recall
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
0,SELECT ?property ?propertyType ?propertyLabel ...,Wikidata properties in numerical order,Counting stuff on Wikidata\nAll Wikidata prope...,<s>[INST] <<SYS>>This is a conversation betwee...,267,"[{'property': {'type': 'uri', 'value': 'http:/...",SELECT ?property ?propertyType ?propertyLabel ...,"[{'property': {'type': 'uri', 'value': 'http:/...","[{'property': {'type': 'uri', 'value': 'http:/...",1.0,1.0
1,SELECT ?id ?idLabel ?idDescription ?new{\n?id ...,Wikidata properties excluding external IDs,Counting stuff on Wikidata\nVariation of the a...,<s>[INST] <<SYS>>This is a conversation betwee...,296,"[{'id': {'type': 'uri', 'value': 'http://www.w...",SELECT ?id ?idLabel ?idDescription ?new{\n?id ...,"[{'id': {'type': 'uri', 'value': 'http://www.w...","[{'id': {'type': 'uri', 'value': 'http://www.w...",1.0,1.0
3,SELECT (COUNT(DISTINCT ?article) AS ?count)\nW...,Count of fictional characters,Counting stuff on Wikidata\nCount of fictional...,<s>[INST] <<SYS>>This is a conversation betwee...,203,[{'count': {'datatype': 'http://www.w3.org/200...,SELECT (COUNT(DISTINCT ?article) AS ?count)\nW...,[{'count': {'datatype': 'http://www.w3.org/200...,[{'count': {'datatype': 'http://www.w3.org/200...,1.0,1.0
4,SELECT (COUNT(?item) AS ?count)\nWHERE { ?item...,Count of items with coordinate locations,Counting stuff on Wikidata\nCount of items wit...,<s>[INST] <<SYS>>This is a conversation betwee...,186,[{'count': {'datatype': 'http://www.w3.org/200...,SELECT (COUNT(?item) AS ?count)\nWHERE { ?item...,[{'count': {'datatype': 'http://www.w3.org/200...,[{'count': {'datatype': 'http://www.w3.org/200...,1.0,1.0
5,SELECT ?q ?langcode ?langLabel (COUNT(?label) ...,"Number of labels for an item, split by language",Counting stuff on Wikidata\nCount of labels fo...,<s>[INST] <<SYS>>This is a conversation betwee...,302,"[{'q': {'type': 'uri', 'value': 'http://www.wi...",SELECT ?q ?langcode ?langLabel (COUNT(?label) ...,"[{'q': {'type': 'uri', 'value': 'http://www.wi...","[{'q': {'type': 'uri', 'value': 'http://www.wi...",1.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...
2840,select ?feat ?en (count(*) as ?total)\nwith {\...,Grammatical features used on English nouns and...,,<s>[INST] <<SYS>>This is a conversation betwee...,309,"[{'feat': {'type': 'uri', 'value': 'http://www...",select ?feat ?en (count(*) as ?total)\nwith {\...,"[{'feat': {'type': 'uri', 'value': 'http://www...","[{'feat': {'type': 'uri', 'value': 'http://www...",1.0,1.0
2841,SELECT ?country ?countryLabel ?lifeExpectancy\...,List of countries sorted by life expectancy,,<s>[INST] <<SYS>>This is a conversation betwee...,250,"[{'country': {'type': 'uri', 'value': 'http://...",SELECT ?country ?countryLabel ?lifeExpectancy\...,"[{'country': {'type': 'uri', 'value': 'http://...","[{'country': {'type': 'uri', 'value': 'http://...",1.0,1.0
2842,SELECT ?item ?itemLabel ?coord ?image\nWHERE\n...,Railway bridges in Germany and Austria,,<s>[INST] <<SYS>>This is a conversation betwee...,271,"[{'item': {'type': 'uri', 'value': 'http://www...",SELECT ?item ?itemLabel ?coord ?image\nWHERE\n...,"[{'item': {'type': 'uri', 'value': 'http://www...","[{'item': {'type': 'uri', 'value': 'http://www...",1.0,1.0
2843,SELECT DISTINCT ?train_wreck ?train_wreckLabel...,Number of deaths in train wrecks per country,,<s>[INST] <<SYS>>This is a conversation betwee...,297,"[{'train_wreck': {'type': 'uri', 'value': 'htt...",SELECT DISTINCT ?train_wreck ?train_wreckLabel...,"[{'train_wreck': {'type': 'uri', 'value': 'htt...","[{'train_wreck': {'type': 'uri', 'value': 'htt...",1.0,1.0


In [70]:
y = df_merged_eval.iloc[0]
average_precision(get_nested_values(y['eval']), get_nested_values(y['gold_eval']), max_k=1)

1.0

In [65]:
len(get_nested_values(y['eval']))

40

In [67]:
df_merged_eval.apply(lambda x: get_nested_values(x['eval']), axis=1).map(len).max()

608697

In [71]:
df_merged_eval["average_precision"] = df_merged_eval.apply(lambda x: average_precision(get_nested_values(x['eval']), get_nested_values(x['gold_eval']), max_k=100000), axis=1)

In [17]:
m_precision = df_merged_eval['precision'].mean()
m_recall = df_merged_eval['recall'].mean()
m_fscore = 2*m_precision*m_recall/(m_precision+m_recall)

bleu_score = corpus_bleu([[x.split()] for x in df_no_gen_fail['target']], [x.split() for x in df_no_gen_fail['translated_prompt']])
meteor_score = corpus_meteor(df_no_gen_fail['target'], df_no_gen_fail['translated_prompt'])


KeyError: 'target'

In [None]:
serie = pd.Series(data=
    {
        "model_name": args.model,
        "num_rows": len(df),
        "num_gen_fail": len(df.loc[df['has_error'] == True]),
        "num_exec_timeout": len(df_exec_timeout),
        "num_exec_fail": len(df_exec_fail),
        "num_exec_empty": len(df_exec_empty),
        "num_exec_to_eval": len(df_exec_to_eval),
        "num_eval": len(df_eval),
        "num_eval_empty": len(df_eval.loc[df_eval['eval'].map(len) == 0]),
        "bleu_score": bleu_score,
        "meteor_score": meteor_score,
        "precision": m_precision,
        "recall": m_recall,
        "f1score": m_fscore
    })