# Synthetic SQuAD dataset based on MKQA (Google)

Need to install some Spacy dictionaries:
- `python -m spacy download en_core_web_sm`
- `python -m spacy download es_core_news_sm`
- `python -m spacy download ru_core_news_sm`
- `python -m spacy download ja_core_news_sm`
- `python -m spacy download xx_ent_wiki_sm`

For Vietnamese, install this repo: https://github.com/trungtv/vi_spacy

In [1]:
import json
import mwparserfromhell
import nltk
import numpy as np
import os
import random
import re
import requests
import spacy
import threading
import time
from time import sleep

from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizerFast

from requests.packages.urllib3.exceptions import InsecureRequestWarning
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)

In [2]:
LM_NAME = 'bert-base-multilingual-cased'
INPUT_FILE = '../data/mkqa/mkqa.jsonl'

"""
LANG_NAME = 'spanish'
LANG_CODE = 'es-ES'
REGION_CODE = 'es'
SPACY_DICT = 'es_core_news_sm'
"""

"""
LANG_NAME = 'japanese'
LANG_CODE = 'ja-JP'
REGION_CODE = 'ja'
SPACY_DICT = 'ja_core_news_sm'
"""

"""
LANG_NAME = 'russian'
LANG_CODE = 'ru-RU'
REGION_CODE = 'ru'
SPACY_DICT = 'ru_core_news_sm'
"""

LANG_NAME = 'vietnamese'
LANG_CODE = 'vi-VN'
REGION_CODE = 'vi'
SPACY_DICT = ''
from spacy.lang.vi import Vietnamese


OUTPUT_PATH = '../artifacts/synthetic_wikigoogle_top_n/'

In [3]:
with open(INPUT_FILE, 'r', encoding='utf-8') as fp:
    mkqa_dataset = list(fp)

mkqa_dataset = [json.loads(jline) for jline in mkqa_dataset]

In [4]:
if LANG_NAME == 'vietnamese':
    from spacy.lang.vi import STOP_WORDS as STOP_WORDS_VI
    STOPWORDS = set([x for x in STOP_WORDS_VI if x])
elif LANG_NAME == 'japanese':
    from spacy.lang.ja import STOP_WORDS as STOP_WORDS_JA
    STOPWORDS = set([x for x in STOP_WORDS_JA if x])
elif LANG_NAME in nltk.corpus.stopwords.fileids():
    STOPWORDS = set(nltk.corpus.stopwords.words(LANG_NAME))
else:
    STOPWORDS = set()

len(STOPWORDS)

1942

In [5]:
def str_to_num(text):
    try:
        return int(text)
    except:
        return float(text)

## Google Search functions

Useful resources:
- API usage: https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list
- Free API Key: https://developers.google.com/custom-search/v1/introduction
- Setup Search Engine and get ID: https://cse.google.com/

Non-free API Key can be obtained from the Google's console.

In [6]:
API_KEY = 'XXXX'
SEARCH_ENGINE_ID = 'XXXX'

In [7]:
def google_search(query, top_results=5, lang_code=LANG_CODE, region_code=REGION_CODE,
                  recall=False, site_restrict=True):
    query = re.sub(r'[\u0060\u00B4\u2018\u2019]', '\'', query)
    query = re.sub(r'[\u201C\u201D]', '"', query)
    query = re.sub(r'"', '', query)
    
    quota_user = '%s%d' % (region_code, np.random.randint(1, 10000000))
    params = {
        'quotaUser': quota_user,
        'key': API_KEY,
        'cx': SEARCH_ENGINE_ID,
        'q': query,
        'hl': lang_code,
        'gl': region_code,
    }
    if site_restrict:
        url = 'https://www.googleapis.com/customsearch/v1/siterestrict'
    else:
        url = 'https://www.googleapis.com/customsearch/v1'
    
    r = requests.get(url=url, params=params)
    if r.status_code == 429:
        print(quota_user, r.text)
        raise Exception('GCloud limit reached!')
    elif r.status_code != 200:
        #print('')
        #print('GSearch - HTTP Status: %d - Query: %s' % (r.status_code, query))
        print(r.text)
        raise Exception('GSearch - HTTP Status: %d - Query: %s' % (r.status_code, query))
        return []

    raw_data = r.json()
    
    try:
        if 'items' not in raw_data:
            if recall:
                # If already this method has been re-called
                return []
            elif 'spelling' in raw_data and 'correctedQuery' in raw_data['spelling']:
                return google_search(query=raw_data['spelling']['correctedQuery'],
                                     top_results=top_results,
                                     lang_code=lang_code,
                                     region_code=region_code,
                                     recall=True)
            else:
                print('GSearch - No items - Query: %s' % (query))
                return []
                #raise Exception('No items')
        else:
            parsed_data = []
            for item in raw_data['items']:
                if 'title' not in item or 'link' not in item:
                    # Bad links
                    continue
                elif re.search(r'\.pdf', item['link']):
                    # Skip pdfs
                    continue
                parsed_data.append({
                    'title': item['title'],
                    'url': item['link'],
                })
            return parsed_data[:top_results]
    except Exception as e:
        print()
        print('Query: %s' % query)
        raise e

In [8]:
query = 'quién interpreta a pam en aquellos maravillosos 70'
google_search(query)

[{'title': "That '70s Show - Wikipedia, la enciclopedia libre",
  'url': 'https://es.wikipedia.org/wiki/That_%2770s_Show'},
 {'title': 'Brooke Shields - Wikipedia, la enciclopedia libre',
  'url': 'https://es.wikipedia.org/wiki/Brooke_Shields'},
 {'title': "That '70s Show - Viquipèdia, l'enciclopèdia lliure",
  'url': 'https://ca.wikipedia.org/wiki/That_%2770s_Show'},
 {'title': 'Jenna Fischer - Wikipedia, la enciclopedia libre',
  'url': 'https://es.wikipedia.org/wiki/Jenna_Fischer'},
 {'title': 'Bobcat Goldthwait - Wikipedia, la enciclopedia libre',
  'url': 'https://es.wikipedia.org/wiki/Bobcat_Goldthwait'}]

In [9]:
query = 'フェルナンドアロンソは誰ですか？'
google_search(query)

[{'title': 'フェルナンド・アロンソ - Wikipedia',
  'url': 'https://ja.wikipedia.org/wiki/%E3%83%95%E3%82%A7%E3%83%AB%E3%83%8A%E3%83%B3%E3%83%89%E3%83%BB%E3%82%A2%E3%83%AD%E3%83%B3%E3%82%BD'},
 {'title': 'F1歴代記録 - Wikipedia',
  'url': 'https://ja.wikipedia.org/wiki/F1%E6%AD%B4%E4%BB%A3%E8%A8%98%E9%8C%B2'},
 {'title': 'F1ドライバーズチャンピオンの一覧 - Wikipedia',
  'url': 'https://ja.wikipedia.org/wiki/F1%E3%83%89%E3%83%A9%E3%82%A4%E3%83%90%E3%83%BC%E3%82%BA%E3%83%81%E3%83%A3%E3%83%B3%E3%83%94%E3%82%AA%E3%83%B3%E3%81%AE%E4%B8%80%E8%A6%A7'},
 {'title': '2006年のF1世界選手権 - Wikipedia',
  'url': 'https://ja.wikipedia.org/wiki/2006%E5%B9%B4%E3%81%AEF1%E4%B8%96%E7%95%8C%E9%81%B8%E6%89%8B%E6%A8%A9'},
 {'title': 'ミハエル・シューマッハ - Wikipedia',
  'url': 'https://ja.wikipedia.org/wiki/%E3%83%9F%E3%83%8F%E3%82%A8%E3%83%AB%E3%83%BB%E3%82%B7%E3%83%A5%E3%83%BC%E3%83%9E%E3%83%83%E3%83%8F'}]

## Wikipedia API

In [10]:
def get_wiki_article(url):
    # Get wiki region and page ID
    matches = re.search(r'https?://(?:www)?(.+)\.wikipedia\.org/(?:wiki|.+-.+)/([^&]+)', url)
    if not matches:
        print('Unexpected Wiki URL format: %s' % url)
        raise Exception('Unexpected Wiki URL format: %s' % url)
        #return None
        
    wiki_region = matches[1]
    page_id = matches[2]
    
    api_url = 'https://%s.wikipedia.org/w/api.php' % wiki_region + \
        '?action=parse' + \
        '&page=%s' % page_id + \
        '&prop=text' + \
        '&format=json'
        #'&section=1' + \
    
    r = requests.get(api_url)
    if r.status_code != 200:
        print('')
        print('Unexpected HTTP Status: %d - %s' % (r.status_code, url))
        raise Exception('Unexpected HTTP Status: %d' % r.status_code)
        #return None
    
    json_data = r.json()
    if 'parse' not in json_data:
        # Not found
        return ''
    article = json_data['parse']['text']['*']
    
    parsed_wikicode = mwparserfromhell.parse(article)
    article = parsed_wikicode.strip_code()
    
    # Clean wiki rubbish
    article = re.sub(r'\[[^\]]+\]+', ' ', article)
    
    return article

In [11]:
def filter_paragraphs(text, min_length=50, max_length=1300, group_max_length=1000):
    paragraphs = re.split(r'\n\n*', text)
    paragraphs = [x.strip() for x in paragraphs]
    paragraphs = [re.sub(r'[\u200b\xa0 ]+', ' ', x) for x in paragraphs]
    paragraphs = [x for x in paragraphs if len(x) > 5]
    
    # Split big paragraphs
    measured_paragraphs = []
    for p in paragraphs:
        if len(p) > max_length:
            # Big paragraph must be split
            p_sentences = nltk.tokenize.sent_tokenize(p)
            grouped_sentences = ''
            k = 0
            while k < len(p_sentences):
                sentence = p_sentences[k]
                if len(grouped_sentences) > group_max_length:
                    measured_paragraphs.append(grouped_sentences)
                    grouped_sentences = ''
                else:
                    grouped_sentences = ('%s %s' % (grouped_sentences, sentence)).strip()
                    k += 1
        elif len(p) >= min_length:
            # The paragraph is not too small
            measured_paragraphs.append(p)

    return measured_paragraphs

In [12]:
url = 'https://es.wikipedia.org/wiki/Pearl_Jam'

content = get_wiki_article(url)
content = filter_paragraphs(content)
print(content)

['Para otros usos de este término, véase Pearl Jam (desambiguación).', 'Pearl JamPearl Jam en Oakland, 2013.Datos generalesOrigen Seattle, Washington, Estados UnidosEstadoActivoInformación artísticaGénero(s)Rock alternativo', 'Hard rockPeríodo de actividad1990-presenteDiscográfica(s)J RecordsEpic RecordsUniversal Music GroupMonkeywrench RecordsWebSitio webSitio OficialMiembrosEddie VedderMike McCreadyStone GossardJeff AmentMatt CameronKenneth GasparExmiembrosDave KrusenMatt ChamberlainDave AbbruzzeseJack Irons', 'Pearl Jam es un grupo de grunge formado en Seattle, Estados Unidos, en el año 1990, con integrantes de las bandas Mother Love Bone y Temple of the Dog. Con la edición de su álbum debut Ten en 1991, Pearl Jam irrumpiría con fuerza en el ámbito musical alternativo. Junto a Nirvana, Alice in Chains, Stone Temple Pilots y Soundgarden están considerados como una de las bandas más grandes e influyentes de toda la escena del movimiento Grunge. Sus miembros fundadores y que aun siguen

In [13]:
url = 'https://zh.wikipedia.org/zh-sg/%E4%B8%AD%E5%AF%86%E6%AD%87%E6%A0%B9%E5%A4%A7%E5%AD%A6'
matches = re.search(r'https?://(?:www)?(.+)\.wikipedia\.org/(?:wiki|.+-.+)/([^&]+)', url)
wiki_region = matches[1]
page_id = matches[2]

print(wiki_region, page_id)

zh %E4%B8%AD%E5%AF%86%E6%AD%87%E6%A0%B9%E5%A4%A7%E5%AD%A6


## MKQA functions

In [14]:
def get_es_number_unit(raw_unit_name):
    es_conversion = {
        'Antes de la era vulgar': ['a.C.'],
        'Galones': ['galón'],
        'Millas por hora': ['mph', 'milla por hora'],
        'acre': ['acre'],
        'antes del Mediodia': ['AM', 'A.M.'],
        'año terrestre': ['año'],
        'caballo de potencia metrico': ['caballo de potencia', 'caballo', 'hp', 'cv'],
        'centímetro': ['cm', 'centímetro'],
        'día': ['día'],
        'dólar estadounidense': ['dólar', '$'],
        'episodio': ['episodio'],
        'escala Fahrenheit': ['grado Fahrenheit', 'Fahrenheit', 'ºF'],
        'estaciones del año': ['temporada'],
        'grados centigrados': ['grado centigrado', 'grado', 'ºC'],
        'gramo': ['gramo', 'gr', 'g'],
        'hora': ['hora', 'h'],
        'kilometraje': ['kilómetro', 'km'],
        'libra avoirdupois': ['libra avoirdupois', 'libra', 'lb'],
        'light año terrestre': ['año luz'],
        'mes sinódico': ['mes sinódico', 'mes'],
        'metros': ['metro', 'm'],
        'metros por segundo': ['metro por segundo', 'mps', 'm/s'],
        'mililitro': ['mililitro', 'ml'],
        'milimetro': ['milímetro', 'mm'],
        'milla': ['milla', 'mi'],
        'millas cuadradas': ['milla cuadrada'],
        'minuto': ['minuto'],
        'onza': ['onza'],
        'other currency': [],
        'other unit': [],
        'palabra': ['palabra'],
        'pie': ['pie'],
        'pies cuadrados': ['pie cuadrado'],
        'post meridiem (time)': ['PM', 'P.M.'],
        'pulgada': ['pulgada'],
        'segundos': ['segundo'],
        'septenario': ['septenario'],
        'tanto por ciento': ['porcentaje', '%'],
    }
    
    if raw_unit_name not in es_conversion or len(es_conversion[raw_unit_name]) == 0:
        return ['']
    else:
        return [x for x in es_conversion[raw_unit_name]]

def parse_nwu_answer_es(raw_answer):
    """
    Parses Spanish answers of type number_with_unit (nwu).
    """
    answers = []

    x_interval = re.search(r'^([\d.]+) ([\d.]+)(?: (.+))?$', raw_answer)
    if x_interval:
        unit_value_1 = x_interval[1]
        unit_value_2 = x_interval[2]
        unit_names = get_es_number_unit(x_interval[3])
        if len(unit_names) == 0:
            # Skip if not unit name is provided in range of values
            pass
        else:
            for unit_name in unit_names:
                answers.append('entre %s y %s %s' % (unit_value_1, unit_value_2, unit_name))
                answers.append('desde %s hasta %s %s' % (unit_value_1, unit_value_2, unit_name))
    else:
        x_single = re.search(r'^([\d.]+)(?: (.+))$', raw_answer)
        if x_single is None:
            return []
        unit_value = x_single[1]
        unit_names = get_es_number_unit(x_single[2])
        if len(unit_names) == 0:
            answers.append(unit_value)
        else:
            for unit_name in unit_names:
                answers.append('%s %s' % (unit_value, unit_name))
    
    return answers

In [15]:
def get_ja_number_unit(raw_unit_name):
    return [raw_unit_name]

def parse_nwu_answer_ja(raw_answer):
    """
    Parses Japanese answers of type number_with_unit (nwu).
    """
    answers = []

    x_interval = re.search(r'^([\d.]+) ([\d.]+)(?: (.+))?$', raw_answer)
    if x_interval:
        unit_value_1 = x_interval[1]
        unit_value_2 = x_interval[2]
        unit_names = get_ja_number_unit(x_interval[3])
        if len(unit_names) == 0:
            # No unit names provided
            unit_names = ['']

        for unit_name in unit_names:
            answers.append('%s%sから%s%sまで' % (unit_value_1, unit_name, unit_value_2, unit_name))
            answers.append('%s%sから%s%s' % (unit_value_1, unit_name, unit_value_2, unit_name))
            answers.append('%s%s乃至%s%s' % (unit_value_1, unit_name, unit_value_2, unit_name))
    else:
        x_single = re.search(r'^([\d.]+)(?: (.+))?$', raw_answer)
        if x_single is None:
            return []
        unit_value = x_single[1]
        unit_names = get_ja_number_unit(x_single[2])
        if len(unit_names) == 0:
            answers.append(unit_value)
        else:
            for unit_name in unit_names:
                answers.append('%s%s' % (unit_value, unit_name))
    
    return answers

In [16]:
def get_ru_number_unit(raw_unit_name):
    ru_conversion = {
        'Ounce': ['унция'],
        'Pound': ['фунт', 'lb'],
        'a.m. (time)': ['часов дня', 'утра'],
        'bc (date)': ['до н.э.'],
        'other currency': [],
        'other unit': [],
        'p.m. (time)': ['часов вечера', 'вечера'],
        'square английский фут': ['квадратный английский фут', 'квадратный фут', 'фут'],
        'акр': ['акр'],
        'американский доллар': ['американский доллар', 'доллар', '$'],
        'английский фут': ['английский фут', 'фут'],
        'времена года': ['времена года', 'сезон'],
        'градус Фаренгейта': ['градус Фаренгейта', 'Фаренгейта', 'ºF'],
        'грамм': ['грамм', 'гр', 'г'],
        'дюйм': ['дюйм', 'inch'],
        'имперский галлон': ['имперский галлон', 'галлон', 'gallon'],
        'квадратная миля': ['квадратная миля', 'миля'],
        'километр': ['километр', 'км'],
        'лет': ['лет'],
        'лошадиная сила': ['лошадиная сила', 'л.с', 'лс'],
        'метр': ['метр', 'м'],
        'метр в секунду': ['метр в секунду', 'м/с', 'm/s'],
        'миллилитр': ['миллилитр', 'мл'],
        'миллиметр': ['миллиметр', 'мм'],
        'миля': ['миля'],
        'миля в час': ['миля в час', 'м/с'],
        'минута': ['минута', 'мин', 'м'],
        'сантиметр': ['сантиметр', 'см'],
        'световой год': ['световой год', 'св. год', 'св. г.'],
        'седьмица': ['седьмица', 'седмица', 'неделя'],
        'секунда': ['секунда', 'с'],
        'синодический лунный месяц': ['синодический лунный месяц', 'лунный месяц'],
        'слово': ['слово'],
        'сотая доля': ['сотая доля', 'сотая часть', 'сотый'],
        'суток': ['суток', 'день'],
        'температурная шкала Цельсия': ['градус Цельсия', 'Цельсия', 'ºC'],
        'часов': ['час'],
        'эпизод': ['эпизод'],
    }
    
    if raw_unit_name not in ru_conversion or len(ru_conversion[raw_unit_name]) == 0:
        return ['']
    else:
        return [x for x in ru_conversion[raw_unit_name]]

def parse_nwu_answer_ru(raw_answer):
    """
    Parses Russian answers of type number_with_unit (nwu).
    """
    answers = []

    x_interval = re.search(r'^([\d.]+) ([\d.]+)(?: (.+))?$', raw_answer)
    if x_interval:
        unit_value_1 = x_interval[1]
        unit_value_2 = x_interval[2]
        unit_names = get_es_number_unit(x_interval[3])
        if len(unit_names) == 0:
            # Skip if not unit name is provided in range of values
            pass
        else:
            for unit_name in unit_names:
                answers.append('от %s %s до %s %s' % (unit_value_1, unit_name, unit_value_2, unit_name))
                answers.append('от %s до %s %s' % (unit_value_1, unit_value_2, unit_name))
    else:
        x_single = re.search(r'^([\d.]+)(?: (.+))$', raw_answer)
        if x_single is None:
            return []
        unit_value = x_single[1]
        unit_names = get_es_number_unit(x_single[2])
        if len(unit_names) == 0:
            answers.append(unit_value)
        else:
            for unit_name in unit_names:
                answers.append('%s %s' % (unit_value, unit_name))
    
    return answers

In [17]:
def get_vi_number_unit(raw_unit_name):
    vi_conversion = {
        'Sức ngựa': ['Sức ngựa', 'Mã lực', 'HP', 'CV'],
        'Tiếng đồng hồ': ['Tiếng đồng hồ', 'Giờ'],
        'Tập chương trình': ['Tập chương trình'],
        'ante meridiem': ['A.M.', 'AM'],
        'bc (date)': ['bc'],
        'chữ': ['chữ'],
        'dặm Anh': ['dặm Anh', 'dặm'],
        'dặm vuông Anh': ['dặm vuông Anh', 'dậm vuông Anh' 'dặm vuông'],
        'foot': ['foot', 'chân'],
        'foot vuông': ['foot vuông', 'square foot'],
        'ga-lông': ['ga-lông', 'galông', 'gallon'],
        'gam': ['gờ ram', 'cờ ram', 'gam', 'gram', 'g'],
        'giây': ['giây', 's'],
        'inch': ['inch', 'in'],
        'ki-lô-mét': ['ki-lô-mét', 'kilômét', 'km'],
        'miles per Tiếng đồng hồ': ['dặm một giờ', 'mph'],
        'mililít': ['mililít', 'ml'],
        'milimét': ['milimét', 'mm'],
        'mét': ['mét', 'm'],
        'mét trên giây': ['mét trên giây', 'm/s'],
        'mùa': ['mùa', 'phần'],
        'mẫu vuông': ['mẫu vuông', 'mẫu', 'acre'],
        'ngày': ['ngày'],
        'năm': ['năm'],
        'năm ánh sáng': ['năm ánh sáng'],
        'other currency': [],
        'other unit': [],
        'ounce avoirdupois quốc tế': ['ounce', 'oz'],
        'phút': ['phút', 'm', 'ph'],
        'phần trăm': ['phần trăm', '%'],
        'post meridiem (time)': ['P.M.', 'PM'],
        'pound': ['pound', 'cân Anh', 'pao', 'lb'],
        'tháng': ['tháng'],
        'tuần': ['tuần'],
        'xentimét': ['xentimét', 'cm'],
        'đô la Hoa Kì': ['đô la Hoa Kì', 'đô la Hoa Kỳ', 'đô la Mỹ', 'Mỹ kim', '$'],
        'độ Fahrenheit': ['độ Fahrenheit', 'độ F', '°F'],
        'độ bách phân': ['độ Celsius', 'độ C', '°C'],
    }
    
    if raw_unit_name not in vi_conversion or len(vi_conversion[raw_unit_name]) == 0:
        return ['']
    else:
        return [x for x in vi_conversion[raw_unit_name]]

def parse_nwu_answer_vi(raw_answer):
    """
    Parses Vietnamese answers of type number_with_unit (nwu).
    """
    answers = []

    x_interval = re.search(r'^([\d.]+) ([\d.]+)(?: (.+))?$', raw_answer)
    if x_interval:
        unit_value_1 = x_interval[1]
        unit_value_2 = x_interval[2]
        unit_names = get_es_number_unit(x_interval[3])
        if len(unit_names) == 0:
            # Skip if not unit name is provided in range of values
            pass
        else:
            for unit_name in unit_names:
                answers.append('từ %s đến %s %s' % (unit_value_1, unit_value_2, unit_name))
                answers.append('từ %s %s đến %s %s' % (unit_value_1, unit_name, unit_value_2, unit_name))
    else:
        x_single = re.search(r'^([\d.]+)(?: (.+))$', raw_answer)
        if x_single is None:
            return []
        unit_value = x_single[1]
        unit_names = get_es_number_unit(x_single[2])
        if len(unit_names) == 0:
            answers.append(unit_value)
        else:
            for unit_name in unit_names:
                answers.append('%s %s' % (unit_value, unit_name))
    
    return answers

In [18]:
def parse_date_answer_es(raw_answer):
    date_parts = re.search('^([\d.]+)-([\d.]+)-([\d.]+)$', raw_answer)
    if date_parts is None:
        # No ISO date format
        return [raw_answer]
    
    month_conversion = ['enero', 'febrero', 'marzo', 'abril', 'mayo', 'junio', 'julio',
                        'agosto', 'septiembre', 'octubre', 'noviembre', 'diciembre']
    
    year_number = date_parts[1]
    month_number = date_parts[2]
    month_str = month_conversion[str_to_num(date_parts[2]) - 1]
    day_number = date_parts[3]
    
    return [
        '%s-%s-%s' % (year_number, month_number, day_number),
        '%s-%s-%s' % (day_number, month_number, year_number),
        '%s-%s-%s' % (month_number, day_number, year_number),
        '%s/%s/%s' % (month_number, day_number, year_number),
        '%s de %s del %s' % (day_number, month_number, year_number),
        '%s de %s, %s' % (day_number, month_number, year_number),
    ]

In [19]:
def parse_date_answer_ja(raw_answer):
    date_parts = re.search('^([\d.]+)-([\d.]+)-([\d.]+)$', raw_answer)
    if date_parts is None:
        # No ISO date format
        return [raw_answer]
    
    number_conversion = ['一', '二', '三', '四', '五', '六', '七', '八', '九', '十',
                         '十一', '十二', '十三', '十四', '十五', '十六', '十七', '十八', '十九', '二十',
                         '二十一', '二十二', '二十三', '二十四', '二十五', '二十六', '二十七', '二十八', '二十九', '三十',
                         '三十一']
    
    year_number = date_parts[1]
    month_number = date_parts[2]
    day_number = date_parts[3]
    
    all_dates = [
        '%s-%s-%s' % (year_number, month_number, day_number),
        '%s-%s-%s' % (day_number, month_number, year_number),
        '%s/%s/%s' % (year_number, month_number, day_number),
        '%s/%s/%s' % (day_number, month_number, year_number),
    ]
    
    kanji_date = '%s年 %s月 %s日' % (year_number, month_number, day_number)
    all_dates = [kanji_date] + all_dates
    
    month_str = number_conversion[str_to_num(date_parts[2]) - 1]
    day_str = number_conversion[str_to_num(date_parts[3]) - 1]
    kanji_date = '%s月 %s日' % (month_str, day_str)
    all_dates = [kanji_date] + all_dates
    
    return all_dates

In [20]:
def parse_date_answer_ru(raw_answer):
    date_parts = re.search('^([\d.]+)-([\d.]+)-([\d.]+)$', raw_answer)
    if date_parts is None:
        # No ISO date format
        return [raw_answer]
    
    month_conversion = ['январь', 'февраль', 'март', 'апрель', 'май', 'июнь', 'июль',
                        'август', 'сентябрь', 'октябрь', 'ноябрь', 'декабрь']
    
    year_number = date_parts[1]
    month_number = date_parts[2]
    month_str = month_conversion[str_to_num(date_parts[2]) - 1]
    day_number = date_parts[3]
    
    return [
        '%s-%s-%s' % (day_number, month_number, year_number),
        '%s.%s.%s' % (day_number, month_number, year_number),
        '%s/%s/%s' % (day_number, month_number, year_number),
        '%s %s %s г.' % (day_number, month_str, year_number),
    ]

In [21]:
def parse_date_answer_vi(raw_answer):
    date_parts = re.search('^([\d.]+)-([\d.]+)-([\d.]+)$', raw_answer)
    if date_parts is None:
        # No ISO date format
        return [raw_answer]
    
    month_conversion = ['giêng', 'hai', 'ba', 'bốn', 'năm', 'sáu',
                        'bảy', 'tám', 'chín', 'mười', 'mười một', 'mười hai']
    
    year_number = date_parts[1]
    month_number = date_parts[2]
    month_str = month_conversion[str_to_num(date_parts[2]) - 1]
    day_number = date_parts[3]
    
    return [
        '%s tháng %s %s' % (day_number, month_str, year_number),
        '%s tháng %s %s' % (day_number, month_number, year_number),
        '%s-%s-%s' % (year_number, month_number, day_number),
        '%s-%s-%s' % (day_number, month_number, year_number),
        '%s/%s/%s' % (day_number, month_number, year_number),
    ]

In [22]:
def parse_raw_answer(type, main_answer, aliases, region_code=REGION_CODE):
    """
    Parses MKQA Answers.
    """
    
    if REGION_CODE == 'es':
        parse_functions = {
            'number_with_unit': parse_nwu_answer_es,
            'date': parse_date_answer_es,
        }
    elif REGION_CODE == 'ja':
        parse_functions = {
            'number_with_unit': parse_nwu_answer_ja,
            'date': parse_date_answer_ja,
        }
    elif REGION_CODE == 'ru':
        parse_functions = {
            'number_with_unit': parse_nwu_answer_ru,
            'date': parse_date_answer_ru,
        }
    elif REGION_CODE == 'vi':
        parse_functions = {
            'number_with_unit': parse_nwu_answer_vi,
            'date': parse_date_answer_vi,
        }
    else:
        raise Exception('Unknown region code: %s' % REGION_CODE)
    
    parsed_answers = []
    if type == 'number_with_unit':
        # Example: 16.0 año terrestre
        parsed_answers += parse_functions[type](main_answer)
        for alias in aliases:
            parsed_answers += parse_functions[type](alias)
    #elif type == 'number':
    #    # Example: 104.0
    #    parsed_answers.append(main_answer)
    #    for alias in aliases:
    #        parsed_answers.append(alias)
    elif type == 'date':
        # Example: 2001-08-29
        parsed_answers += parse_functions[type](main_answer)
    elif type == 'entity':
        # Example: Pokémon Ranger: Sombras de Almia
        parsed_answers.append(main_answer.strip())
        for alias in aliases:
            parsed_answers.append(alias.strip())
    elif type == 'short_phrase':
        # Example: rosemary almond
        parsed_answers.append(main_answer.strip())
        for alias in aliases:
            parsed_answers.append(alias.strip())
    else:
        # Ignored types: unanswerable, long_answer, binary
        pass
    
    # Filter low-quality answers (e.g. only numbers)
    parsed_answers = [x for x in parsed_answers if not re.search(r'^[\d.]+$', x)]
    parsed_answers = [x for x in parsed_answers if x != '']
    
    return parsed_answers

In [23]:
def find_answers(context, context_tokens, answers, answers_tokens, min_bleu, min_answer_length=3, max_answer_length=30):
    found_answers = []
    
    context_lemmas = [token[1] for token in context_tokens]
    context_offsets = [token[2] for token in context_tokens]
    
    for answer_tokens in answers_tokens:
        answer_lemmas = [token[1] for token in answer_tokens]
        
        window_size = len(answer_tokens)
        i = 0
        while i < len(context_tokens) - window_size:
            score = bleu_score(answer_lemmas, context_lemmas[i:i+window_size])
            span_start = context_offsets[i][0]
            span_end = context_offsets[i+window_size-1][1]
            context_answer = context[span_start:span_end]
            if (score >= min_bleu and len(context_answer) <= max_answer_length and
                len(context_answer) >= min_answer_length and context_answer.strip() != ''):
                found_answers.append({
                    'answer_start': span_start,
                    'answer_end': span_end,
                    'text': context[span_start:span_end],
                })
                i += window_size # Skips answer position
            else:
                i += 1
    
    return found_answers

## Tokenization and analysis

In [24]:
def get_piece_tokenizer(artifacts_path='../artifacts/', lm_name='bert-base-multilingual-cased', lowercase=False):
    save_path = '%s%s/' % (artifacts_path, lm_name)
    tokenizer = BertTokenizerFast('%svocab.txt' % save_path, do_lower_case=lowercase)
    return lambda text : [text[start:end] \
                          for (start, end) in tokenizer(text, return_offsets_mapping=True,
                                                        return_special_tokens_mask=False)['offset_mapping'][1:-1]]

In [25]:
def get_word_tokenizer(dictionary_name=None, region_code=REGION_CODE):
    if region_code == 'vi':
        nlp = Vietnamese()
    else:
        nlp = spacy.load(dictionary_name)
    return lambda text : [(token.text, (token.lemma_ if token.lemma_ != '' else token.text.lower()), [token.idx, token.idx + len(token.text)]) for token in nlp(text)]

In [26]:
def bleu_score(reference, hypothesis):
    matches = [int(x == y) for x, y in zip(reference, hypothesis)]
    return sum(matches) / len(reference)

In [27]:
a = ['esta', 'es', 'prueba','1']
b = ['esta', 'es', 'prueba','2']

bleu_score(a, b)

0.75

In [28]:
text = '¡Hola mundo! ¡Adiós mundo!'

tokenizer = get_piece_tokenizer()
tokenizer(text)

['¡', 'Ho', 'la', 'mundo', '!', '¡', 'Adi', 'ós', 'mundo', '!']

In [29]:
text = '¡Hola mundo! ¡Adiós mundo!'

tokenizer = get_word_tokenizer(SPACY_DICT)
tokenizer(text)

[('¡', '¡', [0, 1]),
 ('Hola', 'hola', [1, 5]),
 ('mundo', 'mundo', [6, 11]),
 ('!', '!', [11, 12]),
 ('¡', '¡', [13, 14]),
 ('Adiós', 'adiós', [14, 19]),
 ('mundo', 'mundo', [20, 25]),
 ('!', '!', [25, 26])]

## App

In [30]:
class ItemPicker():
    def __init__(self, dataset, timeout=60, sleep_interval=1):
        self._n_items = len(dataset)
        self._dataset = dataset
        self._idx = 0
        self._timeout = timeout
        self._sleep_interval = sleep_interval
        self._locked = False
    
    def pick(self):
        self.lock()
        if self._idx >= self._n_items:
            item = None
        else:
            item = self._dataset.pop(0)
            self._idx += 1
        self.unlock()
        return item

    def lock(self):
        start_time = time.time()
        while self._locked:
            end_time = time.time()
            if end_time - start_time >= self._timeout:
                raise Exception('Cannot pick an item (timeout)')
            sleep(self._sleep_interval)
        self._locked = True
    
    def unlock(self):
        self._locked = False

In [31]:
def get_best_squad_items(squad_items, top_items):
    squad_items = sorted(squad_items, key=lambda kv: kv['score'], reverse=True)
    return squad_items[:top_items]

In [32]:
def get_static_entities(text, default_score=5.):
    all_matches = []
    matches = re.findall(r'"(.+?)"', text)
    if len(matches) > 0:
        all_matches += [x.strip() for x in matches]
    matches = re.findall(r'\'(.+?)\'', text)
    if len(matches) > 0:
        all_matches += [x.strip() for x in matches]
    matches = re.findall(r'((?:(?:[A-Z][A-Za-z]+|de)\s?)+)', text)
    if len(matches) > 0:
        all_matches += [x.strip() for x in matches]
    
    # Format static entities
    static_entities = []
    for x in all_matches:
        static_entities.append({
            'entity_text': x,
            'entity_score': default_score,
            'is_mandatory': True,
        })
    
    return static_entities

In [33]:
def has_entities(context, entities, n=1):
    if entities is None or len(entities) == 0:
        return True, 0.
    
    counter = 0
    max_score = 0.
    current_score = 0.
    for entity in entities:
        max_score = entity['entity_score']
        if context.find(entity['entity_text']) != -1:
            # Found entity
            counter += 1
            current_score += entity['entity_score']
        elif 'is_mandatory' in entity and entity['is_mandatory']:
            # Mandatory entitiy not found
            return False, 0.
    
    return (counter >= n), (current_score / max_score)

In [34]:
def has_tokens(context, tokens, token_score=1.):
    if tokens is None or len(tokens) == 0:
        raise Exception('No tokens provided for context: %s' % context)
    
    max_score = 0
    current_score = 0.
    for token in tokens:
        if token in list(STOPWORDS):
            continue
        max_score += token_score
        if context.find(token) != -1:
            # Found token
            current_score += token_score
    
    return current_score / max_score if max_score > 0 else 0.

In [35]:
def find_squad_item(idx, question, answers, answer_types, piece_tokenizers,
                    word_tokenizer, query_top_results, max_threads=3,
                    min_bleu=0.8, max_pieces_length=512, max_chars_length=1500):
    if len(piece_tokenizers) != max_threads:
        raise Exception('Need more tokenizers (expected: %d).' % max_threads)

    question_pieces = piece_tokenizers[0](question)
    question_tokens = word_tokenizer(question)
    
    answers_tokens = [word_tokenizer(answer) for answer in answers]
    
    # Sort answers by length
    a_zip = sorted(zip(answers, answer_types, answers_tokens), key=lambda x: len(x[0]), reverse=True)
    answers, answer_types, answers_tokens = list(zip(*a_zip))
    
    # Get top results of Google
    url_items = google_search(question, top_results=query_top_results)
    static_entities = get_static_entities(question)
    
    # Get page content of each ID
    squad_items = []
    end_status = []
    all_threads = []

    n_threads = 0
    url_picker = ItemPicker(url_items)
    n_threads = min([max_threads, len(url_items)])
    
    for i in range(n_threads):
        x = threading.Thread(
            target=find_squad_item_in_url,
            args=(end_status, squad_items, idx, url_picker, piece_tokenizers[i], word_tokenizer,
                  question, question_pieces, question_tokens, answers, answer_types, answers_tokens,
                  static_entities, min_bleu, max_pieces_length, max_chars_length))
        x.start()
        all_threads.append(x)
        sleep(1)
    
    # Wait for threads
    for x in all_threads:
        x.join()
    
    if sum(end_status) != len(all_threads):
        raise Exception('Some thread failed!')
    
    return squad_items

In [36]:
def find_squad_item_in_url(end_status, squad_items, idx, url_picker, piece_tokenizer, word_tokenizer,
                           question, question_pieces, question_tokens, answers, answer_types, answers_tokens,
                           static_entities, min_bleu, max_pieces_length, max_chars_length):
    url_item = url_picker.pick()
    if url_item is None:
        end_status.append(1)
        return
    
    page_url = url_item['url']
    page_title = url_item['title']
    
    try:
        wiki_article = get_wiki_article(page_url)
    except Exception as e:
        print(page_title, '|', page_url)
        raise e
    try:
        paragraphs = filter_paragraphs(wiki_article)
    except Exception as e:
        print(page_title, '|', page_url)
        print(wiki_article)
        raise e

    for paragraph_i, paragraph in enumerate(paragraphs):
        try:
            paragraph_pieces = piece_tokenizer(paragraph)
            paragraph_tokens = word_tokenizer(paragraph)
        except:
            # Something bad is in the paragraph
            continue

        pieces_length = len(paragraph_pieces) + len(question_pieces) # Context + question
        chars_length = sum(len(piece) for piece in paragraph_pieces) # Only context

        if pieces_length <= (max_pieces_length - 3) and chars_length <= max_chars_length:
            p_has_entities, entities_score = has_entities(paragraph, static_entities, n=1)
            if p_has_entities:
                found_answers = find_answers(paragraph, paragraph_tokens, answers, answers_tokens, min_bleu=min_bleu)
                tokens_score = has_tokens(paragraph, [token[1] for token in question_tokens])
                if len(found_answers) > 0:
                    squad_score = entities_score + tokens_score + max([len(x['text']) for x in found_answers])
                    squad_item = {
                        'score': squad_score,
                        'title': page_title,
                        'paragraphs': [{
                            'context': paragraph,
                            'qas': [{
                                'id': '%s_%d' % (idx, paragraph_i),
                                'question': question,
                                'answers': [],
                            }],
                        }],
                    }
                    for found_answer in found_answers:
                        squad_item['paragraphs'][0]['qas'][0]['answers'].append({
                            'answer_start': found_answer['answer_start'],
                            'text': found_answer['text'],
                        })
                    squad_items.append(squad_item)
    
    end_status.append(1)

In [37]:
# Manager of config.json with a simple lock file
class ConfigManager:
    def __init__(self, save_path, timeout=60, sleep_interval=1):
        self._save_path = save_path
        self._timeout = timeout
        self._sleep_interval = sleep_interval
        self._locked = False
    
    def read(self):
        self.lock()
        config_file = os.path.join(self._save_path, 'config.json')
        if os.path.exists(config_file):
            with open(config_file, 'r') as fp:
                config = json.load(fp)
        else:
            with open(config_file, 'w') as fp:
                config = {'skipped': [], 'found': []}
                json.dump(config, fp)
        self.unlock()
        return config

    def save(self, config_data):
        self.lock()
        config_file = os.path.join(self._save_path, 'config.json')
        with open(config_file, 'w') as fp:
            json.dump(config_data, fp)
        self.unlock()

    def lock(self):
        start_time = time.time()
        while self._locked:
            end_time = time.time()
            if end_time - start_time >= self._timeout:
                raise Exception('Cannot lock config file (timeout)')
            sleep(self._sleep_interval)
        self._locked = True
    
    def unlock(self):
        self._locked = False

In [38]:
class LoggerThread(threading.Thread):
    def __init__(self, n_items, config_manager, sleep_interval=5):
        super().__init__()
        self._n_items = n_items
        self._config_manager = config_manager
        self._sleep_interval = sleep_interval
        self._kill = threading.Event()
    
    def run(self):
        while True:
            config = config_manager.read()
            n_skipped = len(config['skipped'])
            n_found = len(config['found'])
            n_processed = n_skipped + n_found
            print('- Processed: %d / %d | Found: %d' % (n_processed, self._n_items, n_found), ' '*20, end='\r')
            
            is_killed = self._kill.wait(self._sleep_interval)
            if is_killed:
                break

    def kill(self):
        self._kill.set()

In [39]:
def main(end_status, save_path, item_picker, config_manager, max_threads=7,
         region_code=REGION_CODE, spacy_dict=SPACY_DICT, max_aliases=5, query_top_results=7):
    # Load initial config
    config = config_manager.read()

    # Need one tokenizer as many query_top_results (since we need one per thread)
    piece_tokenizers = [get_piece_tokenizer() for _ in range(max_threads)]
    word_tokenizer = get_word_tokenizer(SPACY_DICT)
    
    # Create subdirs
    top_items = [1, 2, 3, 5]
    for k in top_items:
        k_save_path = os.path.join(save_path, 'top_%d' % k)
        os.makedirs(k_save_path, exist_ok=True)
    
    while True:
        item = item_picker.pick()
        if item is None:
            break
        
        idx = 'mkqa_' + str(item['example_id'])
        query = item['queries'][region_code]
        
        # Skip if already parsed
        file_exists = os.path.exists(os.path.join(save_path, 'top_1', '%s.json' % idx))
        if item['example_id'] in config['found'] or item['example_id'] in config['skipped'] or file_exists:
            continue
        
        parsed_answers = []
        answer_types = []
        for raw_answer_data in item['answers'][region_code]:
            main_answer = raw_answer_data['text']
            aliases = raw_answer_data['aliases'][:max_aliases] if 'aliases' in raw_answer_data else []
            
            iter_parsed_answers = parse_raw_answer(raw_answer_data['type'], main_answer,
                                                   aliases, region_code=REGION_CODE)
            
            for iter_parsed_answer in iter_parsed_answers:
                if iter_parsed_answer not in parsed_answers:
                    parsed_answers.append(iter_parsed_answer)
                    answer_types.append(raw_answer_data['type'])
        
        if len(parsed_answers) == 0:
            # No answers for this query
            config = config_manager.read()
            config['skipped'].append(item['example_id'])
            config_manager.save(config)
            continue
        
        squad_items = find_squad_item(
                idx, query, parsed_answers, answer_types, piece_tokenizers, word_tokenizer,
                max_threads=max_threads, query_top_results=query_top_results)
        
        if len(squad_items) == 0:
            config = config_manager.read()
            config['skipped'].append(item['example_id'])
            config_manager.save(config)
            continue
        
        for k in top_items:
            output_file = os.path.join(save_path, 'top_%d' % k, '%s.json' % idx)
            with open(output_file, 'w', encoding='utf8') as fp:
                squad_dataset = {'data': get_best_squad_items(squad_items, top_items=k)}
                json.dump(squad_dataset, fp, ensure_ascii=False)
        
        config = config_manager.read()
        config['found'].append(item['example_id'])
        config_manager.save(config)
        
        # Save memory
        del squad_items
        del squad_dataset
    
    end_status.append(1)

In [40]:
max_errors = 2
n_errors = 0
n_threads = 12
n_sub_threads = 1

#######

save_path = os.path.join(OUTPUT_PATH, REGION_CODE)
os.makedirs(save_path, exist_ok=True)

print('Loading config manager...')
config_manager = ConfigManager(save_path)
config_manager.read() # Just to create the initial config file if not exists

print('Loading item picker...')
item_picker = ItemPicker(mkqa_dataset)

print('Loading logger thread...')
n_items = len(mkqa_dataset)
logger_thread = LoggerThread(n_items, config_manager)
logger_thread.start()

while True:
    try:
        all_main_threads = []
        end_status = []
        
        # Start first thread as warm-up
        x = threading.Thread(
            target=main,
            args=(end_status, save_path, item_picker, config_manager, n_sub_threads))
        x.start()
        all_main_threads.append(x)
        sleep(5)
        
        # Start rest of threads
        for _ in range(n_threads - 1):
            x = threading.Thread(
                target=main,
                args=(end_status, save_path, item_picker, config_manager, n_sub_threads))
            x.start()
            all_main_threads.append(x)
            sleep(1)

        for x in all_main_threads:
            x.join()
        
        print()
        logger_thread.kill()
        print('Killing logger thread...')
        logger_thread.join()

        if sum(end_status) != len(all_main_threads):
            raise Exception('Some thread failed!')
        break
    except requests.exceptions.ConnectionError as e:
        n_errors += 1
        if n_errors >= max_errors:
            raise e
    except Exception as e:
        print()
        print('-' * 10)
        print(e)
        raise e
        sleep(5)

print()
print('Done')

Loading config manager...
Loading item picker...
Loading logger thread...
- Processed: 9999 / 10000 | Found: 1354                     
Killing logger thread...

Done
