In [None]:
import os
import pickle

import numpy as np
import pandas as pd
from simpletransformers.t5 import T5Model
from sentence_transformers import SentenceTransformer

In [None]:
questions = [
    'W którym roku odbyła się bitwa pod Grunwaldem?',
    'Jaka jest stolica Polski?',
]

## Retrieve Passages

In [None]:
def encode_texts(texts, encoder):
    emb = encoder.encode(texts, convert_to_numpy=True)
    emb = emb / np.sqrt(np.sum(emb**2, axis=1, keepdims=True))
    return emb

def retrieve_passages(questions, passages, encoder, max_passage_len=300, n_candidates=10):
    
    encoded_passages = np.vstack([p['emb'] for p in passages])
    encoded_questions = encode_texts(questions, encoder)

    candidates = []
    for qid, q in enumerate(questions):
        scores = np.inner(encoded_questions[qid], encoded_passages)
        unsorted_ranks = np.argpartition(scores, -n_candidates)[-n_candidates:]
        ranks = [e[1] for e in sorted(zip(scores[unsorted_ranks], unsorted_ranks), reverse=True)]

        context = []
        for rank, aid in enumerate(ranks):
            aid = int(aid)
            context.append(f"{passages[aid]['title']}: {passages[aid]['text'][:max_passage_len]}")
        candidates.append(context)

    return candidates

In [None]:
encoder = SentenceTransformer('piotr-rybak/poleval2021-task4-herbert-large-encoder')

passages = []
for file in os.listdir('../passages/'):
    if file.endswith('.pkl'):
        with open('../passages/' + file, 'rb') as f:
            passages += pickle.load(f)

candidates = retrieve_passages(questions, passages, encoder)

## Generate Answer

In [None]:
def prepare_input(questions, candidates):
    return [f"Pytanie: {q} | Kontekst: {' | '.join(c)}" for q, c in zip(questions, candidates)]

In [None]:
model = T5Model('t5', 'piotr-rybak/poleval2021-task4-plt5-base-qa', use_cuda=False)

input_text = prepare_input(questions, candidates)
model.predict(input_text)