# Synthetic SQuAD dataset based on MKQA

In [1]:
from bs4 import BeautifulSoup
import json
import math
import nltk
import os
import re
import requests
import urllib

from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizerFast

In [2]:
LM_NAME = 'bert-base-multilingual-cased'
INPUT_FILE = '../data/mkqa/mkqa.jsonl'

In [3]:
with open(INPUT_FILE, 'r', encoding='utf-8') as fp:
    mkqa_dataset = list(fp)

mkqa_dataset = [json.loads(jline) for jline in mkqa_dataset]

In [4]:
def str_to_num(text):
    try:
        return int(text)
    except:
        return float(text)

## DBPedia

In order to get the Wiki entities in the questions, we will use DBPedia. This service will run in a docker container.

Download the models from https://databus.dbpedia.org/dbpedia/spotlight/spotlight-model/

See:
- https://github.com/dbpedia-spotlight/dbpedia-spotlight
- https://www.dbpedia-spotlight.org/api

In [5]:
stopwords = set(nltk.corpus.stopwords.words('spanish'))
stopwords

{'a',
 'al',
 'algo',
 'algunas',
 'algunos',
 'ante',
 'antes',
 'como',
 'con',
 'contra',
 'cual',
 'cuando',
 'de',
 'del',
 'desde',
 'donde',
 'durante',
 'e',
 'el',
 'ella',
 'ellas',
 'ellos',
 'en',
 'entre',
 'era',
 'erais',
 'eran',
 'eras',
 'eres',
 'es',
 'esa',
 'esas',
 'ese',
 'eso',
 'esos',
 'esta',
 'estaba',
 'estabais',
 'estaban',
 'estabas',
 'estad',
 'estada',
 'estadas',
 'estado',
 'estados',
 'estamos',
 'estando',
 'estar',
 'estaremos',
 'estará',
 'estarán',
 'estarás',
 'estaré',
 'estaréis',
 'estaría',
 'estaríais',
 'estaríamos',
 'estarían',
 'estarías',
 'estas',
 'este',
 'estemos',
 'esto',
 'estos',
 'estoy',
 'estuve',
 'estuviera',
 'estuvierais',
 'estuvieran',
 'estuvieras',
 'estuvieron',
 'estuviese',
 'estuvieseis',
 'estuviesen',
 'estuvieses',
 'estuvimos',
 'estuviste',
 'estuvisteis',
 'estuviéramos',
 'estuviésemos',
 'estuvo',
 'está',
 'estábamos',
 'estáis',
 'están',
 'estás',
 'esté',
 'estéis',
 'estén',
 'estés',
 'fue',
 'f

In [6]:
def get_dbpedia_annotations(text, min_score=0.95, lang_name='spanish'):
    text = urllib.parse.quote_plus(text)
    url = 'http://127.0.0.1:8080/rest/annotate' + \
        '?text=%s' % text + \
        '&confidence=0'
    
    r = requests.get(url, headers={'Accept': 'application/json'})
    if r.status_code != 200:
        print('')
        print('Unexpected HTTP Status: %d - %s' % (r.status_code, url))
        #raise Exception('Unexpected HTTP Status: %d' % r.status_code)
        return []
    
    raw_data = r.json()
    
    resources = []
    
    if 'Resources' not in raw_data:
        return []
    
    stopwords = set(nltk.corpus.stopwords.words(lang_name))
    
    for item in raw_data['Resources']:
        item_score = str_to_num(item['@similarityScore'])
        if item_score >= min_score and item['@surfaceForm'] not in stopwords:
            resources.append({
                'uri': item['@URI'],
                'text': item['@surfaceForm'],
                'score': item_score,
                'types': item['@types'],
                'support': item['@support'],
            })

    return resources

In [7]:
sample_text = 'Qué país ha ganado el mayor número de títulos de Copa del mundo'
sample_annotations = get_dbpedia_annotations(sample_text)
sample_annotations

[{'uri': 'http://es.dbpedia.org/resource/Ganado',
  'text': 'ganado',
  'score': 0.9710442510960338,
  'types': '',
  'support': '1721'},
 {'uri': 'http://es.dbpedia.org/resource/Mayor',
  'text': 'mayor',
  'score': 0.9643462383456692,
  'types': '',
  'support': '2545'},
 {'uri': 'http://es.dbpedia.org/resource/Copa_del_Mundo_de_Tenis_de_Mesa',
  'text': 'Copa del mundo',
  'score': 0.9927741286808874,
  'types': '',
  'support': '41'}]

In [8]:
sample_text = '¿Bill Gates vive en Estados Unidos?'
sample_annotations = get_dbpedia_annotations(sample_text)
sample_annotations

[{'uri': 'http://es.dbpedia.org/resource/Bill_Gates',
  'text': 'Bill Gates',
  'score': 0.9999999999940457,
  'types': 'Http://xmlns.com/foaf/0.1/Person,Wikidata:Q5,Wikidata:Q24229398,Wikidata:Q215627,DUL:NaturalPerson,DUL:Agent,Schema:Person,DBpedia:Agent,DBpedia:Person',
  'support': '478'},
 {'uri': 'http://es.dbpedia.org/resource/Vive_(Venezuela)',
  'text': 'vive',
  'score': 0.9784932221938503,
  'types': 'Wikidata:Q43229,Wikidata:Q24229398,Wikidata:Q15265344,DUL:SocialPerson,DUL:Agent,Schema:TelevisionStation,Schema:Organization,DBpedia:Organisation,DBpedia:Broadcaster,DBpedia:Agent,DBpedia:TelevisionStation',
  'support': '45'},
 {'uri': 'http://es.dbpedia.org/resource/Estados_Unidos',
  'text': 'Estados Unidos',
  'score': 0.9999993981782999,
  'types': 'Wikidata:Q6256,Schema:Place,Schema:Country,DBpedia:PopulatedPlace,DBpedia:Place,DBpedia:Location,DBpedia:Country',
  'support': '399783'}]

In [9]:
def get_dbpedia_entity_id(url):
    r = requests.get(url, headers={'Accept': 'application/json'})
    if r.status_code != 200:
        print('')
        print('Unexpected HTTP Status: %d - %s' % (r.status_code, url))
        #raise Exception('Unexpected HTTP Status: %d' % r.status_code)
        return None
    
    raw_data = r.json()
    if url not in raw_data or 'http://dbpedia.org/ontology/wikiPageID' not in raw_data[url]:
        return None
    
    wiki_id = raw_data[url]['http://dbpedia.org/ontology/wikiPageID'][0]['value']
    
    return wiki_id

In [10]:
get_dbpedia_entity_id(sample_annotations[0]['uri'])

375

In [11]:
def get_dbpedia_search(query, min_score):
    annotations = get_dbpedia_annotations(query, min_score=min_score)
    entities = []
    for item in annotations:
        page_id = get_dbpedia_entity_id(item['uri'])
        if page_id is None:
            continue
        entities.append({
            'wiki_title': item['text'],
            'wiki_page_id': page_id,
            'entity_text': item['text'],
            'entity_score': item['score'],
            'is_mandatory': False,
        })
    return entities

## Wiki API

Wikipedia API:
```
https://en.wikipedia.org/w/api.php
?action=query
&list=search
&srsearch=zyz        # search query
&srlimit=1           # return only the first result
&srnamespace=0       # search only articles, ignoring Talk, Mediawiki, etc.
&format=json         # jsonfm prints the JSON in HTML for debugging.
```

In [12]:
def get_wiki_search(query, top_results, lang_code='es'):
    query = urllib.parse.quote_plus(query)
    url = 'https://%s.wikipedia.org/w/api.php' % lang_code + \
        '?action=query' + \
        '&list=search' + \
        '&srsearch=%s' % query + \
        '&srlimit=%d' % top_results + \
        '&srnamespace=0' + \
        '&format=json'
    
    r = requests.get(url)
    if r.status_code != 200:
        print('')
        print('Unexpected HTTP Status: %d - %s' % (r.status_code, url))
        #raise Exception('Unexpected HTTP Status: %d' % r.status_code)
        return []

    raw_data = r.json()
    
    if 'query' not in raw_data:
        return []
    
    items = []
    for raw_item in raw_data['query']['search']:
        items.append({
            'wiki_title': raw_item['title'],
            'wiki_page_id': raw_item['pageid'],
        })
    return items

In [13]:
query = 'Donde vive Bill Gates'
get_wiki_search(query, 3)

[{'wiki_title': 'Medina (Washington)', 'wiki_page_id': 4384248},
 {'wiki_title': 'Vive Latino', 'wiki_page_id': 83800},
 {'wiki_title': 'Gatos fantasma', 'wiki_page_id': 6969933}]

In [14]:
def get_wiki_article(page_id, lang_code='es', min_sentence_length=80):
    url = 'https://%s.wikipedia.org/w/api.php' % lang_code + \
        '?action=parse' + \
        '&pageid=%d' % page_id + \
        '&prop=text' + \
        '&format=json'
        #'&section=1' + \
    
    r = requests.get(url)
    if r.status_code != 200:
        print('')
        print('Unexpected HTTP Status: %d - %s' % (r.status_code, url))
        #raise Exception('Unexpected HTTP Status: %d' % r.status_code)
        return None
    
    json_data = r.json()
    if 'parse' not in json_data:
        # Not found
        return None
    article = json_data['parse']['text']['*']
    article = re.sub(r'><', ">\n<", article) # Adds new line between HTML tags
    
    soup = BeautifulSoup(article)
    article = soup.get_text().strip()
    
    # Remove URLs
    url_regex = r'(?i)\b((?:https?:(?:/{1,3}|[a-z0-9%])|[a-z0-9.\-]+[.](?:com|net|org|edu|gov|mil|aero|asia|biz|cat|coop|info|int|jobs|mobi|museum|name|post|pro|tel|travel|xxx|ac|ad|ae|af|ag|ai|al|am|an|ao|aq|ar|as|at|au|aw|ax|az|ba|bb|bd|be|bf|bg|bh|bi|bj|bm|bn|bo|br|bs|bt|bv|bw|by|bz|ca|cc|cd|cf|cg|ch|ci|ck|cl|cm|cn|co|cr|cs|cu|cv|cx|cy|cz|dd|de|dj|dk|dm|do|dz|ec|ee|eg|eh|er|es|et|eu|fi|fj|fk|fm|fo|fr|ga|gb|gd|ge|gf|gg|gh|gi|gl|gm|gn|gp|gq|gr|gs|gt|gu|gw|gy|hk|hm|hn|hr|ht|hu|id|ie|il|im|in|io|iq|ir|is|it|je|jm|jo|jp|ke|kg|kh|ki|km|kn|kp|kr|kw|ky|kz|la|lb|lc|li|lk|lr|ls|lt|lu|lv|ly|ma|mc|md|me|mg|mh|mk|ml|mm|mn|mo|mp|mq|mr|ms|mt|mu|mv|mw|mx|my|mz|na|nc|ne|nf|ng|ni|nl|no|np|nr|nu|nz|om|pa|pe|pf|pg|ph|pk|pl|pm|pn|pr|ps|pt|pw|py|qa|re|ro|rs|ru|rw|sa|sb|sc|sd|se|sg|sh|si|sj|Ja|sk|sl|sm|sn|so|sr|ss|st|su|sv|sx|sy|sz|tc|td|tf|tg|th|tj|tk|tl|tm|tn|to|tp|tr|tt|tv|tw|tz|ua|ug|uk|us|uy|uz|va|vc|ve|vg|vi|vn|vu|wf|ws|ye|yt|yu|za|zm|zw)/)(?:[^\s()<>{}\[\]]+|\([^\s()]*?\([^\s()]+\)[^\s()]*?\)|\([^\s]+?\))+(?:\([^\s()]*?\([^\s()]+\)[^\s()]*?\)|\([^\s]+?\)|[^\s`!()\[\]{};:\'".,<>?«»“”‘’])|(?:(?<!@)[a-z0-9]+(?:[.\-][a-z0-9]+)*[.](?:com|net|org|edu|gov|mil|aero|asia|biz|cat|coop|info|int|jobs|mobi|museum|name|post|pro|tel|travel|xxx|ac|ad|ae|af|ag|ai|al|am|an|ao|aq|ar|as|at|au|aw|ax|az|ba|bb|bd|be|bf|bg|bh|bi|bj|bm|bn|bo|br|bs|bt|bv|bw|by|bz|ca|cc|cd|cf|cg|ch|ci|ck|cl|cm|cn|co|cr|cs|cu|cv|cx|cy|cz|dd|de|dj|dk|dm|do|dz|ec|ee|eg|eh|er|es|et|eu|fi|fj|fk|fm|fo|fr|ga|gb|gd|ge|gf|gg|gh|gi|gl|gm|gn|gp|gq|gr|gs|gt|gu|gw|gy|hk|hm|hn|hr|ht|hu|id|ie|il|im|in|io|iq|ir|is|it|je|jm|jo|jp|ke|kg|kh|ki|km|kn|kp|kr|kw|ky|kz|la|lb|lc|li|lk|lr|ls|lt|lu|lv|ly|ma|mc|md|me|mg|mh|mk|ml|mm|mn|mo|mp|mq|mr|ms|mt|mu|mv|mw|mx|my|mz|na|nc|ne|nf|ng|ni|nl|no|np|nr|nu|nz|om|pa|pe|pf|pg|ph|pk|pl|pm|pn|pr|ps|pt|pw|py|qa|re|ro|rs|ru|rw|sa|sb|sc|sd|se|sg|sh|si|sj|Ja|sk|sl|sm|sn|so|sr|ss|st|su|sv|sx|sy|sz|tc|td|tf|tg|th|tj|tk|tl|tm|tn|to|tp|tr|tt|tv|tw|tz|ua|ug|uk|us|uy|uz|va|vc|ve|vg|vi|vn|vu|wf|ws|ye|yt|yu|za|zm|zw)\b/?(?!@)))'
    article = re.sub(url_regex, ' ', article)
    
    # Clean short sentences
    article = re.sub(r'\s*\n\s*', '\n', article)
    article = article.splitlines()
    article = [x for x in article if len(x) > min_sentence_length]
    article = '\n'.join(article) + ''
    
    # Clean wiki rubbish
    article = re.sub(r'\[[^\]]+\]+', ' ', article)
    article = re.sub(r'[\s\u200b]+', ' ', article)
    
    return article

In [15]:
article = get_wiki_article(12918)
article

'California es uno de los cincuenta estados que, junto con Washington D. C., forman los Estados Unidos de América. Su capital es Sacramento, y su ciudad más poblada, Los Ángeles. Está ubicado en la región oeste del país, división Pacífico, limitando al norte con Oregón, al este con Nevada, al sureste con el río Colorado que lo separa de Arizona, al sur con Baja California (México) y al oeste con el océano Pacífico. Con 37 253 956 habitantes en 2010 es el estado más poblado y con 423 970 km², el tercero más extenso, por detrás de Alaska y Texas. Fue admitido en la Unión el 9 de septiembre de 1850 como el estado número 31. Además, cuenta con las segunda y quinta áreas más pobladas de la nación, el Gran Los Ángeles y el Área de la Bahía de San Francisco y ocho de las ciudades más pobladas del país: Los Ángeles, San Diego, San José, San Francisco, Fresno, Sacramento, Long Beach y Oakland. La zona estuvo poblada desde hace milenios por los nativos americanos antes de las primeras expedicion

## Static entities

Sometimes, questions or answers contain proper nouns (with an initial capital letter) or quoted text. Those elements are considered "static entities" in this notebook and we can extract them using this function.

In [16]:
def get_static_entities(text, default_score=5.):
    all_matches = []
    matches = re.findall(r'"(.+?)"', text)
    if len(matches) > 0:
        all_matches += [x.strip() for x in matches]
    matches = re.findall(r'\'(.+?)\'', text)
    if len(matches) > 0:
        all_matches += [x.strip() for x in matches]
    matches = re.findall(r'((?:(?:[A-Z][A-Za-z]+|de)\s?)+)', text)
    if len(matches) > 0:
        all_matches += [x.strip() for x in matches]
    
    # Format static entities
    static_entities = []
    for x in all_matches:
        static_entities.append({
            'entity_text': x,
            'entity_score': default_score,
            'is_mandatory': True,
        })
    
    return static_entities

## Analysis functions

In [17]:
def get_word_tokenizer(artifacts_path='../artifacts/', lm_name='bert-base-multilingual-cased', lowercase=False):
    save_path = '%s%s/' % (artifacts_path, lm_name)
    tokenizer = BertTokenizerFast('%svocab.txt' % save_path, do_lower_case=lowercase)
    return lambda text : [text[start:end] \
                          for (start, end) in tokenizer(text, return_offsets_mapping=True,
                                                        return_special_tokens_mask=False)['offset_mapping'][1:-1]]

In [18]:
"""
def bleu_score(reference, hypothesis):
    if len(reference) == 1:
        return int(reference == hypothesis)
    elif len(reference) < 4:
        weights_len = [(0.5, 0.5), (0.34, 0.33, 0.33)]
        score = nltk.translate.bleu_score.sentence_bleu([reference], hypothesis,
                                                        weights=weights_len[len(reference)-1])
    else:
        score = nltk.translate.bleu_score.sentence_bleu([reference], hypothesis)
    return score
"""

def bleu_score(reference, hypothesis):
    matches = [int(x == y) for x, y in zip(reference, hypothesis)]
    return sum(matches) / len(reference)

In [19]:
a = ['esta', 'es', 'prueba','1']
b = ['esta', 'es', 'prueba','2']

bleu_score(a, b)

0.75

In [20]:
text = '¡Hola mundo! ¡Adiós mundo!'

tokenizer = get_word_tokenizer()
tokenizer(text)

['¡', 'Ho', 'la', 'mundo', '!', '¡', 'Adi', 'ós', 'mundo', '!']

## MKQA Dataset

In [21]:
def get_es_number_unit(raw_unit_name, is_plural=False):
    # First list correspond to singular and second list to plural
    es_conversion = {
        'Antes de la era vulgar': [['a.C.'], ['a.C.']],
        'Galones': [['galón'], ['galones']],
        'Millas por hora': [['mph', 'milla por hora'], ['mph', 'millas por hora']],
        'acre': [['acre'], ['acres']],
        'antes del Mediodia': [['AM', 'A.M.'], ['AM', 'A.M.']],
        'año terrestre': [['año'], ['años']],
        'caballo de potencia metrico': [['caballo de potencia', 'caballo', 'hp', 'cv'], ['caballos de potencia', 'caballos', 'hp', 'cv']],
        'centímetro': [['cm', 'centímetro'], ['cm', 'centímetros']],
        'día': [['día'], ['días']],
        'dólar estadounidense': [['dólar', '$'], ['dólares', '$']],
        'episodio': [['episodio'], ['episodios']],
        'escala Fahrenheit': [['grado Fahrenheit', 'Fahrenheit', 'ºF'], ['grados Fahrenheit', 'Fahrenheits', 'ºF']],
        'estaciones del año': [['temporada'], ['temporadas']],
        'grados centigrados': [['grado centigrado', 'grado', 'ºC'], ['grados centigrados', 'grados', 'ºC']],
        'gramo': [['gramo', 'gr', 'g'], ['gramos', 'gr', 'g']],
        'hora': [['hora', 'h'], ['horas', 'h']],
        'kilometraje': [['kilómetro', 'km'], ['kilómetros', 'km']],
        'libra avoirdupois': [['libra avoirdupois', 'libra', 'lb'], ['libras avoirdupois', 'libras', 'lb']],
        'light año terrestre': [['año luz'], ['años luz']],
        'mes sinódico': [['mes sinódico', 'mes'], ['meses sinódicos', 'meses']],
        'metros': [['metro', 'm'], ['metros', 'm']],
        'metros por segundo': [['metro por segundo', 'mps', 'm/s'], ['metros por segundo', 'mps', 'm/s']],
        'mililitro': [['mililitro', 'ml'], ['mililitros', 'ml']],
        'milimetro': [['milímetro', 'mm'], ['milímetros', 'mm']],
        'milla': [['milla', 'mi'], ['millas', 'mi']],
        'millas cuadradas': [['milla cuadrada'], ['millas cuadradas']],
        'minuto': [['minuto'], ['minutos']],
        'onza': [['onza'], ['onzas']],
        'other currency': [[], []], # Nothing to do
        'other unit': [[], []], # Nothing to do
        'palabra': [['palabra'], ['palabras']],
        'pie': [['pie'], ['pies']],
        'pies cuadrados': [['pie cuadrado'], ['pies cuadrados']],
        'post meridiem (time)': [['PM', 'P.M.'], ['PM', 'P.M.']],
        'pulgada': [['pulgada'], ['pulgadas']],
        'segundos': [['segundo'], ['segundos']],
        'septenario': [['septenario'], ['septenarios']],
        'tanto por ciento': [['porcentaje', '%'], ['porcentaje', '%']],
    }
    
    if raw_unit_name not in es_conversion or len(es_conversion[raw_unit_name]) == 0:
        return []
    else:
        return [x for x in es_conversion[raw_unit_name][int(is_plural)]]

def parse_nwu_answer_es(raw_answer):
    """
    Parses Spanish answers of type number_with_unit (nwu).
    """
    answers = []

    x_interval = re.search(r'^([\d.]+) ([\d.]+) (.+)$', raw_answer)
    if x_interval:
        unit_value_1 = str_to_num(x_interval[1])
        unit_value_2 = str_to_num(x_interval[2])
        unit_names = get_es_number_unit(x_interval[3], is_plural=True)
        if len(unit_names) == 0:
            # Skip if not unit name is provided in range of values
            pass
        else:
            for unit_name in unit_names:
                answers.append('entre %s y %s %s' % (str(unit_value_1), str(unit_value_2), unit_name))
                answers.append('desde %s hasta %s %s' % (str(unit_value_1), str(unit_value_2), unit_name))
    else:
        x_single = re.search(r'^([\d.]+) (.+)$', raw_answer)
        if x_single is None:
            return []
        unit_value = str_to_num(x_single[1])
        unit_names = get_es_number_unit(x_single[2], is_plural=(unit_value > 0))
        if len(unit_names) == 0:
            answers.append(str(unit_value))
        else:
            for unit_name in unit_names:
                answers.append('%s %s' % (str(unit_value), unit_name))
    
    return answers

In [22]:
def parse_date_answer_es(raw_answer):
    date_parts = re.search('^([\d.]+)-([\d.]+)-([\d.]+)$', raw_answer)
    if date_parts is None:
        # No ISO date format
        return [raw_answer]
    
    month_conversion = ['enero', 'febrero', 'marzo', 'abril', 'mayo', 'junio', 'julio',
                        'agosto', 'septiembre', 'octubre', 'noviembre', 'diciembre']
    
    year_number = date_parts[1]
    month_number = date_parts[2]
    month_str = month_conversion[str_to_num(date_parts[2]) - 1]
    day_number = date_parts[3]
    
    return [
        '%s-%s-%s' % (year_number, month_number, day_number),
        '%s-%s-%s' % (day_number, month_number, year_number),
        '%s-%s-%s' % (month_number, day_number, year_number),
        '%s/%s/%s' % (month_number, day_number, year_number),
        '%s de %s del %s' % (day_number, month_number, year_number),
        '%s de %s, %s' % (day_number, month_number, year_number),
    ]

In [23]:
def parse_raw_answer_es(type, main_answer, aliases):
    """
    Parses MKQA Answers of Spanish language.
    """
    parsed_answers = []
    if type == 'number_with_unit':
        # Example: 16.0 año terrestre
        parsed_answers += parse_nwu_answer_es(main_answer)
        for alias in aliases:
            parsed_answers += parse_nwu_answer_es(alias)
    elif type == 'number':
        # Example: 104.0
        parsed_answers.append(main_answer)
        for alias in aliases:
            parsed_answers.append(alias)
    elif type == 'date':
        # Example: 2001-08-29
        parsed_answers += parse_date_answer_es(main_answer)
    elif type == 'entity':
        # Example: Pokémon Ranger: Sombras de Almia
        parsed_answers.append(main_answer.strip())
        for alias in aliases:
            parsed_answers.append(alias.strip())
    elif type == 'short_phrase':
        # Example: rosemary almond
        parsed_answers.append(main_answer.strip())
        for alias in aliases:
            parsed_answers.append(alias.strip())
    else:
        # Ignored types: unanswerable, long_answer, binary
        pass
    
    return parsed_answers

In [24]:
def find_answers(context, answers, word_tokenizer, min_bleu):
    context_tokens = word_tokenizer(context)
    found_answers = []
    
    context_offsets = []
    last_span = 0
    for token in context_tokens:
        span_start = context.find(token, last_span)
        if span_start == -1:
            raise Exception('Error to tokenize context: %s' % context)
        span_end = span_start + len(token)
        context_offsets.append([span_start, span_end])
    
    for answer in answers:
        answer_tokens = word_tokenizer(answer)
        window_size = len(answer_tokens)
        
        i = 0
        len_answer = len(answer_tokens)
        while i < len(context_tokens) - len_answer:
            score = bleu_score(answer_tokens, context_tokens[i:i+window_size])
            if score >= min_bleu:
                span_start = context_offsets[i][0]
                span_end = context_offsets[i+window_size-1][1]
                found_answers.append({
                    'answer_start': span_start,
                    'answer_end': span_end,
                    'text': context[span_start:span_end],
                })
                i += len_answer # Skips answer position
            else:
                i += 1
    
    return found_answers

In [25]:
def has_entities(context, entities, n=1):
    if entities is None or len(entities) == 0:
        return True, 0.
    
    counter = 0
    max_score = 0.
    current_score = 0.
    for entity in entities:
        max_score = entity['entity_score']
        if context.find(entity['entity_text']) != -1:
            # Found entity
            counter += 1
            current_score += entity['entity_score']
        elif 'is_mandatory' in entity and entity['is_mandatory']:
            # Mandatory entitiy not found
            return False, 0.
    
    return (counter >= n), (current_score / max_score)

In [26]:
def has_tokens(context, tokens, token_score=1.):
    if tokens is None or len(tokens) == 0:
        raise Exception('No tokens provided for context: %s' % context)
    
    max_score = 0
    current_score = 0.
    for token in tokens:
        max_score += token_score
        if context.find(token) != -1:
            # Found token
            current_score += token_score
    
    return current_score / max_score

In [27]:
def combine_qa_entities(entities_a, entities_b):
    entities = {}
    for entity in entities_a:
        idx = entity['wiki_page_id'] if 'wiki_page_id' in entity else entity['entity_text']
        if idx in entities:
            if entity['entity_score'] < entities[idx]['entity_score']:
                entities[idx] = entity
        else:
            entities[idx] = entity
    for entity in entities_b:
        idx = entity['wiki_page_id'] if 'wiki_page_id' in entity else entity['entity_text']
        if idx in entities:
            if entity['entity_score'] < entities[idx]['entity_score']:
                entities[idx] = entity
        else:
            entities[idx] = entity
    return list(entities.values())

In [28]:
def get_best_squad_items(squad_items, top_items):
    squad_items = sorted(squad_items, key=lambda kv: kv['score'], reverse=True)
    return squad_items[:top_items]

In [29]:
def find_squad_item(idx, question, answers, answer_types, word_tokenizer, query_top_results=3,
                    entity_top_results=3, min_bleu=0.8, min_entity_score=0.96, lang_code='es',
                    max_tokens_length=512, max_chars_length=1500, verbose=None):
    question_tokens = word_tokenizer(question)
    squad_items = []
    
    # Get top results of question and answers (full text search)
    question_page_ids = get_wiki_search(question, top_results=query_top_results, lang_code=lang_code)
    answer_page_ids = []
    for answer, answer_type in zip(answers, answer_types):
        if answer_type == 'entity':
            # Only search wiki pages using the answer if it is an entity
            answer_page_ids += get_wiki_search(answer, top_results=query_top_results, lang_code=lang_code)
    
    # Get top results of static entities in question and answers
    q_static_entities = get_static_entities(question)
    for entity in q_static_entities:
        question_page_ids += get_wiki_search(entity['entity_text'], top_results=query_top_results, lang_code=lang_code)
    a_static_entities = []
    for answer in answers:
        a_static_entities += get_static_entities(answer)
    for entity in a_static_entities:
        answer_page_ids += get_wiki_search(entity['entity_text'], top_results=query_top_results, lang_code=lang_code)
    
    # Get top results of dbpedia entities in quesiton
    question_entities = get_dbpedia_search(question, min_score=min_entity_score)
    answers_entities = []
    #for answer, answer_type in zip(answers, answer_types):
    #    if answer_type in ['number', 'number_with_unit']:
    #        continue
    #    answers_entities += get_dbpedia_search(answer, min_score=min_entity_score)
    
    question_entities = combine_qa_entities(question_entities, q_static_entities)
    answers_entities = combine_qa_entities(answers_entities, a_static_entities)
    found_entities = combine_qa_entities(question_entities, answers_entities)

    # Given a set of entities, create a common dictionary of Wiki Page IDs from entities and OpenSearch results
    entity_page_ids = {}
    for entity in found_entities:
        if 'wiki_page_id' not in entity:
            continue
        elif entity['wiki_page_id'] in entity_page_ids:
            if entity_page_ids[entity['wiki_page_id']]['entity_score'] < entity['entity_score']:
                entity_page_ids[entity['wiki_page_id']] = entity
        else:
            entity_page_ids[entity['wiki_page_id']] = entity
    
    entity_page_ids = sorted(entity_page_ids.items(), key=lambda kv: kv[1]['entity_score'], reverse=True)
    entity_page_ids = [x[1] for x in entity_page_ids][:entity_top_results]
    
    all_page_ids = {x['wiki_page_id']:x for x in question_page_ids}
    for answer_page_data in answer_page_ids:
        if answer_page_data['wiki_page_id'] not in all_page_ids:
            all_page_ids[answer_page_data['wiki_page_id']] = answer_page_data
    for entity_page_data in entity_page_ids:
        if entity_page_data['wiki_page_id'] not in all_page_ids:
            all_page_ids[entity_page_data['wiki_page_id']] = entity_page_data
    all_page_ids = list(all_page_ids.values())
    
    # Get page content of each ID
    for i, page_data in enumerate(all_page_ids):
        if verbose:
            print('- Item: %d / %d | Page: %d / %d | Found: %d' % (verbose['i'], verbose['n_items'], i+1,
                    len(all_page_ids), verbose['n_found']), ' '*20, end='\r')
    
        page_id = page_data['wiki_page_id']
        page_title = page_data['wiki_title']
        page_content = get_wiki_article(page_id, lang_code=lang_code)
        
        if page_content is None:
            continue
        
        # Split content into sentences
        content_sentences = nltk.tokenize.sent_tokenize(page_content)
        
        group_sentences = ''
        k = 0
        while k < len(content_sentences):
            sentence = content_sentences[k]
            tmp_group_sentences = (group_sentences + ' ' + sentence).strip()
            tmp_group_tokens = word_tokenizer(tmp_group_sentences)
            
            tokens_length = len(tmp_group_tokens) + len(question_tokens) # Context + question
            chars_length = sum(len(token) for token in tmp_group_tokens) # Only context
            if tokens_length >= (max_tokens_length - 3) and chars_length >= max_chars_length:
                # Context must have at least N entities of the question and M entities of the answer
                has_question_entities, question_entities_score = has_entities(group_sentences, question_entities, n=1)
                has_answer_entities, answer_entities_score = has_entities(group_sentences, answers_entities, n=1)
                if has_question_entities and has_answer_entities:
                    # Note that we substract 3 since we add 3 additional tokens when encoding for QA model training
                    # Try to find answers in the grouped sentences
                    found_answers = find_answers(group_sentences, answers, word_tokenizer, min_bleu=min_bleu)
                    tokens_score = has_tokens(group_sentences, question_tokens)
                    if len(found_answers) > 0:
                        squad_item = {
                            'score': question_entities_score + answer_entities_score,
                            'title': page_title,
                            'paragraphs': [{
                                'context': group_sentences,
                                'qas': [{
                                    'id': idx,
                                    'question': question,
                                    'answers': [],
                                }],
                            }],
                        }
                        for found_answer in found_answers:
                            squad_item['paragraphs'][0]['qas'][0]['answers'].append({
                                'answer_start': found_answer['answer_start'],
                                'text': found_answer['text'],
                            })
                        squad_items.append(squad_item)
                
                group_sentences = ''
            else:
                group_sentences = tmp_group_sentences
            k += 1
    return get_best_squad_items(squad_items, top_items=3)

In [30]:
def main(mkqa_dataset, lang_code='es', save_path='../artifacts/synthetic/', max_aliases=5):
    squad_dataset = {'data': []}
    word_tokenizer = get_word_tokenizer()
    os.makedirs(save_path, exist_ok=True)
    
    # Load config of parsing
    config_file = os.path.join(save_path, 'config.json')
    if os.path.exists(config_file):
        with open(config_file, 'r') as fp:
            config = json.load(fp)
    else:
        with open(config_file, 'w') as fp:
            config = {'skipped': [], 'found': []}
            json.dump(config, fp)
    
    # Count items
    n_items = sum([1 for _ in mkqa_dataset])
    
    print('Process items...')
    for i, item in enumerate(mkqa_dataset):
        idx = 'mkqa_' + str(item['example_id'])
        query = item['queries'][lang_code]
        
        filename = '%s.json' % idx
        output_file = os.path.join(save_path, filename)
        
        # Skip if already parsed
        if item['example_id'] in config['found'] or item['example_id'] in config['skipped']:
            continue
        
        print('- Item %d of %d' % (i + 1, n_items), ' '*20, end='\r')
        
        parsed_answers = []
        answer_types = []
        for raw_answer_data in item['answers'][lang_code]:
            main_answer = raw_answer_data['text']
            aliases = raw_answer_data['aliases'][:max_aliases] if 'aliases' in raw_answer_data else []
            iter_parsed_answers = parse_raw_answer_es(raw_answer_data['type'], main_answer, aliases)
            for iter_parsed_answer in iter_parsed_answers:
                if iter_parsed_answer not in parsed_answers:
                    parsed_answers.append(iter_parsed_answer)
                    answer_types.append(raw_answer_data['type'])
        
        if len(parsed_answers) == 0:
            # No answers for this query
            config['skipped'].append(item['example_id'])
            with open(config_file, 'w') as fp:
                json.dump(config, fp)
            continue
        
        print('- Item: %d / %d | Found: %d' % (i + 1, n_items, len(config['found'])), ' '*20, end='\r')
        
        squad_dataset['data'] = find_squad_item(idx, query, parsed_answers, answer_types, word_tokenizer, lang_code=lang_code,
                                                verbose={'i': i+1, 'n_items': n_items, 'n_found': len(config['found'])})
        
        if len(squad_dataset['data']) == 0:
            config['skipped'].append(item['example_id'])
            with open(config_file, 'w') as fp:
                json.dump(config, fp)
            continue
        
        with open(output_file, 'w') as fp:
            json.dump(squad_dataset, fp)
        config['found'].append(item['example_id'])
        with open(config_file, 'w') as fp:
            json.dump(config, fp)

In [None]:
while True:
    try:
        main(mkqa_dataset, lang_code='es')
        break
    except Exception as e:
        continue

Process items...
- Item: 2698 / 10000 | Found: 557                                     
Unexpected HTTP Status: 400 - http://es.dbpedia.org/resource/Thelma_&_Louise
- Item: 4208 / 10000 | Found: 871                                     
Unexpected HTTP Status: 400 - http://es.dbpedia.org/resource/Thelma_&_Louise
- Item: 4400 / 10000 | Found: 906                                     
Unexpected HTTP Status: 404 - http://es.dbpedia.org/resource/Entrenador
- Item: 4852 / 10000 | Found: 1000                                     
Unexpected HTTP Status: 400 - http://es.dbpedia.org/resource/Law_&_Order
- Item: 5259 / 10000 | Found: 1059                                     
Unexpected HTTP Status: 400 - http://es.dbpedia.org/resource/Emerson,_Lake_&_Palmer
- Item: 6102 / 10000 | Page: 18 / 18 | Found: 1160                     