# Simple QA Subject Recognition IO

The goal of this notebook is to preprocess a file into the format:

**Subject Name:** angels vengeance

**Result:** what\O language\O is\O angels\I vengeance\I in\O ?\O

In [2]:
import sys
sys.path.insert(0, '../../')

In [3]:
from scripts.utils.simple_qa import load_simple_qa 

df, = load_simple_qa(train=True)
df[:5]

Unnamed: 0,subject,relation,object,question
0,04whkz5,book/written_work/subjects,01cj3p,what is the book e about
1,0tp2p24,music/release_track/release,0sjc7c1,to what release does the release track cardiac...
2,04j0t75,film/film/country,07ssc,what country was the film the debt from
3,0ftqr,music/producer/tracks_produced,0p600l,what songs have nobuo uematsu produced?
4,036p007,music/release/producers,0677ng,Who produced eve-olution?


In [32]:
import pandas as pd

from nltk.stem.snowball import SnowballStemmer

# Simple QA Dataset statistics
# Num Test Rows: 21687
# Num Dev Rows: 10845
# Num Train Rows: 75910
# Total Rows: 108442

OBJECT_COLUMN = 'Wikidata Entity Name'
QUESTION_COLUMN = 'Question'
MASK_COLUMN = 'Object Mask'
SOURCE = '/Users/petrochuk/data/siri/simple_qa/train.tsv'
DEST = SOURCE

data = pd.read_csv(SOURCE, sep='\t', dtype={"Wikidata Property PID": str, "Object Mask": str})
print('Number of Rows:', len(data))
print('Number Object Rows:', sum(data[OBJECT_COLUMN].notnull()))
data.head()

Number of Rows: 134847
Number Object Rows: 93657


Unnamed: 0,Accuracy,Question,Wikidata Entity Name,Wikidata Property PID,Object Mask
0,1.0,What's the population of United States of America,,P1082,
1,1.0,What is the population of San Diego,,P1082,
2,1.0,What's the population of Wyoming,,P1082,
3,1.0,What is the population of New Mexico,,P1082,
4,1.0,What is the population of Chicago,,P1082,


In [None]:
import re
import unicodedata
import csv
from numpy import nan

from nltk.stem.snowball import SnowballStemmer

stemmer = SnowballStemmer("english")
count_failed = 0
    
def remove_accents(text):
    # https://stackoverflow.com/questions/517923/what-is-the-best-way-to-remove-accents-in-a-python-unicode-string
    nkfd_form = unicodedata.normalize('NFKD', text)
    return u"".join([c for c in nkfd_form if not unicodedata.combining(c)])
    
def get_object_mask(row):
    global count_failed
    question = row[QUESTION_COLUMN]
    object_ = row[OBJECT_COLUMN]
    
    def preprocess(s):
        s = s.replace('-GB', '') # Weird Freebase suffixs
        s = s.replace('-US', '') # Weird Freebase suffixs
        s = s.replace("'s", '')
        s = remove_accents(s)
        s = s.split()
        s = [re.sub('^\W+', '', w) for w in s]  # Strip prefix none alphanumeric
        s = [re.sub('\W+$', '', w) for w in s]  # Strip leading none alphanumeric
        s = [w.strip().lower() for w in s]
        s = [w for w in s if len(w) != 0] # Remove Empty
        return [stemmer.stem(word) for word in s]
    
    question = preprocess(question)
    object_ = strip_articles(preprocess(object_))
    
    mask = []
    for start in range(len(question)):
        substring = []
        max_ = min(start + len(object_), len(question))
        for end in range(start, max_):
            substring.append(question[end])
            if substring == object_:
                mask.extend([True] * (end - start + 1))
                mask.extend([False] * (len(question) - end - 1))
                break
        if len(mask) == len(question):
            break
        mask.append(False)

    if any(mask): # Able to align a mask
        ret = ' '.join(['e' if word else 'c' for word in mask])
        if len(ret) > 1000:
            raise ValueError()
        return ret

    print('FAILED')
    print('Question:', question, row[QUESTION_COLUMN])
    print('Object:', object_, row[OBJECT_COLUMN])
    print()
    count_failed += 1
    return nan
    
def strip_articles(text):
    # Text is an array
    articles = set(['the', 'a', 'an'])
    while len(text) != 0:  # Strip stop words from the sides
        if text[-1] in articles:
            text = text[:-1]
        elif text[0] in articles:
            text = text[1:]
        else:
            break
    return text
    
data[MASK_COLUMN] = data[data[OBJECT_COLUMN].notnull()].apply(get_object_mask, axis=1)
print('Failed:', count_failed)
print('Number of Rows:', len(data))
data.to_csv(DEST, sep='\t', index=False)
data.head()