In [47]:
import os
import psycopg2
import psycopg2.extras

# Load .env file
pass_ = {}
for line in open('../../.pass'):
    split = line.strip().split('=')
    pass_[split[0]] = split[1]

# Connect    
connection = psycopg2.connect(dbname=pass_['DB_NAME'],
                              port=pass_['DB_PORT'],
                              user=pass_['DB_USER'],
                              host=pass_['DB_HOST'],
                              password=pass_['DB_PASS'])
cursor = connection.cursor()

In [15]:
simple_qa = {
    'train': '../../data/SimpleQuestions_v2/annotated_fb_data_train.txt',
    'dev': '../../data/SimpleQuestions_v2/annotated_fb_data_valid.txt',
    'test': '../../data/SimpleQuestions_v2/annotated_fb_data_test.txt',
}

In [None]:
RELATION_CLASSIFIER = '../../results/0626.11-05_08:38:02.relation_classifier/11m_05d_08h_45m_57s.pt'

import math
import pprint
import re
import unicodedata

from lib.checkpoint import Checkpoint

    
def remove_accents(text):
    # https://stackoverflow.com/questions/517923/what-is-the-best-way-to-remove-accents-in-a-python-unicode-string
    nkfd_form = unicodedata.normalize('NFKD', text)
    return u"".join([c for c in nkfd_form if not unicodedata.combining(c)])

def preprocess(s):
    s = s.replace("'s", '')
    s = remove_accents(s)
    s = s.split()
    s = [re.sub('^\W+', '', w) for w in s]  # Strip prefix none alphanumeric
    s = [re.sub('\W+$', '', w) for w in s]  # Strip leading none alphanumeric
    s = [w.strip().lower() for w in s]
    s = [w for w in s if len(w) != 0] # Remove Empty
    return ' '.join(s)

relation_classifier_predict = Checkpoint(checkpoint_path=RELATION_CLASSIFIER).predict

def get_relation(question, top_k=3):
    """ 
    Given a question return the predicate in the question using `RELATION_CLASSIFIER` model.
    
    Args:
        question (str)
    Returns:
        list of predicates and their confidence
    """
    question = preprocess(question)
    predicted = list(relation_classifier_predict(question, top_k=top_k))
    if top_k == 1:
            predicted[1] = sum(predicted[1]) / len(predicted[1]) 
            return [predicted]
    return [(class_, sum(confidence) / len(confidence)) for class_, confidence in predicted]

# To test this cell
print(get_relation('Where was Obama born?', 1))

In [16]:
from IPython.display import display
import pandas as pd
from sklearn.utils import shuffle

names = ['subject', 'relation', 'object', 'question']
df_train = pd.read_table(simple_qa['train'], header=None, names=names)
df_dev = pd.read_table(simple_qa['dev'], header=None, names=names)
df_train = shuffle(df_train, random_state=123)
df_dev = shuffle(df_dev, random_state=123)
display(df_train)

Unnamed: 0,subject,relation,object,question
67691,www.freebase.com/m/0fqd80r,www.freebase.com/music/artist/genre,www.freebase.com/m/01kcty,what type of musician is bruna cara considered
60980,www.freebase.com/m/01v7rw,www.freebase.com/book/written_work/subjects,www.freebase.com/m/07dg1,what is the book troilus and cressida about
44045,www.freebase.com/m/03fl5vh,www.freebase.com/music/release/track,www.freebase.com/m/0x7w6c,what is a single on the release busted
35944,www.freebase.com/m/05sq1b,www.freebase.com/common/topic/notable_types,www.freebase.com/m/01xpjyz,what is the duluth international airport
44254,www.freebase.com/m/03yzs66,www.freebase.com/music/release/track_list,www.freebase.com/m/0n2nvvl,what song is on the release welcome to haiti: ...
68214,www.freebase.com/m/057vc_,www.freebase.com/military/military_person/part...,www.freebase.com/m/0kbcg,Which conflict did john gibbon participate in?
618,www.freebase.com/m/0q8552,www.freebase.com/music/recording/tracks,www.freebase.com/m/0kzl7vw,What track is featured in the recording of cha...
52003,www.freebase.com/m/09gdbzs,www.freebase.com/travel/tourist_attraction/nea...,www.freebase.com/m/0g7yx,what travel destination is quattro fontane near
56512,www.freebase.com/m/02v_gfg,www.freebase.com/people/person/profession,www.freebase.com/m/02h665k,What was Toby Wright's profession?
47413,www.freebase.com/m/0ch8kqd,www.freebase.com/fictional_universe/fictional_...,www.freebase.com/m/03bpv,what kind of entity is the character atropos


In [17]:
def preprocess(row):
    row['subject'] = row['subject'].strip().replace('www.freebase.com/m/', '')
    row['object'] = row['object'].strip().replace('www.freebase.com/m/', '')
    question = row['question'].strip()
    row['question'] = ' '.join(tokenize(question.lower()))
    row['relation'] = row['relation'].strip().replace('www.freebase.com/', '')
    return row

df_train = df_train.apply(preprocess, axis=1)
df_dev = df_dev.apply(preprocess, axis=1)
df_train

Unnamed: 0,subject,relation,object,question
67691,0fqd80r,music/artist/genre,01kcty,what type of musician is bruna cara considered
60980,01v7rw,book/written_work/subjects,07dg1,what is the book troilus and cressida about
44045,03fl5vh,music/release/track,0x7w6c,what is a single on the release busted
35944,05sq1b,common/topic/notable_types,01xpjyz,what is the duluth international airport
44254,03yzs66,music/release/track_list,0n2nvvl,what song is on the release welcome to haiti :...
68214,057vc_,military/military_person/participated_in_confl...,0kbcg,which conflict did john gibbon participate in ?
618,0q8552,music/recording/tracks,0kzl7vw,what track is featured in the recording of cha...
52003,09gdbzs,travel/tourist_attraction/near_travel_destination,0g7yx,what travel destination is quattro fontane near
56512,02v_gfg,people/person/profession,02h665k,what was toby wright 's profession ?
47413,0ch8kqd,fictional_universe/fictional_character/rank,03bpv,what kind of entity is the character atropos


In [18]:
from numpy import nan
from nltk.tokenize.treebank import TreebankWordTokenizer

tokenize = TreebankWordTokenizer().tokenize

def add_subject_name(row):
    rows = []
    for mid in [row['subject'], row['object']]:
        sql = """SELECT alias FROM fb_name
                 WHERE mid='{mid}'""".format(mid=mid)
        cursor.execute(sql)
        rows += cursor.fetchall()
    if len(rows) == 0:
        return nan
    # Sort by largest name first
    names = sorted([row[0].strip().lower() for row in rows], key=lambda n: len(n), reverse=True)
    for name in names:
        if name in row['question']:
            return name
    return nan

df_train['subject_name'] = df_train.apply(add_subject_name, axis=1)
df_dev['subject_name'] = df_dev.apply(add_subject_name, axis=1)
df_dev

Unnamed: 0,subject,relation,object,question,subject_name
6219,03k3r,biology/organism_classification/organisms_of_t...,0bs56bp,name an american thoroughbread racehorse,horse
3364,02qlppc,cvg/computer_videogame/cvg_genre,01sjng,what kind of game is vision racing driving sim...,vision racing driving simulator
9374,02l7c8,tv/tv_genre/programs,0dlmm88,what tv program is romance film,romance film
10142,049_zj3,location/location/containedby,04rrx,what state is polaski located in,polaski
97,02w9ycr,people/deceased_person/cause_of_death,0qcr0,what disease claimed the life of fern emmett,fern emmett
8296,08k_58,music/artist/album,02rw0hc,what is tristeza 's album that starts with an s,tristeza
3932,0dp76p7,music/album/artist,06twql,"who was the artist on the album just you , jus...","just you , just me"
6960,06c6tz,ice_hockey/hockey_player/hockey_position,02qvdc,what position does max bentley play,max bentley
2168,0d0gvn,computer/computer_processor/manufacturers,03s7h,what company made intel core 2 ?,intel core 2
1893,05gg9s,cvg/computer_videogame/cvg_genre,06zm8z,what kind of game is the legend of zelda : twi...,the legend of zelda : twilight princess


In [59]:
from tqdm import tqdm_notebook
from random import sample

flatten = lambda l: [item for sublist in l for item in sublist]

def get_candidates(df):
    ret = []
    n_rows = df.shape[0]
    for index, row in tqdm_notebook(df.iterrows(), total=n_rows):
        # CHECK: Should dev have the below filter?
        if not pd.isnull(row['subject_name']):
            cursor.execute("""SELECT fb_name.alias, fb_kg.subject_mid, array_agg(fb_kg.relation)
                              FROM fb_kg
                              LEFT JOIN fb_name
                              ON subject_mid = mid
                              WHERE alias LIKE %s
                              GROUP BY fb_kg.subject_mid, fb_name.alias""", (row['subject_name'],))
            rows = cursor.fetchall()
            all_relations = flatten([relations for _, _, relations in rows])
            total_facts = float(len(all_relations))
            total_relations = float(len(set(all_relations)))
            labels = []
            feature_rows = []
            for alias, subject_mid, relations in rows:
                label = subject_mid == row['subject']
                labels.append(label)
                n_facts = len(relations) / total_facts
                n_relations = len(set(relations)) / total_relations
                feature_rows.append([n_facts, n_relations])
            assert sum(labels) <= 1, 'n true: %d' % sum(labels) # Only one True
            ret.append({
                'rows': len(rows),
                'total_facts': total_facts,
                'total_relations': total_relations,
                'feature_rows': feature_rows,
                'labels': labels,
            })
    return ret

dev_candidates = get_candidates(df_dev.head(n=500))
print('Sample Candidates:', dev_candidates[:5])
train_candidates = get_candidates(df_train.head(n=4000))


Sample Candidates: [{'rows': 9, 'total_facts': 3199.0, 'total_relations': 65.0, 'feature_rows': [[0.003751172241325414, 0.1076923076923077], [0.0012503907471084713, 0.06153846153846154], [0.9681150359487339, 0.4], [0.009377930603313536, 0.2], [0.0025007814942169426, 0.12307692307692308], [0.00031259768677711783, 0.015384615384615385], [0.0015629884338855893, 0.046153846153846156], [0.010940919037199124, 0.24615384615384617], [0.002188183807439825, 0.1076923076923077]], 'labels': [False, False, True, False, False, False, False, False, False]}, {'rows': 1, 'total_facts': 3.0, 'total_relations': 3.0, 'feature_rows': [[1.0, 1.0]], 'labels': [True]}, {'rows': 1, 'total_facts': 20198.0, 'total_relations': 15.0, 'feature_rows': [[1.0, 1.0]], 'labels': [True]}, {'rows': 1, 'total_facts': 3.0, 'total_relations': 1.0, 'feature_rows': [[1.0, 1.0]], 'labels': [True]}, {'rows': 1, 'total_facts': 6.0, 'total_relations': 6.0, 'feature_rows': [[1.0, 1.0]], 'labels': [True]}]





In [60]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

train = flatten([row['feature_rows'] for row in train_candidates])
train_labels = flatten([row['labels'] for row in train_candidates])
dev = flatten([row['feature_rows'] for row in dev_candidates])
dev_labels = flatten([row['labels'] for row in dev_candidates])
model = LogisticRegression()
model.fit(train, train_labels)
dev_pred = model.predict(dev)
print(classification_report(dev_labels, dev_pred))

             precision    recall  f1-score   support

      False       1.00      1.00      1.00     30323
       True       0.93      0.78      0.85       481

avg / total       1.00      1.00      1.00     30804



In [65]:
correct = 0.0
for row in dev_candidates:
    if not row['feature_rows']:
        continue
    confidences = model.decision_function(row['feature_rows'])
    max_confidence = float("-inf")
    max_index = 0
    for i, confidence in enumerate(confidences):
        if confidence > max_confidence:
            max_confidence = confidence
            max_index = i
    if row['labels'][i]:
        correct += 1
print('Num Correct:', correct)
print('Accuracy:', correct / len(dev_candidates))

Num Correct: 336.0
Accuracy: 0.6956521739130435
