In [2]:
import pandas as pd
from datasets import load_dataset

In [34]:
SOURCE = '../data/natural-questions/nq-train-kilt.jsonl.txt'
OUTPUT = '../data/natural-questions/nq-train-kilt_fixed.jsonl'

In [38]:
def get_references(output):
    references = []

    if len(output) == 0:
        return None
    
    if 'provenance' not in output[0].keys():
        return None

    for p in output[0]['provenance']:
        references.append({
            'document_id': p['wikipedia_id'],
            'index': p['start_paragraph_id'] + 1,
            'global_id': None,
            'text': None
        })

    return references


df = pd.read_json(SOURCE, lines=True)
df.iloc[0]['output']

df['id'] = df['id'].astype(str)
df['answer'] = df['output'].apply(lambda l: l[0]['answer'])
df['references'] = df['output'].apply(get_references)

df.loc[~df['references'].isna()].sample(10).to_json(OUTPUT, lines=True, orient='records')

In [34]:
def get_references(row):
    references = []
    output = row['output']

    for p in output[0]['provenance']:
        references.append({
            'document_id': p['wikipedia_id'],
            'index': p['start_paragraph_id'] + 1,
            'global_id': None,
            'text': None
        })

    return {
        'id': row['id'],
        'answer': output[0]['answer'],
        'references': references
    }


In [29]:
kilt_nq = load_dataset("facebook/kilt_tasks", name="nq")
kilt_nq

DatasetDict({
    train: Dataset({
        features: ['id', 'input', 'meta', 'output'],
        num_rows: 87372
    })
    validation: Dataset({
        features: ['id', 'input', 'meta', 'output'],
        num_rows: 2837
    })
    test: Dataset({
        features: ['id', 'input', 'meta', 'output'],
        num_rows: 1444
    })
})

In [35]:
train_clean = kilt_nq['train'].filter(lambda row: (len(row['output'][0]['provenance']) > 0))

subset = train_clean

relevant = subset.map(get_references)
relevant


Map: 100%|██████████| 74038/74038 [00:06<00:00, 10645.83 examples/s]


Dataset({
    features: ['id', 'input', 'meta', 'output', 'answer', 'references'],
    num_rows: 74038
})

In [33]:
relevant.select_columns(['id', 'input', 'answer', 'references']).to_pandas()

Unnamed: 0,id,input,answer,references
0,5328212470870865242,how i.met your mother who is the mother,Tracy McConnell,"[{'document_id': '40262098', 'global_id': None..."
1,5289242154789678439,who had the most wins in the nfl,Tom Brady,"[{'document_id': '13929036', 'global_id': None..."
2,-2500044561429484630,who played mantis guardians of the galaxy 2,Pom Klementieff,"[{'document_id': '43209054', 'global_id': None..."
3,-7491001389340565191,god's not dead a light in the darkness release...,"March 30 , 2018","[{'document_id': '56095752', 'global_id': None..."
4,4706363175863681196,when do the eclipse supposed to take place,"August 21 , 2017","[{'document_id': '4762233', 'global_id': None,..."
...,...,...,...,...
74033,5088894102963373562,who plays tupac mother in all eyez on me,Danai Gurira,"[{'document_id': '48919161', 'global_id': None..."
74034,2872167102791875297,who has scored more goals messi or ronaldo,Cristiano Ronaldo,"[{'document_id': '43992506', 'global_id': None..."
74035,-6443678100256960002,who made the world's first computer virus,Basit Farooq Alvi,"[{'document_id': '175852', 'global_id': None, ..."
74036,7553133436598981958,when does season 5 of senora acero come out,15 October 2018,"[{'document_id': '58414668', 'global_id': None..."


In [None]:

kilt_fever = load_dataset("kilt_tasks", name="fever")
allowed = ['SUPPORTS', 'REFUTES']

train_clean = kilt_fever['train'].filter(lambda row: (row['output'][0]['answer'] in allowed and len(row['output'][0]['provenance']) > 0))

subset = train_clean.select(range(self.cfg.documents.subset_size))
if 'subset_size' in self.cfg.documents.keys():
    subset = train_clean.select(range(self.cfg.documents.subset_size))
else:
    subset = train_clean

relevant = subset.map(self.extract_relevant_references)

relevant.select_columns(['id', 'input', 'answer', 'references']).to_json(self.cfg.documents.target)

return list(map(lambda r: r[0]['document_id'], list(relevant['references'])))