# Requirments



For everything to work.

In [None]:
# !pip install torch # if you dont run in collab
!pip install sentencepiece
!pip install datasets transformers

!pip install googletrans==4.0.0-rc1
!pip install wikipedia
!pip install rank_bm25
!pip install majka
!pip install corpy

!pip install tensorflow-gpu==1.15.2
!pip install deeppavlov
!pip install git+https://github.com/deepmipt/bert.git@feat/multi_gpu

Importing important libraries

In [None]:
import torch
import string
import os
import sys
import time
import shutil
import json
import numpy as np
import collections
import datetime
import warnings
import nltk.data
import nltk
import re

import pickle
from xml.dom import minidom
from tqdm.auto import tqdm

from transformers import AlbertTokenizerFast, AlbertForQuestionAnswering
from transformers import BertTokenizerFast, BertForQuestionAnswering

from deeppavlov import build_model, configs

from rank_bm25 import BM25Okapi, BM25Plus, BM25L

import majka
from corpy.morphodita import Tagger

import wikipedia
from googletrans import Translator

from google.colab import drive

warnings.filterwarnings("ignore")
nltk.download('punkt')

Download and extract morphodita tagger model for lemmatization

In [None]:
!curl --remote-name-all https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-1836{/czech-morfflex-pdt-161115.zip}
!unzip ./czech-morfflex-pdt-161115.zip

Download majka database for lemmatization

In [None]:
!curl --remote-name-all https://nlp.fi.muni.cz/ma{/majka.w-lt}

# Reader

In [None]:
class Reader():

  def __init__(self, model_checkpoint, model_type="albert", max_answer_length=10, n_best_size=10, max_length=384, stride=128, use_cpu=False):
    # load all parameters of the reader
    self.max_answer_length = max_answer_length  # max answer span length
    self.n_best_size = n_best_size  # 
    self.max_length = max_length  # max count of tokens in one tokenized passage
    self.stride = stride  # the length of overlap between two mini-batches of tokenizer

    # choose device; cuda if available
    self.device = torch.device("cuda:0" if (torch.cuda.is_available() and use_cpu == False) else "cpu")

    if model_type == 'albert':
      # load tokenizer and model from pretrained checkpoint
      self.tokenizer = AlbertTokenizerFast.from_pretrained(model_checkpoint)
      # load model to device if possible
      self.model = AlbertForQuestionAnswering.from_pretrained(model_checkpoint).to(self.device)
    elif model_type == 'mbert':
      # load tokenizer and model from pretrained checkpoint
      self.tokenizer = BertTokenizerFast.from_pretrained(model_checkpoint)
      # load model to device if possible
      self.model = BertForQuestionAnswering.from_pretrained(model_checkpoint).to(self.device)
    else:
      print("Wrong model type parameter.")
      return None

    print("Model loaded from: " + model_checkpoint)
    print(f"Model has {self.count_parameters(self.model)} parameters")
    print("Device selected:")
    print(self.device)


  def decode(self, output, context, offset_mappings):
    """
    get the text span from the unnormalized log probabilities

    method has been partly borrowed from 
    https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/question_answering.ipynb

    its also thoroughly commented there
    """
    # enumerate over all outputs (max output size is 500 tokens in a log prob tensor)
    valid_answers = []

    for i, _ in enumerate(output.start_logits):

      start_logits = output.start_logits[i].cpu().detach().numpy()
      end_logits = output.end_logits[i].cpu().detach().numpy()
      offset_mapping = offset_mappings[i]

      # Gather the indices the best start/end logits:
      start_indexes = np.argsort(start_logits)[-1 : - self.n_best_size - 1 : -1].tolist()
      end_indexes = np.argsort(end_logits)[-1 : - self.n_best_size - 1 : -1].tolist()
      for start_index in start_indexes:
          for end_index in end_indexes:
              # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
              # to part of the input_ids that are not in the context.
              if (
                  start_index >= len(offset_mapping)
                  or end_index >= len(offset_mapping)
                  or offset_mapping[start_index] is None
                  or offset_mapping[end_index] is None
              ):
                  continue
              # Don't consider answers with a length that is either < 0 or > max_answer_length.
              if end_index < start_index or end_index - start_index + 1 > self.max_answer_length:
                  continue
              if start_index <= end_index: # We need to refine that test to check the answer is inside the context
                  start_char = offset_mapping[start_index][0]
                  end_char = offset_mapping[end_index][1]
                  valid_answers.append(
                      {
                          "score": start_logits[start_index] + end_logits[end_index],
                          "text": context[start_char: end_char].strip()
                      }
                  )
    valid_answers = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[:self.n_best_size]
    return valid_answers


  def get_answers(self, question, context):
    """
    get the best answers from the context to the question 

    """
    # tokenize the input for the model using special huggingface tokenizer
    inputs = self.tokenizer(question, context, 
                      return_tensors='pt',
                      truncation="only_second",
                      max_length=self.max_length, # to prevent cuda running out of memory
                      stride=self.stride,     # overlap within splitted long
                      return_offsets_mapping=True,
                      return_overflowing_tokens=True,
                      padding="max_length")
    inputs.to(self.device) # port inputs to gpu

    # get the model predictions
    outputs = self.model(inputs['input_ids'], 
                    token_type_ids=inputs['token_type_ids'],
                    attention_mask=inputs['attention_mask'])
    
    # use the decode function to get the n_best_size best valid answers
    valid_answers = self.decode(outputs, context, inputs['offset_mapping'])
    
    return valid_answers


  def count_parameters(self, model):
    """
    Counts the parameters of the model

    """

    return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
    

# Retriever


In [None]:
class Retriever():

  def __init__(self, wiki_abstracts, wiki_titles, majka_file, dita_file, index_file, use_majka=False, download_ner_model=False, model_type='mbert'):
    # set wiki api language to search the Czech Wikipedia
    wikipedia.set_lang("cs") 

    # save the model type for some small nuances in the retriever
    self.model_type = model_type

    # save the most common czech words (stop words)
    common = "kdy být a se v na ten on že s z který mít do já o k i jeho ale svůj jako za moci pro tak po tento co když všechen už jak aby od nebo říci jeden jen můj jenž ty stát u muset chtít také až než ještě při jít pak před však ani vědět hodně podle další celý jiný mezi dát tady tam kde každý takový protože nic něco ne sám bez či dostat nějaký proto"
    self.common = common.split()

    # save punctuation to be removed
    punctuation = ". , ? ! ... \" ( ) ; - /"
    self.punctuation = punctuation.split()

    # majka lemmatizer settings
    self.morph = majka.Majka(majka_file)
    self.morph.flags = 0  # unset all flags
    self.morph.tags = False  # return just the lemma, do not process the tags
    self.morph.first_only = True  # return only the first entry
    self.morph.negative = "ne"
    
    print("Loading lemmatizer")
    # morphodita lemmatizer
    self.tagger = Tagger(dita_file)
    print("Lemmatizer loaded")

    print("Building titles index")
    # load wiki titles and build index for search
    self.bm25_articles_index, self.titles = self.get_title_search_index(wiki_titles)
    print("Titles index done")

    print("Building articles index")
    # load wiki abstracts and build index for search
    self.bm25_abstract_index, self.abstract_titles, self.abstracts = self.get_abstract_search_index(wiki_abstracts, index_file)
    print("articles index done")

    print("Loading tokenizer")
    # load tokenizer to split text into sentences
    self.tokenizer = nltk.data.load('tokenizers/punkt/czech.pickle')
    print("Tokenizer loaded")

    print("Building ner model")
    # Download and load model (set download=False to skip download phase)
    self.ner = build_model(configs.ner.ner_ontonotes_bert_mult, download=download_ner_model)
    print("Ner model loaded")

    # choose the correct lemmatizer
    if use_majka:
      self.lemmatize = self.lemmatize_majka
      print("using Majka")
    else:
      self.lemmatize = self.lemmatize_morphodita
      print("using MorphoDiTa")

    print("Retriever initialized")


  def get_title_search_index(self, wiki_titles):
    """
    Build index for searching through relevant title names on czech wiki:

    """

    # load all the titles from the file
    f = open(wiki_titles, "r")
    titles = []

    for line in f: 
      title = ((" ").join(line.split("_"))).strip()
      title = title.strip('\n')
      titles.append(title)
    f.close()

    # tokenize for bm25
    tok_titles = []
    for title in titles:

      tok_tit = re.split(" ", title.lower())
      for tok in tok_tit:
        if tok == "":
          tok_tit.remove("")

      tok_titles.append(tok_tit)
    
    #build index
    bm25 = BM25Okapi(tok_titles)

    return bm25, titles


  def get_abstract_search_index(self, saved_abstracts, index_file):
    """
    Build index for searching through relevant abstracts on czech wiki:

    """
    # load abstracts and titles from preprocessed JSON file
    with open(saved_abstracts, "r") as f:
      wiki_abstracts = json.load(f)

    titles = []
    for idx in wiki_abstracts:
      titles.append(wiki_abstracts[idx]['title'])

    abstracts = []
    for idx in wiki_abstracts:
      abstracts.append(wiki_abstracts[idx]['abstract'])
    
    # if index saved, load it with pickle and return
    if os.path.isfile(index_file):
      with open(index_file, "rb") as fd:
        print("loading from pickle")
        bm25 = pickle.load(fd)
        return bm25, titles, abstracts

    # process for creating bm25 index
    tok_abstracts = []
    for abstract in abstracts:
      tok_abstract = self.delete_common(self.lemmatize(abstract.lower()))
      tok_abstracts.append(tok_abstract)

    # build index
    bm25 = BM25Okapi(tok_abstracts)

    # save index with pickle
    with open(index_file, "wb") as fd:
      pickle.dump(bm25, fd, pickle.HIGHEST_PROTOCOL)

    return bm25, titles, abstracts


  def search_titles(self, question):
    """
    Search with bm25 among the wiki titles

    """
    tokenized_query = self.delete_common(self.lemmatize(question.lower()))
    results = self.bm25_articles_index.get_top_n(tokenized_query, self.titles, n=5)

    return results
  

  def search_abstracts(self, question):
    """
    Search with bm25 among the wiki abstracts

    """
    tokenized_query = self.delete_common(self.lemmatize(question.lower()))
    results = self.bm25_abstract_index.get_top_n(tokenized_query, self.abstract_titles, n=5)

    return results


  def get_named_entities(self, question):
    """
    Extracts named entities from the question.

    """
    # tag the questions
    ner_tags = self.ner([question])

    # extracts the named entities
    named_entities = ""
    for idx, tag in enumerate(ner_tags[1][0]):
      if tag != 'O':
        named_entities += ner_tags[0][0][idx] + " "

    # save NEs in a list
    NEs = []
    if len(named_entities.strip()) != 0:
      NEs.append(named_entities.strip())

    return NEs

  
  def iscommon(self, x):
    """
    decides if query token is common

    """
    if x in self.common or x in self.punctuation:
      return True
    else:
      return False


  def delete_common(self, tokens):
    """
    Remove the most common czech words from the query tokens (low information value)

    """
    tokens = [x for x in tokens if not self.iscommon(x)]
        
    return tokens

  
  def lemmatize_majka(self, text):
    """
    Returns lemma of each token in a list of lemmatized tokens

    """
    # tokenize
    tok_text = re.split("\W", text)

    # lemmatize each token
    lemmatized_tokens = []
    for token in tok_text:
      if token == '':
        continue
      lemma = self.morph.find(token)
      if len(lemma) == 0:
        lemmatized_tokens.append(token)
      else:
        lemmatized_tokens.append(lemma[0]['lemma'])

    return lemmatized_tokens
  

  def lemmatize_morphodita(self, text):
    """
    Returns lemma of each token in a list of lemmatized tokens

    """
    # tokenize and join again
    # (this works better with morphodita which sometimes fails to tokenize the 
    #  text correctly if it wasnt split before like this - it just works)
    text = re.split("\W", text)
    text = (" ").join(text)

    tokens = list(self.tagger.tag(text, convert='strip_lemma_id'))

    lemmas = []
    for token in tokens:
      lemmas.append(token.lemma)

    return lemmas


  def search_again(self, tokens):
    """
    Performs repeated search in case wiki api didnt find any documents

    """
    # join the searched tokens and try to use wiki search
    searched_term = (' ').join(tokens)
    if searched_term.strip() != "":
      doc_list = wikipedia.search(searched_term, results=1)

    # if no tokens left, end
    if len(tokens) == 0:
      return []
    # if nothing was found, strip the first searched token 
    # and perform the search recursivly like this
    if len(doc_list) == 0:
      del tokens[0]
      return self.search_again(tokens)
    # return the found article
    return doc_list


  def get_doc_list(self, question):
    """
    Returns top 1-3 wiki arcitles that might answer the question topic

    """

    # get names entities if present
    named_ERs = self.get_named_entities(question)
    # get relevant article title names
    relevant_titles = self.search_titles(question)
    # get article titles from relevant abstracts search
    relevant_abstracts = self.search_abstracts(question)

    #search for documents - 1 article for each search method
    max_docs = 1
    doc_list = []

    # search based on recognised named entity
    if len(named_ERs) > 0:
      article = wikipedia.search(named_ERs[0], results=max_docs)
      if len(article) > 0:
        doc_list.append(article[0])
    # search based on best wiki title match
    if len(relevant_titles) > 0:
      article = wikipedia.search(relevant_titles[0], results=max_docs)
      if len(article) > 0:
        doc_list.append(article[0])
    # search based on best wiki abstract match
    if len(relevant_abstracts) > 0:
      article = wikipedia.search(relevant_abstracts[0], results=max_docs)
      if len(article) > 0:
        doc_list.append(article[0])
        

    # basic search for the non-processed question
    article = wikipedia.search(question, results=max_docs)
    # simplify the search if its too bad and search recursively
    if len(article) == 0:
      # extract important for wiki
      tokens = self.delete_common(self.lemmatize(question.lower()))
      article = self.search_again(tokens)
    if len(article) > 0:
      doc_list.append(article[0])

    return doc_list

  
  def normalize_length(self, par):
    """
    Splits too long paragraph into smaller ones

    """

    #split long paragraph into sentences
    sentences = self.tokenizer.tokenize(par)

    normalized_pars = []
    new_paragraph = ""

    # iterate over sentences
    for idx, sentence in enumerate(sentences):
      # if max paragraph length was reached, save the created paragraph
      # and start appending to a new one
      if len(new_paragraph) + len(sentence) > 1300:
        normalized_pars.append(new_paragraph)
        new_paragraph = ""
        # make some overlap by taking some sentences from the previous paragraph
        for k, trailing in enumerate(sentences[idx-2:idx]):
          new_paragraph += trailing
      else:
        new_paragraph += sentence
    
    return normalized_pars


  def get_wiki_page(self, doc):
    """
    Get the Wikipedia page content

    """
    # note: wikipedia api throws some errors that are hard to deal with
    # that is the reason for this unusual structure
    try:
      doc = wikipedia.page(doc, auto_suggest=False)
    except wikipedia.DisambiguationError as e:
      s = e.options[0]
      try:
        doc = wikipedia.page(s, auto_suggest=False)
      except wikipedia.DisambiguationError or wikipedia.PageError:
        return "not_found"

    return doc


  def split_documents(self, doc_list):
    """
    Splits each retrieved wiki article into paragraphs and normalizes its lengths

    """

    pars = [] # the final paragraphs that will be ranked and retrieved
    lemm_pars = [] # processed paragraphs for building the index

    # iterate over articles and process each one
    for doc in doc_list:
      # get whole page content
      try:
        doc = self.get_wiki_page(doc)
      except wikipedia.PageError:
        continue
      # check if actual page was found
      # page is of Wikipedia instance
      # if not found, string "not found" is returned
      if isinstance(doc, str):
        continue

      # remove the references part of the page 
      result = re.split('=== Reference ===|== Reference ==', doc.content)[0]
      # split article into paragraphs
      # this regular expression catches the headings of paragraphs
      result = re.split('== .*. ==|\\n\\n', result)

      # save stripped paragraphs
      for par in result:
        par = ((((par.strip()).strip('=')).strip('\n')).strip('\n\n')).strip('\r\n')

        # remove some trash -- dont know, how to do this better
        # its something like the references part
        # usually contains a lot of searched terms but not the answer,
        # so its removed to not end high in the bm25 ranking
        if par == '' or par == '\n' or par.strip().startswith("Obrázky, zvuky či videa k tématu"):
          continue

        # for albert, the max paragraph length shall be shorter due to translation limits
        if self.model_type == 'mbert':
          max_len = 3000
        else:
          max_len = 1500

        # check max paragraph length 
        if len(par) > max_len:
          # split into smaller paragraphs - normalize lengths
          normalized_paragraphs = self.normalize_length(par)
          # append each smaller paragraph
          for norm_par in normalized_paragraphs:
            # append paragraph
            pars.append(norm_par)
            # get lemmas and append
            lemm_pars.append((' ').join(self.delete_common(self.lemmatize(norm_par.lower()))))
        else:
          # append paragraph
          pars.append(par)
          # get lemmas and append
          lemm_pars.append((' ').join(self.delete_common(self.lemmatize(par.lower()))))

    return pars, lemm_pars


  def retrieve(self, question, max_docs):  
    """
    Returns the top 3 paragraphs for the given question

    """
    # check max question length - just set something
    if len(question) > 250:
      return ""

    # strip questionmark - its not necessary i guess
    question = question.strip('?')

    # get relevant wiki article names
    doc_list = self.get_doc_list(question)

    # convert from list to set -- only work with unique article names
    doc_list = set(doc_list)

    # split docs into paragraphs -- this is the slowest part of the process
    # might need optimalization
    pars, lemm_pars = self.split_documents(doc_list)

    # if we didnt find anything using the wiki api -- we need to get atleast something
    # so we just take the reelvant abstracts and hope the answer is there
    if len(pars) == 0:
      # tokenize the query and get the top 5 abstracts
      tokenized_query = self.delete_common(self.lemmatize(question.lower()))
      results = self.bm25_abstract_index.get_top_n(tokenized_query, self.abstracts, n=5)

      # perform the paragraph normalization similar to the one in split_documents()
      # get the paragraphs and their processed versions to build its search index
      for par in results:
        par = par.strip()
        # check max paragraph length
        # for albert, the max paragraph length shall be shorter due to translation limits
        if self.model_type == 'mbert':
          max_len = 3000
        else:
          max_len = 1500
        if len(par) > max_len:
          # split into smaller paragraphs
          normalized_paragraphs = self.normalize_length(par)
          # append each smaller paragraph
          for norm_par in normalized_paragraphs:
            pars.append(norm_par)
            lemm_pars.append((' ').join(self.delete_common(self.lemmatize(norm_par.lower()))))
        else:
          # append paragraph
          pars.append(par)
          # get lemmas and append
          lemm_pars.append((' ').join(self.delete_common(self.lemmatize(par.lower()))))
    ##############################################################################################

    # tokenize for bm25
    tok_text = []
    for par in lemm_pars:
      tok_par = re.split("\W", par)
      for tok in tok_par:
        if tok == "":
          tok_par.remove("")
      tok_text.append(tok_par)

    # finally build the index from the processed paragraphs
    bm25 = BM25Plus(tok_text)
    # either BM25 function can be used - the results are similar
    # bm25 = BM25Okapi(tok_text)

    # tokenize and lemmatize the query
    tokenized_query = (' ').join(self.delete_common(self.lemmatize(question.lower())))
    tokenized_query = re.split("\W", tokenized_query)

    # get top n results
    results = bm25.get_top_n(tokenized_query, pars, n=max_docs)

    return results, doc_list


  @staticmethod
  def count_log_conf(best_answer, all_answers):
    """
    Returns the sum of log probs 
    
    """

    # this function tries to take into consideration not only top answer score
    # but also if the model chose the answer multiple times
    # if it did, the scores are summed
    log_conf = 0
    for answer in all_answers:
      if (best_answer in answer['text']) or (answer['text'] in best_answer):
        log_conf += answer['score']
    
    return log_conf


# Preprocessing SQAD

Download czech sqad dataset in case it needs to be processed

In [None]:
# download czech sqad
!curl --remote-name-all https://lindat.cz/repository/xmlui/bitstream/handle/11234/1-3069{/sqad_v3.tar.xz}
!tar -xf sqad_v3.tar.xz

In [None]:
class SqadDataset():
"""
Process sqad into json (only questions, answers and their lemmas)

"""

  def __init__(self, sqad_dir, save_dir="./sqad_processed", process_boolean=False):
    self.save_dir = save_dir # the file where the processed json should be stored
    self.sqad_dir = sqad_dir # the directory containing the sqad dataset
    self.process_boolean = process_boolean # if false, automatically skips the yes/no questions

  def extract_answer(self, dirnum):
    """
    Parse the answer of current dataset record.
    Returns the parsed answer and its lemma.

    """

    # open the corresponding file
    f = open(f"{self.sqad_dir}/{dirnum}/09answer_extraction.vert", "r")
    # split into lines
    q = f.read().split("\n")
    answer = ""
    answer_lemma = ""

    # parse answer
    for line in q:
      # split into columns
      line = line.split("\t")
      # end sign
      if line[0] == "</s>":
        break
      
      # get answer and its lemma
      line_a = line[0]
      if len(line) > 1:
        line_a_lemma = line[1]
        if line_a_lemma == "[number]":
          line_a_lemma = line[0]
      # process special signs - kinda annoying to deal with this
      if line_a in {"<s>", "<g/>", "</s>", "<s desamb=\"1\">"}:
        answer = answer[:-1]
        answer_lemma = answer_lemma[:-1]
        continue

      # append to answer string
      answer += line_a + " "
      answer_lemma += line_a_lemma + " "
    
    f.close()
    return answer, answer_lemma


  def extract_question(self, dirnum):
    """
    Parse the answer of current dataset record.
    Returns the parsed answer and its lemma.
    
    """
    # open the corresponding file
    f = open(f"{self.sqad_dir}/{dirnum}/01question.vert", "r")
    # split into lines
    q = f.read().split("\n")
    question = ""

    # iterate over lines 
    for line in q:
      # split into columns
      line = line.split("\t")
      # end sign
      if line[0] == "</s>" and question[-1] == "?":
        break
      # the first item in each line is the q token we want
      line = line[0]
      # process some annoying special signs
      if line in {"<s>", "<g/>", "</s>", "<s desamb=\"1\">"}:
        if line != "<s desamb=\"1\">":
          question = question[:-1]
        else:
          question += " "
        continue
      # get the questions mark
      if line != "?":
        question += line + " "
      else:
        question += "? "

    f.close()
    return question


  def process_dataset(self, from_q, to_q):
    """
    Process the questions and answers from the sqad dataset and save it as a json file

    Process the dataset from record from_q to record to_q
    
    """
    # save processed as directory to be dumped as json
    sqad_dataset = {}
    counter = from_q

    # iterate in the range of the whole dataset
    for i in tqdm(range(from_q, to_q+1)):

      # get the question number
      q_number = ""
      for _ in range(len(str(i)), 6):
        q_number += "0"
      q_number += str(i)

      # extract from dataset
      question = self.extract_question(q_number)
      correct_answer, lemmatized_answer = self.extract_answer(q_number)

      # exclude yes/no questions
      if not self.process_boolean:
        if correct_answer.lower().strip() == "ano" or correct_answer.lower().strip() == "ne":
          continue

      # save data
      data = {}
      data["question"] = question
      data["answer"] = correct_answer
      data["answer_lemma"] = lemmatized_answer

      # we want to preserve the original question number -- two counters
      sqad_dataset[counter] = data
      counter += 1

    # dump extracted data as json
    with open(self.save_dir, "w") as f:
      json.dump(sqad_dataset, f)
      print("Sqad dataset has been processed to: " + self.save_dir)


  @staticmethod
  def load_sqad(saved_dataset_file):
    """
    Loads the saved json dataset as a dictionary
    
    """

    # load preprocessed sqad dataset
    with open(saved_dataset_file) as f: 
        data = json.load(f)
    print("Sqad dataset loaded from: " + saved_dataset_file)

    return data

In [None]:
sqad = SqadDataset(sqad_dir="./data/", save_dir="./sqad_processed.json")
sqad.process_dataset(1, 13476)

HBox(children=(FloatProgress(value=0.0, max=13476.0), HTML(value='')))


Sqad dataset has been processed to: drive/MyDrive/data/sqad_processed.json


# Wiki abstracts parsing

Download wiki abstracts dump 

In [None]:
!curl --remote-name-all https://dumps.wikimedia.org/cswiki/latest/{cswiki-latest-abstract.xml.gz}
!gunzip ./cswiki-latest-abstract.xml.gz

In [None]:
def parse_abstracts(save_file)
  """
  this function is used to parse the xml input abstracts dump to a neat json file

  """
  # load abstracts -- eats a lot of ram and save only titles and abstracts - no links
  xmldoc = minidom.parse('cswiki-latest-abstract.xml')
  abstracts = xmldoc.getElementsByTagName('abstract')
  titles = xmldoc.getElementsByTagName('title')

  # strip the 'Wikipedie: ' thing from each title
  for title in titles:
    title.firstChild.nodeValue = title.firstChild.nodeValue.lstrip('Wikipedie: ')

  # save processed dump as dict
  wiki_abstracts = {}
  for idx, title in enumerate(titles):
    data = {}
    data['title'] = title.firstChild.nodeValue
    # skip records without abstract
    if abstracts[idx].firstChild != None:
      data['abstract'] = abstracts[idx].firstChild.nodeValue
      wiki_abstracts[idx] = data

  # dump extracted data as json
  with open(save_file, "w") as f:
    json.dump(wiki_abstracts, f)
    print("Sqad dataset has been processed to: " + save_file)

In [None]:
save_file = "./wiki_abstracts_processed.json"
process_wiki_abstracts(save_file)

Download titles:

In [None]:
!curl --remote-name-all https://dumps.wikimedia.org/cswiki/latest/{cswiki-latest-all-titles-in-ns0.gz}
!gunzip ./cswiki-latest-all-titles-in-ns0.gz
!mv cswiki-latest-all-titles-in-ns0 wiki_titles

# Loading model


Choose model checkpoint

In [None]:
# either 'mbert' or 'albert'
model_type = 'mbert'
squadv2 = True

if squadv2:
  if model_type == 'mbert':
    model_checkpoint = "../data/mbert_finetuned_czech_squad2"
  elif model_type == 'albert':
    model_checkpoint = "../data/albert_squad2_finetuned"
  else:
    print("wrong model type name")
else:
  if model_type == 'mbert':
    model_checkpoint = "../data/mbert_finetuned_czech_squad"
  elif model_type == 'albert':
    model_checkpoint = "../data/albert_squad_finetuned"
  else:
    print("wrong model type name")

# files with additional saved data
"""
depends on where you run from, you can also download and process via another cells here
or you can use the preprocessed/downloaded data from /data directory on the SD 

"""
majka_file = './majka.w-lt'
dita_file  = "./czech-morfflex-pdt-161115/czech-morfflex-pdt-161115.tagger"
wiki_titles = "../data/wiki_titles" 
wiki_abstracts = "../data/wiki_abstracts_processed.json"
index_file = "../data/abstracts_index.pkl"

Create reader



In [None]:
reader = Reader(model_checkpoint, model_type=model_type)

Model loaded from: ./drive/MyDrive/mbert_models/bert_finetuned_czech_squad2
Model has 177264386 parameters
Device selected:
cuda:0


Create retriever (Initialization of the retriever takes about 10 minutes because of building the index for wikipedia abstracts).
If the saved index is supplied, its much faster.

In [None]:
retriever = Retriever(wiki_abstracts=wiki_abstracts, wiki_titles=wiki_titles, majka_file=majka_file, dita_file=dita_file, 
                      index_file=index_file, use_majka=False, download_ner_model=True)

# Final pipeline

In [None]:
def translate(question_cs, documents_cs, translator):
  """
  Translates the czech question and documents and returns
  question and list of documents in english

  """
  # we concatenate the question with all the documents to be translated at once
  # so we minimize the number of requests for googletrans
  delimiter = " _____ "
  concatenated = question_cs

  for doc in documents_cs:
    concatenated += delimiter
    concatenated += doc

  # and translate as a whole
  if len(concatenated) > 5000:
    concatenated = concatenated[0:4999]
  concatenated = translator.translate(concatenated, src='cs', dest='en').text
  # and split again
  delimiter = "_____"
  concatenated = concatenated.split(delimiter)
  
  # get translated question and doc
  question = concatenated[0]
  documents = concatenated[1:]

  return question, documents

In [None]:
def find_answer(question, reader, retriever, translator, model_type):
  """
  Finds the answer to the question - connects everything 

  """

  question_cs = question # save czech question

  # retrieve the relevant paragraphs of context
  documents_cs, article_list = retriever.retrieve(question, max_docs=10)

  # for saving the best results
  bestAnswers = []
  bestDocs = []
  bestLogProbs = []
  bestSummedLogProbs = []
  
  # translate according to model type - mbert doesnt need translation
  if model_type == 'mbert':
    question = question_cs
    documents = documents_cs
  elif model_type == 'albert':
    # translate question and documents for reader
    question, documents = translate(question_cs, documents_cs, translator)
  # delete null strings
  documents = [x for x in documents if len(x.strip())]

  # iterate over retrieved paragraphs
  for idx, document in enumerate(documents):
    # strip whitespaces
    document = document.strip()
    # chceck if any document has been found for the question
    if document == "":
      continue

    if len(question) > len(document):
      question = question.split("____")[0]
    
    #get answer -------------------------------------------
    answers = reader.get_answers(question, document)
    log_conf = 0
    log_conf_summed = 0

    # choose the first valid answer - which is not empty
    answer = ''
    for answer in answers:
      if answer['text'] != '':
        log_conf = answer['score']
        answer = answer['text']
        log_conf_summed = Retriever.count_log_conf(answer, answers)
        break
    #######################################################
    if type(answer) is not str:
      continue # this is just to make sure that the answer is really ok

    # save probs and answer
    bestAnswers.append(answer)
    bestLogProbs.append(log_conf)
    bestSummedLogProbs.append(log_conf_summed)
    # save retrieved doc
    bestDocs.append(documents_cs[idx])

  ############################################################
  # check if any answer was found
  if len(bestLogProbs) == 0 or bestAnswers[np.argmax(bestLogProbs, axis=0)] == '':
    return ""

  # get the best doc
  # get best answer from retriever according to reader
  document = bestDocs[np.argmax(bestLogProbs, axis=0)]
  answer = bestAnswers[np.argmax(bestLogProbs, axis=0)]

  # translate the final answer - in case of using english albert
  answer_en = answer
  if model_type == 'albert':
    delimiter = " _____ "
    concatenated = answer
    for ans in bestAnswers:
      concatenated += delimiter
      concatenated += ans
    # translate concatenated
    concatenated =  translator.translate(concatenated, src='en', dest='cs').text

    # and split again
    delimiter = "_____"
    concatenated = concatenated.split(delimiter)
    
    # get translated question and doc
    answer = concatenated[0]
    bestAnswers = concatenated[1:]

  return answer, answer_en, document, bestAnswers, bestLogProbs, bestSummedLogProbs, article_list, bestDocs

# Get predictions on SQAD

In [None]:
def get_timestamp():
    # used for unique file names
    return datetime.datetime.now().strftime('%d-%m-%Y_%H:%M')

In [None]:
def sqad_eval_predictions(from_q, to_q, data, save_dir, model_type):
  """
  Gets answers for required range of sqad questions.
  Saves results as txt and json for later evaluation

  """

  # create translator
  translator = Translator()

  # for writing structured results to json
  results = {}

  #iterate over sqad questions in required range
  for i in tqdm(range(from_q, to_q+1)):

    # get data from processed dataset
    question = data[str(i)]["question"]
    correct_answer = data[str(i)]["answer"]
    lemmatized_answer = data[str(i)]["answer_lemma"]

    # get answer and other info for the specific question
    try:
      answer, answer_en, document, bestAnswers, bestLogProbs, bestSummedLogProbs, article_list, best_docs = find_answer(
        question, reader, retriever, translator, model_type)
    except ValueError:
      continue
    
    # convert set to list
    article_list = list(article_list)
    # convert list of floats to strings
    bestLogProbs_string = ['{:.2f}'.format(x) for x in bestLogProbs]

    # write to dict for json dump for later evaluation
    results[i] = {
        "question" : question,
        "answer" : answer,
        "answer_orig" : answer_en,
        "answer_sqad" : correct_answer,
        "answer_sqad_lemma" : lemmatized_answer,
        "articles" : article_list,
        "document" : document,
        "all_answers" : bestAnswers,
        "all_log_probs" : bestLogProbs_string,
        "all_documents" : best_docs
    }

    # interim save
    if i % 100 == 0:
      # save interim results
      save_file = save_dir + f"checkpoint_{from_q}-{i}__{get_timestamp()}.json"
      with open(save_file, "w") as j:
        json.dump(results, j)

  ##############################################################################

  # save results as json for later evaluation
  save_file = save_dir + f"answers_{from_q}-{to_q}__time:{get_timestamp()}.json"
  with open(save_file, "w") as j:
    json.dump(results, j)
    print()
    print("Results has been saved to " + save_file)

Run the evaluation and save the results to the chosen file

In [None]:
# load preprocessed sqad dataset
data = SqadDataset.load_sqad("../data/sqad_processed.json")
save_to = "./"

# get the model predictions on the selected range of the dataset
sqad_eval_predictions(1, 100, data, save_to, model_type)

# Evaluate SQAD predictions

In [None]:
# morphodita lemmatizer
tagger = Tagger("./czech-morfflex-pdt-161115/czech-morfflex-pdt-161115.tagger")

def lemmatize(text):
    """
    Returns lemma of each token in a list of lemmatized tokens
    Used to lemmatize during the evaluation

    """
    # function is thoroughly described as the retriever method - see Retriever()
    text = re.split("\W", text)
    text = (" ").join(text)

    tokens = list(tagger.tag(text, convert='strip_lemma_id'))

    lemmas = []
    for token in tokens:
      lemmas.append(token.lemma)

    return lemmas

In [None]:
def sqad_eval_score(results_json):
  """
  Evaluates results from file 'results_json' on standard metrics

  """

  # load results
  with open(results_json, "r") as f:
      results = json.load(f)

  # for counting correct answers
  score = 0
  n_answers = 0
  EM = 0
  EM_lemma = 0
  f1 = 0
  f3 = 0
  mrr = 0
  recall = 0
  precision = 0
  doc_score = 0
  doc_mrr = 0

  for i in results:

    # convert to lowercase and strip
    answer = results[i]["answer"].lower()
    answer_en = results[i]["answer_orig"].lower()
    correct_answer = results[i]["answer_sqad"].lower().strip()
    lemmatized_answer = results[i]["answer_sqad_lemma"].lower().strip()
    document = results[i]["document"]
    all_docs = results[i]["all_documents"]

    # lemmatized versions of answers
    ans_lemma = (" ").join(lemmatize(answer))
    corr_ans_lemma = (" ").join(lemmatize(correct_answer))
    lemm_ans_lemma = (" ").join(lemmatize(lemmatized_answer))

    # doc score
    lemm_docs = [(" ").join(lemmatize(x)) for x in all_docs]
    for idx, doc in enumerate(all_docs):
      if (correct_answer in doc or corr_ans_lemma in doc or lemm_ans_lemma in doc or
          correct_answer in lemm_docs[idx] or corr_ans_lemma in lemm_docs[idx] or lemm_ans_lemma in lemm_docs[idx]):
        doc_score += 1
        doc_mrr += 1/(idx+1)
        is_there = True
        break

    # try to convert numbers to the same format
    # e.g.: 15 350 == 15350 
    # -> convert every number to the format without the spaces
    is_int = True
    try:
      answer_int = ("").join(answer.split(" "))
      answer_int = int(answer_int)
    except ValueError:
      is_int = False
    if is_int:
      try:
        corr_answer_int = ("").join(correct_answer.split(" "))
        corr_answer_int = int(corr_answer_int)
      except ValueError:
        is_int = False
    if is_int:
      answer = str(answer_int)
      correct_answer = str(corr_answer_int)


    # first three (F3) answers match
    all_answers = results[i]["all_answers"]
    for a in all_answers:
      if a.lower().strip() == correct_answer:
        f3 += 1
        break
    
    # MRR
    all_scores = results[i]["all_log_probs"]
    answers_scores = []
    # iterate, get tuples of answer-scores and sort in descending order via score
    for idx, ans_score in enumerate(all_scores):
      answers_scores.append((float(ans_score), all_answers[idx]))
    answers_scores = sorted(answers_scores, key=lambda answer: answer[0], reverse=True)
    # iterate and compare with the ground truth to compute the rank of the correct answer
    for idx, ans in enumerate(answers_scores):
      if ans[1].lower().strip() == correct_answer:
        mrr += 1/(idx+1)
        break
    

    # NAIVE SCORE - increment score, if we got any match between the original and retrieved answer
    if (   answer in correct_answer or correct_answer in answer or
        ans_lemma in corr_ans_lemma or corr_ans_lemma in ans_lemma or
        ans_lemma in lemm_ans_lemma or lemm_ans_lemma in ans_lemma or 
        answer_en in lemm_ans_lemma or lemm_ans_lemma in answer_en):
      # increment score
      score += 1

    # ELM
    if ans_lemma == (" ").join(lemmatize(correct_answer)) or ans_lemma == (" ").join(lemmatize(lemmatized_answer)):
      EM_lemma += 1

    # exact match
    if (answer == correct_answer or answer_en == correct_answer):
      EM += 1

    # F1 score
    if correct_answer.find(answer) != -1 or answer.find(correct_answer) != -1:
      if len(answer) > len(correct_answer):
        tp = len(correct_answer)
        fp = len(answer) - tp
        fn = 0
      elif len(answer) < len(correct_answer):
        tp = len(answer)
        fp = 0
        fn = len(correct_answer) - tp
      else:
        tp = len(correct_answer)
        fp = 0
        fn = 0

      f_precision = tp / (tp + fp)
      f_recall = tp / (tp + fn)
      f1 += 2 * (f_precision * f_recall) / (f_precision + f_recall)
    else:
      f_recall = 0
      f_precision = 0
      f1 += 0
    # save precision and recall score as well
    recall += f_recall
    precision += f_precision

    # increment total answer count
    n_answers += 1


  # count the procentual scores
  percent = (score / n_answers) * 100
  percentEM = (EM / n_answers) * 100
  percentEML = (EM_lemma / n_answers) * 100
  percentF3 = (f3 / n_answers) * 100
  percentMRR = (mrr / n_answers) * 100
  percentRecall = (recall / n_answers) * 100
  percentPrecision = (precision / n_answers) * 100
  percentF1 = (f1 / n_answers) * 100
  percentDocScore = (doc_score / n_answers) * 100
  percentDocMrr = (doc_mrr / n_answers) * 100
  
  print("results:")
  print()
  print(f"naive score: {score}/{n_answers} ... {percent:{5}.{4}} %")
  print(f"exact match: {str(EM)}/{n_answers} ... {percentEM:{5}.{4}} %")
  print(f"exact match - lemma: {EM_lemma}/{n_answers} ... {percentEML:{5}.{4}} %")
  print()
  print(f"EM3 - F3: {f3}/{n_answers} ... {percentF3:{5}.{4}} %")
  print(f"MRR: {float(mrr):{8}.{6}}/{n_answers} ... {percentMRR:{5}.{4}} %")
  print()
  print(f"recall: {float(recall):{8}.{6}}/{n_answers} ... {percentRecall:{5}.{4}} %")
  print(f"precision: {float(precision):{8}.{6}}/{n_answers} ... {percentPrecision:{5}.{4}} %")
  print(f"F1: ... {float(percentF1):{5}.{4}} %")
  print()
  print(f"document score: {doc_score}/{n_answers} ... {percentDocScore:{5}.{4}} %")
  print(f"document mrr score: {doc_mrr}/{n_answers} ... {percentDocMrr:{5}.{4}} %")

In [None]:
sqad_eval_score("../benchmarks/whole_sqad_eval/8answers_1-11273__20-04-2021_14_01.json")

# Ask a question

In [None]:
# create translator
if model_type == 'albert':
  translator = Translator()
else:
  translator = None

# type the question, get the answer
question = "kdy začala první světová válka"
answer = find_answer(question, reader, retriever, translator, model_type)

# getting the data we want (if len 0 -> no answer found)
if len(answer) == 0:
  print(answer)
else:
  print(question)
  print()
  print(answer[0])
  if model_type == 'albert':
    print(answer[1])
  print()
  print(answer[2])
  print()
  print(answer[3])
  print(answer[4])
  print(answer[6])

kdy začala první světová válka

28. července 1914

První světová válka (před rokem 1939 známá jako Velká válka nebo světová válka) byl globální válečný konflikt probíhající od 28. července 1914 do 11. listopadu 1918.První světová válka zasáhla Evropu, Afriku a Asii a bojovalo se i na světových oceánech.Formální příčinou války byl úspěšný atentát na následníka rakousko-uherského trůnu arcivévodu Františka Ferdinanda d'Este v Sarajevu dne 28. června 1914.Měsíc poté, 28. července 1914, vyhlásilo Rakousko-Uhersko odvetou válku Srbsku.Na základě předchozích aliančních smluv následovala řetězová reakce ostatních států a během zhruba čtyř týdnů se ve válečném konfliktu ocitla většina Evropy.Ve válce se střetly dvě znepřátelené koalice: Trojdohoda (dále jen „Dohoda“) a Ústřední mocnosti.Dohodovými mocnostmi byly při vypuknutí války Francie, carské Rusko a Spojené království (resp. celé Britské impérium) stojící na straně napadeného Srbska a Belgie, jejíž neutralitu na počátku války nerespektov