Extract DistilBERT answer from each pair (question, passage).

In [25]:
from transformers import pipeline
from datasets import Dataset
from transformers.pipelines.pt_utils import KeyPairDataset
from tqdm.auto import tqdm
import json

In [26]:
def read_rows(path):
    rows = []
    for line in open(path):
        rows.append(json.loads(line))
    return rows

In [27]:
def write_json_format(path_out, rows):
    f_out = open(path_out, 'w')
    for row in rows:
        f_out.write(json.dumps(row, ensure_ascii=False)+'\n')

In [28]:
model_name = 'distilbert-base-cased-distilled-squad'
path_dataset = 'test-B-big/allegro.jl'

In [29]:
rows = read_rows(path_dataset)

In [30]:
contexts = [row['passage_translated'] for row in rows]
questions = [row['question_translated'] for row in rows]

In [31]:
pipe = pipeline("question-answering", model=model_name, device=0)

In [32]:
dataset = Dataset.from_dict({"question": questions, "context": contexts})

In [33]:
answers = []
for result in pipe(dataset, batch_size=16):
    answers.append(result['answer'])

In [34]:
for row, answer in zip(rows, answers):
    row["distillbert_answer"] = answer

In [35]:
rows[-10:]

[{'question_id': 499,
  'question_text': 'Co mogę umieścić w dodatkowych informacjach o dostawie i płatności?',
  'passage_id': '193',
  'passage_text': 'Postaraj się, aby jakość Twoich ocen na koncie, które bierze udział w promocji nie spadła poniżej 98%. Jeśli jednak tak się stanie lub jeśli przestaniesz spełniać inne warunki promocji, Twoje konto zostanie z niej wykluczone z końcem opłaconego okresu.',
  'score_bm25': 0.2670532,
  'score_bm25_not_lemmatized': 0.27342082998151523,
  'score_bm25_bigrams': 0,
  'passage_translated': 'Try to ensure that the quality of your rating in your account that participates in the promotion does not fall below 98%. However, if this happens or if you stop fulfilling other promotional conditions, your account will be excluded from it at the end of the paid period.',
  'question_translated': 'What can I put in the additional delivery and payment information?',
  'distillbert_answer': 'promotional conditions'},
 {'question_id': 499,
  'question_text':

In [36]:
write_json_format(path_dataset, rows)