In [1]:
import random
random.seed(10)

In [2]:
def compute_presence_accuracy(ret_out):
    correct = 0
    for d in ret_out:
        if d['ref'].strip() in d['output']:
            correct += 1
    return correct / len(ret_out)

# Compute BLEU score
from nltk.translate.bleu_score import sentence_bleu
#convert string to list of words
def str2list(s):
    return s.split()

#Take a reference and a candidate paragraph and compute the BLEU score.
def computeBLEU(reference, candidate):
    #convert candidate and reference to list of words
    candidate = str2list(candidate)
    reference = str2list(reference)
    # compute BLEU score between candidate and reference use upto n-gram
    # where n is minimum of number of words in reference and 4
    n = min(len(reference), 4)
    BLEUscore = sentence_bleu([reference], candidate, weights=[1./n]*n)
    return BLEUscore

def compute_BLEU(ret_out):
    BLEUscores = []
    for d in ret_out:
        BLEUscores.append(computeBLEU(d['ref'], d['output']))
    return sum(BLEUscores)/len(BLEUscores)

# Find the first number in the string
import re
def find_first_num(s):
    m = re.search(r'\d+', s)
    if m:
        return m.group()
    else:
        return None

In [3]:
import json
def get_metrics(FILE, INP_FILE=None, tok_limit=None):
    # Function to compute metrics on WikiQA
    # FILE: path to json file generated by run_inference_WikiQA.py
    # INP_FILE: input file for WikiQA, can be None if tok_limit is None
    # tok_limit: if set, will only compute metrics for inputs with token length <= tok_limit

    # Load json FILE
    with open(FILE) as f:
        ret_out = json.load(f)
    
    if tok_limit is not None:
        # Load json INP_FILE
        with open(INP_FILE) as f:
            inp_data = json.load(f)
            
        print(len(ret_out), len(inp_data))
        print("Old len of data: " + str(len(ret_out)))
        for idx in range(len(ret_out)):
            ret_out[idx]['tok_len'] = inp_data[idx]['conversations'][0]['tok_len']
            if(ret_out[idx]['input'] != inp_data[idx]['conversations'][0]['value']):
                print("Error")
        ret_out = [d for d in ret_out if (d['tok_len']<=tok_limit)]
        tok_len_stat = [d['tok_len'] for d in ret_out]
        print(sum(tok_len_stat)/len(tok_len_stat),max(tok_len_stat))
        print("New len of data: " + str(len(ret_out)))
    
    # Convert each ['output'] to ['output'][0]
    for i in range(len(ret_out)):
        ret_out[i]['output'] = ret_out[i]['output'][0]
    
    # Convert ret_out into a pandas dataframe
    import pandas as pd
    df = pd.DataFrame(ret_out)

    # Add a column 'question location' to df
    # Every even row is "end" and every odd row is "start"
    df['question location'] = df.index % 2
    # Change the value of 'question location' to "start" and "end"
    df['question location'] = df['question location'].apply(lambda x: "start" if x == 1 else "end")

    # Add a column 'answer location' to df
    # Every 6 rows, first 2 have the value 'start', next 2 are 'random' and last 2 are 'end'
    df['answer location'] = df.index % 6
    # Change the value of 'answer location' to "start", "random" and "end"
    df['answer location'] = df['answer location'].apply(lambda x: "start" if int(x/2) == 0 else "random" if int(x/2) == 1 else "end")

    # Add a column for exact match accuracy in string
    df['exact match'] = df.apply(lambda x: 1 if (x['ref'].strip() in x['output']) else 0, axis=1)
    # Add a column for BLEU score
    df['BLEU score'] = df.apply(lambda x: computeBLEU(x['ref'], x['output']), axis=1)

    # print average exact match accuracy and average BLEU score
    print("Average exact match accuracy: ", df['exact match'].mean())
    print("Average BLEU score: ", df['BLEU score'].mean())

    # Print average exact match accuracy and average BLEU score for each answer location
    print("Average exact match accuracy for each answer location: ")
    print(df.groupby('answer location')['exact match'].mean())
    print("Average BLEU score for each answer location: ")
    print(df.groupby('answer location')['BLEU score'].mean())

    # Print average exact match accuracy and average BLEU score for each question location
    print("Average exact match accuracy for each question location: ")
    print(df.groupby('question location')['exact match'].mean())
    print("Average BLEU score for each question location: ")
    print(df.groupby('question location')['BLEU score'].mean())

### Compute metrics on 4k FFQA data

In [11]:
FILE = None # Add path to the generated file here
get_metrics(FILE)

### Compute metrics on same FFQA data but only ones with input <=4k tokens

In [12]:
# Some of the data in the file might cross 4k tokens (the mean is 4k, refer data statistics)
FILE = None # Add path to the generated file here
INP_FILE = '../../../datasets/WikiQA/Free_Form_QA/ffqa_4k.json'
TOK_LIMIT = 4096
GEN_SLACK = 256

# To ensure input+generated is within 4k, we limit the input by an additional 256 tokens
get_metrics(FILE, INP_FILE, tok_limit = TOK_LIMIT - GEN_SLACK)