In [1]:
# Filter NQ dev data by only keeping the instances that the model can give correct original answers

import sys
import json
import jsonlines
import string
import re
from pathlib import Path

In [4]:
# Filter the NQ test data provided in the NQ repository for Llama2 experiments
# Here, the author only filters the test data provided in the git repo by
# keeping a subset of data points for which the model can generate the correct original answer in its output

proj_path = Path("/home/xiaotang/Project/context-faithful-llm/")
data_path = proj_path / "datasets"
orig_data_path = data_path / "nq" / "orig_dev_filtered.json"
sub_data_path = data_path / "nq" / "conflict_dev_filtered.json"

with open(orig_data_path, "r") as fin:
    orig_examples = json.load(fin)

with open(sub_data_path, "r") as fin:
    counter_examples = json.load(fin)

print("Number of original examples: ", len(orig_examples))
print("Number of counterfactual examples: ", len(counter_examples))
assert(len(orig_examples) == len(counter_examples))

# Load predictions of Llama2-chat models
pred_path = proj_path / "results" / "llama2-7b-chat-no-demo-closed-book_preds.json"
predictions = []
with open(pred_path, "r") as fin:
    predictions = json.load(fin)

assert(len(predictions) == len(orig_examples))

Number of original examples:  2773
Number of counterfactual examples:  2773


In [13]:
# Iterate prediction file, extract a subset of examples (sample indices) for which 
# the model output contains the original answer
 
filtered_data_indices = []
for idx, sample in enumerate(predictions):
    golds = sample['answer']
    pred = sample['prediction']

    recall = 0
    if isinstance(golds, list):
        for g in golds:
            recall = max(recall_score(pred, g), recall)
    else:
        recall = recall_score(pred, golds)
    
    if  recall == True:
        filtered_data_indices.append(idx)

print(len(filtered_data_indices))
print(len(filtered_data_indices) / len(predictions))

1610
0.5805986296429859


In [15]:
# Filter the original and couterfactual samples for Llama2 experiments

filtered_orig_samples = []
filtered_counter_samples = []

for idx in filtered_data_indices:
    filtered_orig_samples.append(orig_examples[idx])
    filtered_counter_samples.append(counter_examples[idx])

print(len(filtered_orig_samples))
print(len(filtered_counter_samples))

# Write filtered NQ test data (for Llama-2 experiments) to file

output_path = data_path / "nq_llama2"
with open(output_path / "orig_dev_filtered.json", 'w') as fout:
    json.dump(filtered_orig_samples, fout, indent=4)

with open(output_path / "conflict_dev_filtered.json", 'w') as fout:
    json.dump(filtered_counter_samples, fout, indent=4)



1610
1610


In [5]:
# Load NQ dev data
proj_path = Path('/home/xiaotang/Project/entity_perturb')
data_path = proj_path / "data"
orig_data_path = data_path /  "MRQANaturalQuestionsDev_orig.jsonl"
orig_answers = get_gold_answers(orig_data_path)

# Load the predictions of Llama2-chat models
pred_path = proj_path / "results" / "llama2-7b-chat-nq-origin-greedy-no_special_token_preds.json"
predictions = []
with open(pred_path, 'r') as fin:
    predictions = json.load(fin)

pred_answers = []
for pred in predictions:
    pred_answers.append(pred['prediction'])

assert(len(pred_answers) == len(orig_answers))

In [6]:
# Evaluate the exact match and recall

em, recall = get_score(pred_answers, orig_answers)
print("EM", em)
print("recall", recall)

EM 0.48461862519043697
recall 39.633375390574436


In [29]:
# Extract a subset of test instances whose recall is 1

filtered_data = []
for sample, gold in zip(predictions, orig_answers):
    pred = sample['prediction']
    recall = 0
    if isinstance(gold, list):
        for g in gold:
            recall = max(recall_score(pred, g), recall)
    else:
        recall = recall_score(pred, gold)
    if recall == True:
        filtered_data.append(sample)

print(len(filtered_data))
print(len(filtered_data) / len(orig_answers))

1881
0.3963337547408344


In [32]:
# Write filtered NQ dev data to file
# Current setting: without special tokens + greedy decoding

output_path = data_path /  "filterd_MRQANaturalQuestionsDev_orig.jsonl"
with jsonlines.open(output_path, mode='w') as writer:
    writer.write_all(filtered_data)

In [11]:
# Convert the filtered test data and its perturbed version to context faithful llm format

orig_data_path = data_path / "filterd_MRQANaturalQuestionsDev_orig.jsonl"
conflict_data_path = data_path / "conflict_MRQANaturalQuestionsDev-filtered.jsonl"

converted_data = []
src_data_path = conflict_data_path
with jsonlines.open(src_data_path) as reader:
    for obj in reader:
        if "uid" in obj.keys():
            data_sample = dict()
            data_sample['question'] = obj['query']
            data_sample['context'] = obj['context']
            
            answers = []
            for ans in obj['gold_answers']:
                answers.append(ans['text'])
            data_sample['answer'] = answers
            
            converted_data.append(data_sample)

print(len(converted_data))

# Save the converted version of data to JSON
with open(data_path / "llama2_conflict_dev_filtered.json", 'w') as fout:
    json.dump(converted_data, fout, indent=4)

1881


### Utility Functions

In [6]:
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))    

def recall_score(prediction, ground_truth):
    prediction = normalize_answer(prediction)
    ground_truth = normalize_answer(ground_truth)
    return (ground_truth in prediction)

def get_gold_answers(data_path):
    gold_answers = []
    with jsonlines.open(data_path) as reader:
        for obj in reader:
            if "uid" in obj.keys():
                answers = []
                for ans in obj['gold_answers']:
                    answers.append(ans['text'])
                
                gold_answers.append(answers)
    
    return gold_answers

def get_score(preds, golds):
    em, recall = 0, 0
    for pred, gold in zip(preds, golds):
        if isinstance(gold, list):
            _em, _recall = 0, 0
            for g in gold:
                _em = max(exact_match_score(pred, g), _em)
                _recall = max(recall_score(pred, g), _recall)
            em += _em
            recall += _recall
        else:
            em += exact_match_score(pred, gold)
            recall += recall_score(pred, gold)
    em = em * 100 / (len(preds) + 1e-5)
    recall = recall * 100 / (len(preds) + 1e-5)
    return em, recall