In [1]:
import string
import re
import os
import glob
import json
from tqdm import trange
import pprint
from collections import defaultdict, OrderedDict

# for evaluation
from collections import Counter 
from rouge import Rouge

In [2]:
# answer nomalization
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

In [3]:
# prompt preparation
def flex_format_batch(template, history_org, eval_mode):

    print(f'template: {template}\n\
          history_org: {history_org}\n\
            eval_mode: {eval_mode}')

    used_history = {}
    cnt_used_key = 0

    for k, v in history_org.items():
        if k in template:
            cnt_used_key += 1
            if k in ['prompt', 'response', 'wiki_id_title']:
                # transpose
                v = [list(x) for x in zip(*v)]
            used_history[k] = v

    used_history_mod = [{k: v[i] for k, v in used_history.items()} for i in range(len(history_org['question']))]
    
    print(f'used_history: {used_history}\n\
          used_history_mod: {used_history_mod}')

    if eval_mode == 'eval':
        out = []
        for i in range(len(history_org['question'])):
            for k, v in used_history_mod[i].items():
                if k == 'question':
                    question = v
                elif k == 'prompt':
                    prompt = v
                elif k == 'response':
                    response = v
                elif k == 'wiki_id_title':
                    wiki_id_title = v
                else:
                    raise NotImplementedError
            out.append(eval(template))
    elif eval_mode == 'f-strings':
        out = [template.format(**(used_history_mod[i])) for i in range(len(history_org['question']))]
    else:
        raise ValueError
    
    print(f'out: {out}')
    
    return out

In [4]:
# load data from filename
def load_data(filename):
    data = []
    with open(filename, "r") as fin:
        lines = fin.readlines()
        for line in lines:
            data.append(json.loads(line))
    return data

In [5]:
# validate input
def validate_input(gold_records, guess_records):

    if len(gold_records) != len(guess_records):
        print(
            "WARNING: DIFFERENT SIZE gold: {} guess: {}".format(
                len(gold_records), len(guess_records)
            )
        )

    # align order
    gold_ids = []
    for gold in gold_records:
        assert str(gold["id"]).strip() not in gold_ids, "Gold IDs should be unique"
        gold_ids.append(str(gold["id"]).strip())

    id2guess_record = {}
    for guess in guess_records:
        assert (
            str(guess["id"]).strip() not in id2guess_record
        ), "Prediction IDs should be unique"
        id2guess_record[str(guess["id"]).strip()] = guess

    print(f'id2guess_record: {id2guess_record}\n\
          gold_ids: {gold_ids}')
    
    guess_records = []
    for id in gold_ids:
        if id in id2guess_record:
            guess_records.append(id2guess_record[id])
        else:
            raise ValueError("ERROR: no prediction provided for id: {}".format(id))
        
    print(f'guess_records: {guess_records}')

    return gold_records, guess_records

In [6]:
# utility to get gold answers
def get_gold_answers(gold):
    ground_truths = set()
    for item in gold["output"]:
        if "answer" in item and item["answer"] and len(item["answer"].strip()) > 0:
            ground_truths.add(item["answer"].strip())

    print(f'gold: {gold}\n\
          ground_truths: {ground_truths}')
    
    return ground_truths

In [7]:
# utility to get max
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)

    print(f'metric_fn: {metric_fn}\n\
          prediction: {prediction}\n\
          ground_truths: {ground_truths}\n\
          scores_for_ground_truths: {scores_for_ground_truths}')
    
    return max(scores_for_ground_truths)

In [8]:
def get_ids_list(datapoint, rank_keys, verbose=False):
    # collect all gold ids
    ids_list = []
    for output in datapoint["output"]:
        current_ids_list = []
        if "provenance" in output:
            for provenance in output["provenance"]:
                if any(rank_key not in provenance for rank_key in rank_keys):
                    missing = set(rank_keys) - set(
                        list(provenance.keys())
                    ).intersection(set(rank_keys))
                    if verbose:
                        print(
                            f"WARNING: missing key(s) {missing} in provenance, unable to compute retrieval for those."
                        )
                else:
                    current_ids_list.append(
                        "+".join(
                            [
                                str(provenance[rank_key]).strip()
                                for rank_key in rank_keys
                            ]
                        )
                    )
        ids_list.append(list(set(current_ids_list)))  # remove duplicates

    # consider only unique ids
    return ids_list

In [9]:
def computeRprec(guess_ids, gold_ids):

    R = len(gold_ids)
    num = 0

    for prediction in guess_ids[:R]:
        if str(prediction).strip() in gold_ids:
            num += 1

    Rprec = num / R if R > 0 else 0
    return Rprec

In [10]:
# R-precision https://link.springer.com/referenceworkentry/10.1007%2F978-0-387-39940-9_486
def rprecision(guess_item, gold_item, rank_keys):
    gold_ids_list = get_ids_list(gold_item, rank_keys)
    guess_ids = get_ids_list(guess_item, rank_keys)[0]
    Rprec_vector = []
    for gold_ids in gold_ids_list:
        Rprec = computeRprec(guess_ids, gold_ids)
        Rprec_vector.append(Rprec)

    print(f'gold_ids_list: {gold_ids_list}\n\
          guess_ids: {guess_ids}\n\
          Rprec_vector: {Rprec_vector}')
    
    return max(Rprec_vector)

In [11]:
# utility to get gold titles
def get_gold_titles(gold):
    titles = set()
    for item in gold["output"]:
        if "provenance" in item:
            for provenance in item["provenance"]:
                if (
                    "title" in provenance
                    and provenance["title"]
                    and len(provenance["title"].strip()) > 0
                ):
                    titles.add(provenance["title"].strip())
    return titles

In [12]:
# 1. Precision computation
def precision_at_k(rank, k):

    # precision @ k
    p = rank[:k].count(True) / k

    return p

In [13]:
# 2. Recall computation
def recall_at_k(rank, num_distinct_evidence_sets, k):

    r = rank[:k].count(True) / num_distinct_evidence_sets

    return r

In [14]:
# 3. Success rate computation
def success_rate_at_k(rank, k):

    # success rate @ k
    p = int(True in rank[:k])

    return p

In [15]:
# 4. Answer in context computation
def answer_in_context_at_k(guess_item, gold_item, k):

    answers = get_gold_answers(gold_item)

    if "provenance" in guess_item["output"][0]:
        provenance = guess_item["output"][0]["provenance"]
        for i in range(0, min(k, len(provenance))):
            if "text" in provenance[i]:
                normalized_text = normalize_answer(
                    provenance[i]["text"]
                )
                for a in answers:
                    if normalize_answer(a) in normalized_text:
                        return 1
    return 0

In [16]:
# 5. Answer+entity in context computation
def answer_and_ent_in_context_at_k(guess_item, gold_item, k):

    answers = get_gold_answers(gold_item)
    titles = get_gold_titles(gold_item)

    if "provenance" in guess_item["output"][0]:
        provenance = guess_item["output"][0]["provenance"]
        for i in range(0, min(k, len(provenance))):
            if "text" in provenance[i]:
                normalized_text = normalize_answer(
                    provenance[i]["text"]
                )
                has_answer = False
                for a in answers:
                    if normalize_answer(a) in normalized_text:
                        has_answer = True
                        break
                if has_answer:
                    for t in titles:
                        if normalize_answer(t) in normalized_text:
                            return 1

    return 0

In [17]:
# 6. Entity in input
def entity_in_input(gold_item):

    input = normalize_answer(gold_item["input"])
    titles = get_gold_titles(gold_item)

    for t in titles:
        if normalize_answer(t) in input:
            return 1
    return 0

In [18]:
# 7. Entity in context
def ent_in_context_at_k(guess_item, gold_item, k):

    titles = get_gold_titles(gold_item)

    if "provenance" in guess_item["output"][0]:
        provenance = guess_item["output"][0]["provenance"]
        for i in range(0, min(k, len(provenance))):
            if "text" in provenance[i]:
                normalized_text = normalize_answer(
                    provenance[i]["text"]
                )
  
                for t in titles:
                    if normalize_answer(t) in normalized_text:
                        return 1

    return 0

In [19]:
def get_rank(guess_item, gold_item, k, rank_keys, verbose=False):
    """
    The main idea is to consider each evidence set as a single point in the rank.
    The score in the rank for an evidence set is given by the lowest scored evidence in the set.
    """

    assert k > 0, "k must be a positive integer grater than 0."

    rank = []
    num_distinct_evidence_sets = 0

    guess_ids = get_ids_list(guess_item, rank_keys)[0]

    if guess_ids and len(guess_ids) > 0:

        # 1. collect evidence sets and their sizes
        evidence_sets = []
        e_size = defaultdict(int)
        for output in gold_item["output"]:
            if "provenance" in output:
                e_set = {
                    "+".join(
                        [
                            str(provenance[rank_key]).strip()
                            for rank_key in rank_keys
                            if rank_key in provenance
                        ]
                    )
                    for provenance in output["provenance"]
                }
                if e_set not in evidence_sets:  # no duplicate evidence set
                    evidence_sets.append(e_set)
                    e_size[len(e_set)] += 1
        num_distinct_evidence_sets = len(evidence_sets)

        # 2. check what's the minimum number of predicted pages needed to get a robust P/R@k
        min_prediction_size = 0
        c = 0
        for size, freq in sorted(e_size.items(), reverse=True):
            for _ in range(freq):
                min_prediction_size += size
                c += 1
                if c == k:
                    break
            if c == k:
                break
        # if the number of evidence sets is smaller than k
        min_prediction_size += k - c

        if verbose and len(guess_ids) < min_prediction_size:
            print(
                f"WARNING: you should provide at least {min_prediction_size} provenance items for a robust recall@{k} computation (you provided {len(guess_ids)} item(s))."
            )

        # 3. rank by gruping pages in each evidence set (each evidence set count as 1),
        # the position in the rank of each evidence set is given by the last page in guess_ids
        # non evidence pages counts as 1
        rank = []
        for guess_id in guess_ids:
            guess_id = str(guess_id).strip()
            found = False
            for idx, e_set in enumerate(evidence_sets):

                e_set_id = f"evidence_set:{idx}"

                if guess_id in e_set:
                    found = True

                    # remove from the rank previous points referring to this evidence set
                    if e_set_id in rank:
                        rank.remove(e_set_id)

                    # remove the guess_id from the evidence set
                    e_set.remove(guess_id)

                    if len(e_set) == 0:
                        # it was the last evidence, it counts as true in the rank
                        rank.append(True)
                    else:
                        # add a point for this partial evidence set
                        rank.append(e_set_id)

            if not found:
                rank.append(False)

    return rank, num_distinct_evidence_sets

In [20]:
def get_ranking_metrics(guess_item, gold_item, ks, rank_keys):

    Rprec = 0.0
    P_at_k = {"precision@{}".format(k): 0 for k in sorted(ks) if k > 0}
    R_at_k = {"recall@{}".format(k): 0 for k in sorted(ks) if k > 1}
    S_at_k = {"success_rate@{}".format(k): 0 for k in sorted(ks) if k > 1}
    A_at_k = {"answer_in_context@{}".format(k): 0 for k in sorted(ks) if k > 0}
    AE_at_k = {"answer_and_ent_in_context@{}".format(k): 0 for k in sorted(ks) if k > 0}
    E_at_k = {"entity_in_context@{}".format(k): 0 for k in sorted(ks) if k > 0}

    assert (
        "output" in guess_item and len(guess_item["output"]) == 1
    ), f"guess should provide exactly one output for {guess_item['id']}"

    Rprec = rprecision(guess_item, gold_item, rank_keys=rank_keys)
    eii = entity_in_input(gold_item)
    for k in ks:

        # 0. get rank
        rank, num_distinct_evidence_sets = get_rank(
            guess_item, gold_item, k, rank_keys=rank_keys
        )

        if num_distinct_evidence_sets > 0:

            # 1. precision
            P_at_k["precision@{}".format(k)] = precision_at_k(rank, k)

            # 2. recall
            R_at_k["recall@{}".format(k)] = recall_at_k(
                rank, num_distinct_evidence_sets, k
            )

            # 3. success rate
            S_at_k["success_rate@{}".format(k)] = success_rate_at_k(rank, k)

        # 4. answer in context
        A_at_k["answer_in_context@{}".format(k)] = answer_in_context_at_k(
            guess_item, gold_item, k
        )

        AE_at_k[
            "answer_and_ent_in_context@{}".format(k)
        ] = answer_and_ent_in_context_at_k(guess_item, gold_item, k)

        E_at_k[
            "entity_in_context@{}".format(k)
        ] = ent_in_context_at_k(guess_item, gold_item, k)

    return {
        "Rprec": Rprec,
        **P_at_k,
        **R_at_k,
        **S_at_k,
        **A_at_k,
        **AE_at_k,
        **E_at_k,
        "entity_in_input": eii,
    }


In [30]:
prediction = "The quick, brown fox jumps over the lazy dog."
ground_truth = "The fox jumps over the dog."

file_path = "/hcds_vol/private/skes/RaLLe/jupyter"
config_exp_path = "../scripts/configs/experiment_settings/test_evaluation"
qa_path = "../KILT/data/nq-dev-kilt.jsonl"
guess_path = "output_reduced.jsonl"

ks = [1, 5]
rank_keys = ["wikipedia_id"]

# evaluation metrics

In [22]:
# F1 score definition
def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())

    print(f'prediction_tokens: {prediction_tokens}\n\
        ground_truth_tokens: {ground_truth_tokens}\n\
        common: {common}\n\
        num_same: {num_same}')

    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)

    print(f'f1: {f1}')

    return f1

In [23]:
f1_score(prediction, ground_truth)

prediction_tokens: ['quick', 'brown', 'fox', 'jumps', 'over', 'lazy', 'dog']
        ground_truth_tokens: ['fox', 'jumps', 'over', 'dog']
        common: Counter({'fox': 1, 'jumps': 1, 'over': 1, 'dog': 1})
        num_same: 4
f1: 0.7272727272727273


0.7272727272727273

In [24]:
# EM score definition
def exact_match_score(prediction, ground_truth):
    return normalize_answer(prediction) == normalize_answer(ground_truth)

In [25]:
exact_match_score(prediction, ground_truth)

False

In [26]:
# ROUGEL score definition
def rougel_score(prediction, ground_truth):
    rouge = Rouge()

    print(f'rouge: {rouge}')
    
    # no normalization
    try:
        scores = rouge.get_scores(prediction, ground_truth, avg=True)
    except ValueError:  # "Hypothesis is empty."
        return 0.0
    
    print(f'scores: {scores}')
    
    return scores["rouge-l"]["f"]

In [27]:
rougel_score(prediction, ground_truth)

rouge: <rouge.rouge.Rouge object at 0x7fd3303a0110>
scores: {'rouge-1': {'r': 1.0, 'p': 0.6666666666666666, 'f': 0.7999999952000001}, 'rouge-2': {'r': 0.6, 'p': 0.375, 'f': 0.4615384568047337}, 'rouge-l': {'r': 1.0, 'p': 0.6666666666666666, 'f': 0.7999999952000001}}


0.7999999952000001

# dismantling ralle evaluation

In [31]:
os.chdir(file_path)
cwd = os.getcwd()
cwd

'/hcds_vol/private/skes/RaLLe/jupyter'

In [32]:
os.chdir(config_exp_path)
config_paths = glob.glob("./**/*.json", recursive=True)
config_paths = sorted(config_paths)
config_paths = [os.path.join(config_exp_path, os.path.basename(c)) for c in config_paths]
os.chdir(cwd)
config_paths

['../scripts/configs/experiment_settings/test_evaluation/test_nq.json']

In [33]:
tmp_cfg_path = config_paths[0]
with open(tmp_cfg_path, "r", encoding="utf-8") as reader:
    text = reader.read()
config_exp = json.loads(text)
config_exp

{'chain_config': {'dataset': {'dataset_name': 'NQ',
   'num_evaluate': 10,
   'batch_size': 1},
  'len_chain': 2,
  'chain': [{'prompt_template': '{question}',
    'function': 'Retriever',
    'retriever_name': 'flat_subset_499992',
    'npassage': 3,
    'f-strings_or_eval': 'f-strings'},
   {'prompt_template': 'Referring to the following document, answer "{question}?" in 5 words or less.\n\n{response[0]}\n\nAnswer: ',
    'function': 'LLM',
    'llm_name': 'oasst-sft-1-pythia-12b',
    'f-strings_or_eval': 'f-strings'}]}}

In [34]:
used_llm = set()
used_ret = set()

for chain in config_exp["chain_config"]["chain"]:
    if "llm_name" in chain:
        used_llm.add(chain["llm_name"])
    if "retriever_name" in chain:
        used_ret.add(chain["retriever_name"])

used_llm = list(used_llm)
used_ret = list(used_ret)

print(f'used_llm: {used_llm}\n\
used_ret: {used_ret}')

used_llm: ['oasst-sft-1-pythia-12b']
used_ret: ['flat_subset_499992']


In [35]:
num_qa = config_exp['chain_config']['dataset']['num_evaluate']

In [36]:
pp = pprint.PrettyPrinter(indent=4)

gold_records = load_data(qa_path)[:num_qa]
guess_records = load_data(guess_path)[:num_qa]

print(f'gold_records: {len(gold_records)}\n\
      guess_records: {guess_records}')

gold_records: 10
      guess_records: [{'id': '6915606477668963399', 'output': [{'provenance': [{'wikipedia_id': 208157, 'doc_id': 379552}, {'wikipedia_id': 1216013, 'doc_id': 47596}, {'wikipedia_id': 18973788, 'doc_id': 465915}], 'answer': '\nThree dots (...) represent the number 3 in mathematics. They are often used to show repetition or to indicate something is being referred back to earlier information.'}]}, {'id': '-8366545547296627039', 'output': [{'provenance': [{'wikipedia_id': 1404364, 'doc_id': 105029}, {'wikipedia_id': 74983, 'doc_id': 269986}, {'wikipedia_id': 74983, 'doc_id': 269997}], 'answer': '\n"Who wrote the song \'Photograph\' by Ringo Starr?"\n\nis derived from Indian music and is based on a high-pitched drone created by tambura drums.'}]}, {'id': '-5004457603684974952', 'output': [{'provenance': [{'wikipedia_id': 512449, 'doc_id': 156640}, {'wikipedia_id': 1406500, 'doc_id': 220782}, {'wikipedia_id': 48030881, 'doc_id': 162284}], 'answer': '\nMaroon 5'}]}, {'id': '

In [37]:
gold_records, guess_records = validate_input(gold_records, guess_records)

id2guess_record: {'6915606477668963399': {'id': '6915606477668963399', 'output': [{'provenance': [{'wikipedia_id': 208157, 'doc_id': 379552}, {'wikipedia_id': 1216013, 'doc_id': 47596}, {'wikipedia_id': 18973788, 'doc_id': 465915}], 'answer': '\nThree dots (...) represent the number 3 in mathematics. They are often used to show repetition or to indicate something is being referred back to earlier information.'}]}, '-8366545547296627039': {'id': '-8366545547296627039', 'output': [{'provenance': [{'wikipedia_id': 1404364, 'doc_id': 105029}, {'wikipedia_id': 74983, 'doc_id': 269986}, {'wikipedia_id': 74983, 'doc_id': 269997}], 'answer': '\n"Who wrote the song \'Photograph\' by Ringo Starr?"\n\nis derived from Indian music and is based on a high-pitched drone created by tambura drums.'}]}, '-5004457603684974952': {'id': '-5004457603684974952', 'output': [{'provenance': [{'wikipedia_id': 512449, 'doc_id': 156640}, {'wikipedia_id': 1406500, 'doc_id': 220782}, {'wikipedia_id': 48030881, 'doc_

## Calculate matrics

In [38]:
# downstream matrics
accuracy = 0
normalized_em = 0
normalized_f1 = 0
rougel = 0

# kilt matrics
kilt_accuracy = 0
kilt_em = 0
kilt_f1 = 0
kilt_rougel = 0

In [39]:
total_count = 0

for guess_item, gold_item in zip(guess_records, gold_records):

    print(f'guess_item: {guess_item}\n\
            gold_item: {gold_item}')
    
    # check ids
    assert(
        str(guess_item["id"]).strip() == str(gold_item["id"]).strip()
    ), "IDs should match"

    total_count += 1

    # check if each output of guess file exist in set of candidate answers
    gold_candidate_answers = get_gold_answers(gold_item)

    conditions = (len(guess_item["output"]) == 1) and (
        "answer" in guess_item["output"][0]
    )

    assert(
        conditions
    ), f"you should provide exactly one valid answer for {guess_item['id']}"

    guess_answer = str(guess_item["output"][0]["answer"]).strip()

    # empty answer
    if len(guess_answer) == 0:
        continue

    # 0. accuracy = strict exact match
    local_accuracy = 0
    if guess_answer in gold_candidate_answers:
        local_accuracy = 1
    accuracy += local_accuracy

    # 1. normalized exact match
    local_em = metric_max_over_ground_truths(
        exact_match_score, guess_answer, gold_candidate_answers
    )
    normalized_em += local_em

    # 2. normalized f1
    local_f1 = metric_max_over_ground_truths(
        f1_score, guess_answer, gold_candidate_answers
    )
    normalized_f1 += local_f1

    # 3. ROUGEL
    local_rougel = metric_max_over_ground_truths(
        rougel_score, guess_answer, gold_candidate_answers
    )
    rougel += local_rougel

    # KILT metrics
    Rprec = rprecision(
        guess_item, gold_item, rank_keys=["wikipedia_id"]
    )

    if Rprec == 1:
        # 1. KILT_AC
        kilt_accuracy += local_accuracy

        # 2. KILT_EM
        kilt_em += local_em

        # 3. KILT_F1
        kilt_f1 += local_f1

        # 4. KILT_RL
        kilt_rougel += local_rougel

guess_item: {'id': '6915606477668963399', 'output': [{'provenance': [{'wikipedia_id': 208157, 'doc_id': 379552}, {'wikipedia_id': 1216013, 'doc_id': 47596}, {'wikipedia_id': 18973788, 'doc_id': 465915}], 'answer': '\nThree dots (...) represent the number 3 in mathematics. They are often used to show repetition or to indicate something is being referred back to earlier information.'}]}
            gold_item: {'id': '6915606477668963399', 'input': 'what do the 3 dots mean in math', 'output': [{'answer': 'the therefore sign', 'provenance': [{'wikipedia_id': '10593264', 'title': 'Therefore sign', 'start_paragraph_id': 1, 'start_character': 44, 'end_paragraph_id': 1, 'end_character': 62, 'bleu_score': 1.0, 'section': 'Section::::Abstract.'}]}, {'answer': 'therefore sign', 'provenance': [{'wikipedia_id': '10593264', 'title': 'Therefore sign', 'start_paragraph_id': 1, 'start_character': 48, 'end_paragraph_id': 1, 'end_character': 62, 'bleu_score': 1.0, 'section': 'Section::::Abstract.'}]}, {'

In [40]:
if total_count > 0:
    accuracy /= total_count
    normalized_em /= total_count
    normalized_f1 /= total_count
    rougel /= total_count

    kilt_accuracy /= total_count
    kilt_em /= total_count
    kilt_f1 /= total_count
    kilt_rougel /= total_count

print(f'\
      KILT metrics:\n\
      KILT_AC: {kilt_accuracy}\n\
      KILT_EM: {kilt_em}\n\
      KILT_F1: {kilt_f1}\n\
      KILT_RL: {kilt_rougel}\n\
      downstream metrics:\n\
      accuracy: {accuracy}\n\
      normalized_em: {normalized_em}\n\
      normalized_f1: {normalized_f1}\n\
      rougel: {rougel}')

      KILT metrics:
      KILT_AC: 0.0
      KILT_EM: 0.0
      KILT_F1: 0.0
      KILT_RL: 0.0
      downstream metrics:
      accuracy: 0.0
      normalized_em: 0.0
      normalized_f1: 0.11335693935693936
      rougel: 0.09813491843585166


## Compute retrieval performance

In [41]:
ks = sorted([int(x) for x in ks])

result = OrderedDict()
result["Rprec"] = 0.0
result["entity_in_input"] = 0.0

for k in ks:
    if k>0:
        result["precision@{}".format(k)] = 0.0
        result["answer_in_context@{}".format(k)] = 0.0
        result["answer_and_entity_in_context@{}".format(k)] = 0.0
        result["entity_in_context@{}".format(k)] = 0.0
    if k>1:
        result["recall@{}".format(k)] = 0.0
        result["success_rate@{}".format(k)] = 0.0

In [42]:
assert len(guess_records) == len(gold_records),\
    "different size gold: {} guess: {}".format(len(guess_records), len(gold_records))

for guess_item, gold_item in zip(guess_records, gold_records):
    assert str(guess_item["id"]).strip() == str(gold_item["id"]).strip(), "IDs should match"

    ranking_metrics = get_ranking_metrics(guess_item, gold_item, ks, rank_keys)
    result["Rprec"] += ranking_metrics["Rprec"]
    result["entity_in_input"] += ranking_metrics["entity_in_input"]

    for k in ks:
        if k>0:
            result["precision@{}".format(k)] += ranking_metrics["precision@{}".format(k)]
            result["answer_in_context@{}".format(k)] += ranking_metrics["answer_in_context@{}".format(k)]
            result["answer_and_entity_in_context@{}".format(k)] += ranking_metrics["answer_and_ent_in_context@{}".format(k)]
            result["entity_in_context@{}".format(k)] += ranking_metrics["entity_in_context@{}".format(k)]
        if k>1:
            result["recall@{}".format(k)] += ranking_metrics["recall@{}".format(k)]
            result["success_rate@{}".format(k)] += ranking_metrics["success_rate@{}".format(k)]

gold_ids_list: [['10593264'], ['10593264'], [], [], ['10593264'], ['10593264']]
          guess_ids: ['208157', '18973788', '1216013']
          Rprec_vector: [0.0, 0.0, 0, 0, 0.0, 0.0]
gold: {'id': '6915606477668963399', 'input': 'what do the 3 dots mean in math', 'output': [{'answer': 'the therefore sign', 'provenance': [{'wikipedia_id': '10593264', 'title': 'Therefore sign', 'start_paragraph_id': 1, 'start_character': 44, 'end_paragraph_id': 1, 'end_character': 62, 'bleu_score': 1.0, 'section': 'Section::::Abstract.'}]}, {'answer': 'therefore sign', 'provenance': [{'wikipedia_id': '10593264', 'title': 'Therefore sign', 'start_paragraph_id': 1, 'start_character': 48, 'end_paragraph_id': 1, 'end_character': 62, 'bleu_score': 1.0, 'section': 'Section::::Abstract.'}]}, {'answer': 'a logical consequence , such as the conclusion of a syllogism'}, {'answer': 'the therefore sign ( ∴ ) is generally used before a logical consequence , such as the conclusion of a syllogism'}, {'provenance': [{

In [43]:
if len(guess_records) > 0:
    result["Rprec"] /= len(guess_records)
    result["entity_in_input"] /= len(guess_records)

    for k in ks:
        if k>0:
            result["precision@{}".format(k)] /= len(guess_records)
            result["answer_in_context@{}".format(k)] /= len(guess_records)
            result["answer_and_entity_in_context@{}".format(k)] /= len(guess_records)
            result["entity_in_context@{}".format(k)] /= len(guess_records)
        if k>1:
            result["recall@{}".format(k)] /= len(guess_records)
            result["success_rate@{}".format(k)] /= len(guess_records)

print(f'result: {result}')

result: OrderedDict([('Rprec', 0.0), ('entity_in_input', 0.5), ('precision@1', 0.0), ('answer_in_context@1', 0.0), ('answer_and_entity_in_context@1', 0.0), ('entity_in_context@1', 0.0), ('precision@5', 0.0), ('answer_in_context@5', 0.0), ('answer_and_entity_in_context@5', 0.0), ('entity_in_context@5', 0.0), ('recall@5', 0.0), ('success_rate@5', 0.0)])
