# Requirments



In [None]:
"""
# get majka database
!curl --remote-name-all https://nlp.fi.muni.cz/ma{/majka.w-lt}
!mv majka.w-lt drive/MyDrive/data/
# download czech squad
!curl --remote-name-all https://lindat.cz/repository/xmlui/bitstream/handle/11234/1-3069{/sqad_v3.tar.xz}
!mv sqad_v3.tar.xz drive/MyDrive/data/
!tar -xf drive/MyDrive/data/sqad_v3.tar.xz
"""

In [None]:
!pip install sentencepiece
!pip install datasets transformers
!pip install googletrans==4.0.0-rc1
!pip install wikipedia
!pip install rank_bm25
!pip install majka

In [None]:
import torch
import string
import os
import sys
import time
import shutil
import json
import numpy as np
import collections
import datetime
from tqdm.auto import tqdm
import warnings

from datasets import load_dataset, load_metric
from typing import List, Tuple, Dict
from collections import defaultdict
from transformers import AlbertTokenizerFast, AlbertForQuestionAnswering, TrainingArguments, Trainer, default_data_collator

from rank_bm25 import BM25Okapi, BM25Plus, BM25L
import re
import majka
import wikipedia
from googletrans import Translator
import requests


from google.colab import drive

In [None]:
# Remove pre-cached sample data in colab's directory
if os.path.isdir("sample_data"):
  shutil.rmtree("sample_data")
drive.mount('/content/drive')

Print gpu

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    !nvidia-smi

In [None]:
# This flag is the difference between SQUAD v1 or 2 (if you're using another dataset, it indicates if impossible
# answers are allowed or not).
squad_v2 = True
if squad_v2:
  model_checkpoint = "./drive/MyDrive/albert_models/albert_squad2_finetuned"
else:
  model_checkpoint = "./drive/MyDrive/albert_models/albert_finetuned"
batch_size = 16

# Data preprocessing

In [None]:
def prepare_train_features(examples):
    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [None]:
tokenizer = AlbertTokenizerFast.from_pretrained(model_checkpoint)
datasets = load_dataset("squad_v2" if squad_v2 else "squad")

In [None]:
max_length=384
doc_stride=128
tokenized_datasets = datasets.map(prepare_train_features,
                                  batched=True, 
                                  remove_columns=datasets["train"].column_names)

# Model fine-tuning

In [None]:
model = AlbertForQuestionAnswering.from_pretrained(model_checkpoint)

Create trainer

In [None]:
args = TrainingArguments(
    f"./drive/MyDrive/data/checkpoints",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
)

data_collator = default_data_collator

trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

Train and save

In [None]:
trainer.train()

In [None]:
trainer.save_model("./drive/MyDrive/data/albert_finetuned")

# Model evaluation

In [None]:
def prepare_validation_features(examples):
    # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

    # We keep the example_id that gave us this feature and we will store the offset mappings.
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1
        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])

        # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
        # position is part of the context or not.
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]

    return tokenized_examples

In [None]:
def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):
    all_start_logits, all_end_logits = raw_predictions
    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # The dictionaries we have to fill.
    predictions = collections.OrderedDict()

    # Logging.
    print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")

    # Let's loop over all the examples!
    for example_index, example in enumerate(tqdm(examples)):
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]

        min_null_score = None # Only used if squad_v2 is True.
        valid_answers = []
        
        context = example["context"]
        # Looping through all the features associated to the current example.
        for feature_index in feature_indices:
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions in our logits to span of texts in the original
            # context.
            offset_mapping = features[feature_index]["offset_mapping"]

            # Update minimum null prediction.
            cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]
            if min_null_score is None or min_null_score < feature_null_score:
                min_null_score = feature_null_score

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue

                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(
                        {
                            "score": start_logits[start_index] + end_logits[end_index],
                            "text": context[start_char: end_char]
                        }
                    )
        
        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
        else:
            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
            # failure.
            best_answer = {"text": "", "score": 0.0}
        
        # Let's pick our final answer: the best one or the null answer (only for squad_v2)
        if not squad_v2:
            predictions[example["id"]] = best_answer["text"]
        else:
            answer = best_answer["text"] if best_answer["score"] > min_null_score else ""
            predictions[example["id"]] = answer

    return predictions

Final evaluation

In [None]:
# get ground truth features
validation_features = datasets["validation"].map(prepare_validation_features,
                                                 batched=True,
                                                 remove_columns=datasets["validation"].column_names)

# get raw predictions
raw_predictions = trainer.predict(validation_features)

validation_features.set_format(type=validation_features.format["type"], 
                               columns=list(validation_features.features.keys()))

In [None]:
# hyperparameters
max_answer_length = 30
n_best_size = 20

# map examples to features
examples = datasets["validation"]
features = validation_features

example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = collections.defaultdict(list)
for i, feature in enumerate(features):
    features_per_example[example_id_to_index[feature["example_id"]]].append(i)

# get final predictions
final_predictions = postprocess_qa_predictions(datasets["validation"], validation_features, raw_predictions.predictions)
# get metric used
metric = load_metric("squad_v2" if squad_v2 else "squad")

#evaluate
if squad_v2:
    formatted_predictions = [{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in final_predictions.items()]
else:
    formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
references = [{"id": ex["id"], "answers": ex["answers"]} for ex in datasets["validation"]]
metric.compute(predictions=formatted_predictions, references=references)

# Decode

In [None]:
def decode(output, context, offset_mappings):
  # hyperparameters
  max_answer_length = 10
  n_best_size = 10
  
  # enumerate over all outputs (max output size is 500 tokens in a log prob tensor)
  valid_answers = []
  for i, _ in enumerate(output.start_logits):

    start_logits = output.start_logits[i].cpu().detach().numpy()
    end_logits = output.end_logits[i].cpu().detach().numpy()
    offset_mapping = offset_mappings[i]

    # Gather the indices the best start/end logits:
    start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
    end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
    for start_index in start_indexes:
        for end_index in end_indexes:
            # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
            # to part of the input_ids that are not in the context.
            if (
                start_index >= len(offset_mapping)
                or end_index >= len(offset_mapping)
                or offset_mapping[start_index] is None
                or offset_mapping[end_index] is None
            ):
                continue
            # Don't consider answers with a length that is either < 0 or > max_answer_length.
            if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                continue
            if start_index <= end_index: # We need to refine that test to check the answer is inside the context
                start_char = offset_mapping[start_index][0]
                end_char = offset_mapping[end_index][1]
                valid_answers.append(
                    {
                        "score": start_logits[start_index] + end_logits[end_index],
                        "text": context[start_char: end_char].strip()
                    }
                )

  valid_answers = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[:n_best_size]
  return valid_answers

Get answers

In [None]:
def get_answers(question, context):
  inputs = tokenizer(question, context, 
                     return_tensors='pt',
                     truncation="only_second",
                     max_length=384, # to prevent cuda running out of memory
                     stride=128,     # overlap within splitted long
                     return_offsets_mapping=True,
                     return_overflowing_tokens=True,
                     padding="max_length")
  inputs.to(device)

  outputs = model(inputs['input_ids'], 
                  token_type_ids=inputs['token_type_ids'],
                  attention_mask=inputs['attention_mask'])
  
  valid_answers = decode(outputs, context, inputs['offset_mapping'])
  return valid_answers # [0]['text']

# Loading model


Choose trained version

In [None]:
squad_v2 = False
if squad_v2:
  model_checkpoint = "./drive/MyDrive/albert_models/albert_squad2_finetuned"
else:
  model_checkpoint = "./drive/MyDrive/albert_models/albert_finetuned"

# model_checkpoint = "ktrapeznikov/albert-xlarge-v2-squad-v2"

Load model and tokenizer

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

tokenizer = AlbertTokenizerFast.from_pretrained(model_checkpoint)
model = AlbertForQuestionAnswering.from_pretrained(model_checkpoint).to(device)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model has {count_parameters(model)} parameters")

# Retriever


In [None]:
translator = Translator()
wikipedia.set_lang("cs") 

Build index for searching through relevant title names on czech wiki:

In [None]:
def get_title_search_index():
  f = open("drive/MyDrive/data/wiki/cswiki-latest-all-titles-in-ns0", "r")
  titles = []

  for line in f: 
    title = ((" ").join(line.split("_"))).strip()
    title = title.strip('\n')
    titles.append(title)

  f.close()

  # tokenize for bm25
  tok_titles = []
  for title in titles:
    tok_tit = re.split(" ", title.lower())
    for tok in tok_tit:
      if tok == "":
        tok_tit.remove("")
    tok_titles.append(tok_tit)

  bm25 = BM25Okapi(tok_titles)

  return bm25, titles

In [None]:
def search_titles(question, bm25, titles):
  tokenized_query = delete_common(lemmatize(question.lower()))
  # print(tokenized_query)
  results = bm25.get_top_n(tokenized_query, titles, n=5)

  return results

Extract named entities' lemmatas from question

In [None]:
def getNamed_ER(question):
  """
  Extracts named entities from the question.

  """

  URL = "https://nlp.fi.muni.cz/projekty/ner/nerJSON.py"
  text = question
  PARAMS = {'text': text}
  r = requests.get(url = URL, params = PARAMS)
  data = r.json() 

  lemmatas = []

  if data != {}:
    for item in data:
      lemma = data[item]['lemma']
      lemmatas.append(lemma)

  return lemmatas

In [None]:
def extract_que_ans(dirnum, filename, lemmatized=False):
  # get question and answer from czech squad 
  f = open(f"drive/MyDrive/data/cz_sqad/data/{dirnum}/{filename}", "r")
  q = f.read().split("\n")
  question = ""

  #lemmatized parsed
  if lemmatized:
    answer = ""
    for line in q:
      if len(line.split()) < 2:
        continue
      line = line.split("\t")[1]
      if line in {"<s>", "<g/>", "</s>"}:
        answer = answer[:-1]
        continue
      if line != "?":
        answer += line + " "
      else:
        answer += " ? "

    f.close()
    return answer

  # normal parse
  for line in q:
    line = line.split("\t")[0]
    if line in {"<s>", "<g/>", "</s>"}:
      question = question[:-1]
      continue
    
    if line != "?":
      question += line + " "
    else:
      question += " ? "

  f.close()
  return question

In [None]:
# save the most common czech words
# common = "být a se v na ten on že s z který mít do já o k i jeho ale svůj jako za moci rok pro tak po tento co když všechen už jak aby od nebo říci jeden jen můj jenž člověk ty stát u muset velký chtít také až než ještě při jít pak před dva však ani vědět nový hodně podle další celý jiný mezi dát tady den tam kde doba každý místo dobrý takový strana protože nic začít něco vidět říkat ne sám bez či dostat nějaký proto"
common = "kdy být a se v na ten on že s z který mít do já o k i jeho ale svůj jako za moci pro tak po tento co když všechen už jak aby od nebo říci jeden jen můj jenž ty stát u muset chtít také až než ještě při jít pak před však ani vědět hodně podle další celý jiný mezi dát tady tam kde každý takový protože nic něco ne sám bez či dostat nějaký proto"
common = common.split()
punctuation = ". , ? ! ... \" ( ) ; - /"
punctuation = punctuation.split()

# decides if query token is common
def iscommon(x):
  if x in common or x in punctuation:
    return True
  else:
    return False

# remove the most common czech words from the query tokens (low information value)
def delete_common(tokens):
  tokens = [x for x in tokens if not iscommon(x)]
      
  return tokens

In [None]:
def search_again(tokens):

  searched_term = (' ').join(tokens)
  #print(searched_term)
  doc_list = wikipedia.search(searched_term, results=1)

  if len(tokens) == 0:
    return []

  if len(doc_list) == 0:
    del tokens[0]
    return search_again(tokens)

  return doc_list

In [None]:
morph = majka.Majka('drive/MyDrive/data/majka.w-lt')
morph.flags |= majka.ADD_DIACRITICS  # find word forms with diacritics
morph.flags |= majka.DISALLOW_LOWERCASE  # do not enable to find lowercase variants
morph.flags |= majka.IGNORE_CASE  # ignore the word case whatsoever
morph.flags = 0  # unset all flags

morph.tags = False  # return just the lemma, do not process the tags
morph.first_only = True  # return only the first entry
morph.negative = "ne"

# returns lemma of each token in a list of lemmatized tokens
def lemmatize(text):

  tok_text = text.lower()
  tok_text = re.split("\W", text)

  # lemmatize each token
  lemmatized_tokens = []
  for token in tok_text:
    if token == '':
      continue
    lemma = morph.find(token)
    if len(lemma) == 0:
      lemmatized_tokens.append(token)
    else:
      lemmatized_tokens.append(lemma[0]['lemma'])

  return lemmatized_tokens


In [None]:
def get_doc_list(question, bm_articles_index, titles):
  # get names entities if present
  named_ERs = getNamed_ER(question)
  # get relevant article title names
  relevant_titles = search_titles(question, bm_articles_index, titles)

  #search for documents
  max_docs = 1
  doc_list = []

  # search based on recognised named entity
  if len(named_ERs) > 0:
    article = wikipedia.search(named_ERs[0], results=max_docs)
    if len(article) > 0:
      doc_list.append(article[0])
  # search based on best wiki title match
  if len(relevant_titles) > 0:
    article = wikipedia.search(relevant_titles[0], results=max_docs)
    if len(article) > 0:
      doc_list.append(article[0])

  # basic search for the question
  article = wikipedia.search(question, results=max_docs)
  # simplify the search if its too bad
  if len(article) == 0:
    # extract important for wiki
    tokens = delete_common(lemmatize(question))
    article = search_again(tokens)
  doc_list.append(article[0])

  return doc_list

In [None]:
def retrieve(question, bm_articles_index, titles):  
  """
  returns the top 3 paragraphs for the given question
  """
  question = question.strip('?')

  doc_list = get_doc_list(question, bm_articles_index, titles)

  # CONTROL PRINT######################################
  print(doc_list)

  # split docs into paragraphs
  pars = []
  lemm_pars = []

  for doc in doc_list:
    # get whole page content
    try:
      doc = wikipedia.page(doc)
    except wikipedia.DisambiguationError as e:
      s = e.options[0]
      try:
        doc = wikipedia.page(s)
      except wikipedia.DisambiguationError:
        continue
    
    result = re.split('== .*. ==|\\n\\n', doc.content)

    # save stripped paragraphs
    for par in result:
      par = par.strip()
      par = par.strip('=')
      par = par.strip('\n')
      par = par.strip('\n\n')
      par = par.strip('\r\n')

      if par == '' or par == '\n':
        continue

      # get lemmas
      lemm_pars.append((' ').join(delete_common(lemmatize(par.lower()))))
      pars.append(par)

  # tokenize for bm25
  tok_text = []
  for par in lemm_pars:
    tok_par = par.lower()
    tok_par = re.split("\W", tok_par)
    for tok in tok_par:
      if tok == "":
        tok_par.remove("")
    tok_text.append(tok_par)

  # build index
  # bm25 = BM25L(tok_text)
  bm25 = BM25Plus(tok_text)
  # bm25 = BM25Okapi(tok_text)

  # tokenize and lemmatize the query
  tokenized_query = (' ').join(delete_common(lemmatize(question.lower())))
  tokenized_query = re.split("\W", tokenized_query)

  # get results
  results = bm25.get_top_n(tokenized_query, pars, n=3)

  return results


In [None]:
def count_log_conf(best_answer, all_answers):
  log_conf = 0
  for answer in all_answers:
    if (best_answer in answer['text']) or (answer['text'] in best_answer):
      log_conf += answer['score']
  
  return log_conf

# Testing with sqad

In [None]:
# write results to
f = open("drive/MyDrive/data/saved_answers/saved_answers_albert/1-100_ner_with_info.txt", "a")

# build index for searching through wiki article titles
bm25_for_title_index, titles = get_title_search_index()

In [None]:
# write results to
f = open("drive/MyDrive/data/saved_answers/saved_answers_albert/1-100_ner_with_info.txt", "a")
# write first question-answer pairs in sqad
from_q = 95
to_q = 95
warnings.filterwarnings("ignore")
for i in range(from_q, to_q+1):

  # get question number
  name = ""
  for _ in range(len(str(i)), 6):
    name += "0"
  name += str(i)

  # extract from dataset
  question = extract_que_ans(name, "01question.vert")
  correct_answer = extract_que_ans(name, "09answer_extraction.vert")
  lemmatized_answer = extract_que_ans(name, "09answer_extraction.vert", lemmatized=True)

  # wiki search
  documents = retrieve(question, bm25_for_title_index, titles)

  # for saving the best results
  bestAnswers = []
  bestDocs = []
  bestLogProbs = []
  bestSummedLogProbs = []

  question_cs = question # save czech question

  # iterate over retrieved paragraphs
  for document in documents:

    # strip whitespaces
    document = document.strip()

    # chceck if any document has been found for the question
    if document == "":
      f.write("question: " + question + "\n" +
              "answer: odpověď nenalezena" + "\n" + 
              "correct answer: " + correct_answer + "\n\n")
      continue;
    try:
      document_cs = document
      document = translator.translate(document, src='cs', dest='en').text
    except TypeError:
      continue

    # remove some trash
    # TODO this not cool (in retriever)
    if (document_cs.strip().startswith("Obrázky, zvuky či videa k tématu")):
      continue

    bestDocs.append(document_cs)
    # translate
    question = translator.translate(question, src='cs', dest='en').text

    #get answer -------------------------------------------
    answers = get_answers(question, document)
    log_conf = 0
    answer = ''

    for answer in answers:
      if answer['text'] != '':
        log_conf = answer['score']
        answer = answer['text']
        log_conf_summed = count_log_conf(answer, answers)
        break
    #######################################################

    # save probs and answer
    bestAnswers.append(answer)
    bestLogProbs.append(log_conf)
    bestSummedLogProbs.append(log_conf_summed)

  # check if any answer was found
  if len(bestLogProbs) == 0 or bestAnswers[np.argmax(bestLogProbs, axis=0)] == '':
    f.write("question: " + question_cs + "\n" +
              "answer: odpověď nenalezena" + "\n" + 
              "correct answer: " + correct_answer + "\n\n")
    continue

  # get the best doc
  # get best answer from retriever according to reader
  # get the best confidence
  document = bestDocs[np.argmax(bestLogProbs, axis=0)]
  answer = bestAnswers[np.argmax(bestLogProbs, axis=0)]

  # translate the final answer
  answer_en = answer
  try:
    answer =  translator.translate(answer, src='en', dest='cs').text
  except IndexError:
    pass

  # convert to string to write to file
  bestAnswers = (";; ").join(bestAnswers)
  bestLogProbs = (";; ").join(str(v) for v in bestLogProbs)
  bestSummedLogProbs = (";; ").join(str(v) for v in bestSummedLogProbs)

  # write the result to file
  f.write("otázka č." + name + ": " + question_cs + "\n" +
          "::odpověď: " + answer + " / " + answer_en + "\n" + 
          "::správná odpověď podle sqad : " + correct_answer + "\n\n" +
          bestAnswers + "\n" +
          bestLogProbs + "\n" +
          bestSummedLogProbs + "\n\n" +
          # "::sqad lemmatized : " + lemmatized_answer + "\n" +
          "----------------------------------------------------------------\n"+
          "získaný dokument: " + document + 
          "\n----------------------------------------------------------------\n"+
          "----------------------------------------------------------------"+
          "\n\n")
  
  # controlprint
  print("wrote: " + name)

f.close()


# Ask a question

In [None]:
def find_answer(question, bm25_for_title_index, titles):
  """
  finds the answer to the question
  """
  question_cs = question # save czech question

  # wiki search
  documents = retrieve(question, bm25_for_title_index, titles)

  # for saving the best results
  bestAnswers = []
  bestDocs = []
  bestLogProbs = []

  # iterate over retrieved paragraphs
  for document in documents:

    # strip whitespaces
    document = document.strip()

    # chceck if any document has been found for the question
    if document == "":
      continue;
    try:
      document_cs = document
      document = translator.translate(document, src='cs', dest='en').text
    except TypeError:
      continue

    # remove some trash
    # TODO this not cool (in retriever)
    if (document_cs.strip().startswith("Obrázky, zvuky či videa k tématu")):
      continue

    # translate
    question = translator.translate(question, src='cs', dest='en').text

    #get answer -------------------------------------------
    answers = get_answers(question, document)
    print(answers)
    log_conf = 0
    answer = ''
    for answer in answers:
      if answer['text'] != '':
        log_conf = answer['score']
        answer = answer['text']
        log_conf_summed = count_log_conf(answer, answers)
        break
    #######################################################

    # save probs and answer
    bestAnswers.append(answer)
    bestLogProbs.append(log_conf_summed)
    # save retrieved doc
    bestDocs.append(document_cs)

  ############################################################
  # check if any answer was found
  if len(bestLogProbs) == 0 or bestAnswers[np.argmax(bestLogProbs, axis=0)] == '':
    return "odpověď nenalezena"

  print(bestAnswers)
  print(bestLogProbs)
  # get the best doc
  # get best answer from retriever according to reader
  document = bestDocs[np.argmax(bestLogProbs, axis=0)]
  answer = bestAnswers[np.argmax(bestLogProbs, axis=0)]

  # translate the final answer
  answer_en = answer
  answer =  translator.translate(answer, src='en', dest='cs').text

  return answer, answer_en, document

In [None]:
# build index for searching through wiki article titles
bm25_for_title_index, titles = get_title_search_index()

In [None]:
warnings.filterwarnings("ignore")

question = "Jak se jmenuje otec spisovatele Jiřího Muchy ?"
answer = find_answer(question, bm25_for_title_index, titles)

if len(answer) != 3:
  print(answer)
else:
  print(answer[0])
  print(answer[1])