# Entity Linking Regression for Subject Name Duplicates

During our experiments, we found that the largest bottleneck is: given the true subject name and relation, linking the subject name to the mid has a for 21% of rows has duplicate rows with the same subject name but different mids.

Here we investigate, are there features we could use to determine the correct mid in the presence of duplicates. 

For this exploration, we intend to use the true subject name given it is in the question.

In [164]:
from IPython.display import display
import pandas as pd
from sklearn.utils import shuffle

simple_qa = {
    'train': '../../data/SimpleQuestions_v2/annotated_fb_data_train.txt',
    'dev': '../../data/SimpleQuestions_v2/annotated_fb_data_valid.txt',
    'test': '../../data/SimpleQuestions_v2/annotated_fb_data_test.txt',
}
names = ['subject', 'relation', 'object', 'question']
df_train = pd.read_table(simple_qa['train'], header=None, names=names)
df_dev = pd.read_table(simple_qa['dev'], header=None, names=names)
# Shuffle data to make sure order does not affect our results
df_train = shuffle(df_train, random_state=123)
df_dev = shuffle(df_dev, random_state=123)
display(df_dev[:5])

Unnamed: 0,subject,relation,object,question
6219,www.freebase.com/m/03k3r,www.freebase.com/biology/organism_classificati...,www.freebase.com/m/0bs56bp,Name an American Thoroughbread racehorse
3364,www.freebase.com/m/02qlppc,www.freebase.com/cvg/computer_videogame/cvg_genre,www.freebase.com/m/01sjng,what kind of game is vision racing driving sim...
9374,www.freebase.com/m/02l7c8,www.freebase.com/tv/tv_genre/programs,www.freebase.com/m/0dlmm88,what tv program is romance film
10142,www.freebase.com/m/049_zj3,www.freebase.com/location/location/containedby,www.freebase.com/m/04rrx,what state is polaski located in
97,www.freebase.com/m/02w9ycr,www.freebase.com/people/deceased_person/cause_...,www.freebase.com/m/0qcr0,what disease claimed the life of fern emmett


## Preprocess

Standard preprocessing of removing the http address. As well to standardize the questions with the subjectname, we tokenize them with treebank.

In [165]:
from nltk.tokenize.treebank import TreebankWordTokenizer

tokenize = TreebankWordTokenizer().tokenize

def preprocess(row):
    row['subject'] = row['subject'].strip().replace('www.freebase.com/m/', '')
    row['object'] = row['object'].strip().replace('www.freebase.com/m/', '')
    question = row['question'].strip()
    row['question'] = ' '.join(tokenize(question.lower()))
    row['relation'] = row['relation'].strip().replace('www.freebase.com/', '')
    return row

df_dev = df_dev.apply(preprocess, axis=1)
display(df_dev[:5])
df_train = df_train.apply(preprocess, axis=1)

Unnamed: 0,subject,relation,object,question
6219,03k3r,biology/organism_classification/organisms_of_t...,0bs56bp,name an american thoroughbread racehorse
3364,02qlppc,cvg/computer_videogame/cvg_genre,01sjng,what kind of game is vision racing driving sim...
9374,02l7c8,tv/tv_genre/programs,0dlmm88,what tv program is romance film
10142,049_zj3,location/location/containedby,04rrx,what state is polaski located in
97,02w9ycr,people/deceased_person/cause_of_death,0qcr0,what disease claimed the life of fern emmett


## Get True Subject Name

Here we query for the subject name from our DB and check for it's existance in our question. Checking for it's existance in our question, leaves us with 4.5% of the rows not having a subject name embedded in the question. This also reduces the number of names we have to deal with substantially; typically, there is more than one name per question.

In [166]:
import os
import psycopg2
import psycopg2.extras

# Load .env file
pass_ = {}
for line in open('../../.pass'):
    split = line.strip().split('=')
    pass_[split[0]] = split[1]

# Connect    
connection = psycopg2.connect(dbname=pass_['DB_NAME'],
                              port=pass_['DB_PORT'],
                              user=pass_['DB_USER'],
                              host=pass_['DB_HOST'],
                              password=pass_['DB_PASS'])
cursor = connection.cursor()

In [168]:
from numpy import nan

def add_subject_name(row):
    sql = """SELECT alias FROM fb_name
             WHERE mid='{mid}'""".format(mid=row['subject'])
    cursor.execute(sql)
    rows = cursor.fetchall()
    if len(rows) == 0:
        return nan
    # Sort by largest name first
    names = sorted([(row[0].strip().lower(), row[0]) for row in rows], key=lambda n: len(n), reverse=True)
    for name, orginal_name in names:
        if name in row['question']:
            return orginal_name # Make sure to return the original name to find it again
    return nan

df_dev['subject_name'] = df_dev.apply(add_subject_name, axis=1)
display(df_dev[:5])
df_train['subject_name'] = df_train.apply(add_subject_name, axis=1)

Unnamed: 0,subject,relation,object,question,subject_name
6219,03k3r,biology/organism_classification/organisms_of_t...,0bs56bp,name an american thoroughbread racehorse,horse
3364,02qlppc,cvg/computer_videogame/cvg_genre,01sjng,what kind of game is vision racing driving sim...,vision racing driving simulator
9374,02l7c8,tv/tv_genre/programs,0dlmm88,what tv program is romance film,romance
10142,049_zj3,location/location/containedby,04rrx,what state is polaski located in,polaski
97,02w9ycr,people/deceased_person/cause_of_death,0qcr0,what disease claimed the life of fern emmett,fern emmet


In [169]:
# Expect around 4.5% of the subject_name to not be present as consistent with past experiments
print('% missing subject_name:', 1 - (df_train.count()['subject_name'] / df_train.shape[0]))

% missing subject_name: 0.0431563693848


## Load Relation Model

Here we use an already trained relation model as one of our features. This allows us to experiment with more of the end-to-end parts.

In [170]:
import sys
sys.path.insert(0, '../../')

import math
import pprint
import re
import unicodedata

from lib.checkpoint import Checkpoint


RELATION_CLASSIFIER = '../../results/0626.11-05_08:38:02.relation_classifier/11m_05d_08h_45m_57s.pt'
    
def remove_accents(text):
    # https://stackoverflow.com/questions/517923/what-is-the-best-way-to-remove-accents-in-a-python-unicode-string
    nkfd_form = unicodedata.normalize('NFKD', text)
    return u"".join([c for c in nkfd_form if not unicodedata.combining(c)])

def preprocess_question(s):
    s = s.replace("'s", '|')
    s = remove_accents(s)
    s = s.split()
    s = [re.sub('^\W+', '', w) for w in s]  # Strip prefix none alphanumeric
    s = [re.sub('\W+$', '', w) for w in s]  # Strip leading none alphanumeric
    s = [w.strip().lower() for w in s]
    s = [w for w in s if len(w) != 0] # Remove Empty
    return ' '.join(s)

relation_classifier_predict = Checkpoint(checkpoint_path=RELATION_CLASSIFIER).predict

def get_relation(question, top_k=3):
    """ 
    Given a question return the predicate in the question using `RELATION_CLASSIFIER` model.
    
    Args:
        question (str)
    Returns:
        list of predicates and their confidence
    """
    question = preprocess_question(question)
    predicted = list(relation_classifier_predict(question, top_k=top_k))
    if top_k == 1:
            predicted[1] = sum(predicted[1]) / len(predicted[1]) 
            return [predicted]
    return [(class_, sum(confidence) / len(confidence)) for class_, confidence in predicted]

# To test this cell
print(get_relation('Where was Obama born?', 10))



[('www.freebase.com/people/person/place_of_birth', 1.0), ('www.freebase.com/people/person/nationality', 1.8365256779269813e-08), ('www.freebase.com/location/location/people_born_here', 5.0509226076001125e-09), ('www.freebase.com/music/album/album_content_type', 1.8572093479922836e-09), ('www.freebase.com/music/album/genre', 1.0259027594091375e-09), ('www.freebase.com/fictional_universe/fictional_character/place_of_birth', 4.3435644844782266e-10), ('www.freebase.com/people/deceased_person/place_of_death', 3.6899369947368274e-10), ('www.freebase.com/medicine/drug_formulation/drug_category', 1.7813512957546355e-10), ('www.freebase.com/media_common/netflix_title/netflix_genres', 8.754611452654275e-11), ('www.freebase.com/people/profession/people_with_this_profession', 6.442428900204096e-11)]


## Experiments

Here we run and report the experiments exploring linking the subject name to the mid. The experiments assume that there are some features that can rerank accuratly. 

Some of the experiments to run are:
- Sanity check that the number of rows with duplicates is around 56% as shown in other notebooks
- Setup a baseline with random guessing > 56%. 
- Hypotheses to test:
    - Hypothesis is that the number of facts associated with the mid is indicative
    - Hypothesis is that the likelihood of the set of relations associated with the mid is indicative
    - Hypothesis is that the type of the mid is indicative 
    - Hypothesis that the relation is indicative of the mid
    - Hypothesis that the number of unique relations associated with the mid is a indicative

In [171]:
# Get samples to speed up experiments
df_dev_sample = df_dev[df_dev.subject_name.notnull()].head(n=1000)
df_train_sample = df_train[df_train.subject_name.notnull()].head(n=5000)
print('df_dev sample length:', len(df_dev_sample))
print('df_train sample length:', len(df_train_sample))

df_dev sample length: 1000
df_train sample length: 5000


In [172]:
# Add all possible mids to dedup
def add_candidate_mids(row):
    cursor.execute("""SELECT fb_name.alias, fb_kg.subject_mid, array_agg(fb_kg.relation)
                  FROM fb_kg
                  LEFT JOIN fb_name
                  ON subject_mid = mid
                  WHERE alias = %s
                  GROUP BY fb_kg.subject_mid, fb_name.alias""", (row['subject_name'],))
    return cursor.fetchall()

df_dev_sample['candidate_mids'] = df_dev_sample.apply(add_candidate_mids, axis=1)
display(df_dev_sample[:5])
df_train_sample['candidate_mids'] = df_train_sample.apply(add_candidate_mids, axis=1)

Unnamed: 0,subject,relation,object,question,subject_name,candidate_mids
6219,03k3r,biology/organism_classification/organisms_of_t...,0bs56bp,name an american thoroughbread racehorse,horse,"[(horse, 01qd1g9, [music/group_member/instrume..."
3364,02qlppc,cvg/computer_videogame/cvg_genre,01sjng,what kind of game is vision racing driving sim...,vision racing driving simulator,"[(vision racing driving simulator, 02qlppc, [c..."
9374,02l7c8,tv/tv_genre/programs,0dlmm88,what tv program is romance film,romance,"[(romance, 0bdttnp, [tv/tv_genre/programs, tv/..."
10142,049_zj3,location/location/containedby,04rrx,what state is polaski located in,polaski,"[(polaski, 049_zj3, [location/location/contain..."
97,02w9ycr,people/deceased_person/cause_of_death,0qcr0,what disease claimed the life of fern emmett,fern emmet,"[(fern emmet, 02w9ycr, [people/person/national..."


In [173]:
zero_mid_rows = [len(row['candidate_mids']) == 0 for index, row in df_train_sample.iterrows()]
# Expected to be zero if subject_name was refound in the fb_name database
print('number of rows with no mid:', sum(zero_mid_rows))

# Print if not zero for df_train
for index, row in df_train_sample.iterrows():
    if len(row['candidate_mids']) == 0:
        print(row)

# Remove from sample
filter_train = [len(row['candidate_mids']) != 0 for index, row in df_train_sample.iterrows()]
df_train_sample = df_train_sample[filter_train]
filter_dev = [len(row['candidate_mids']) != 0 for index, row in df_dev_sample.iterrows()]
df_dev_sample = df_dev_sample[filter_dev]
print('df_dev sample length:', len(df_dev_sample))
print('df_train sample length:', len(df_train_sample))

number of rows with no mid: 2
subject                                  0fy0qfr
relation          music/album/album_content_type
object                                    02jbfk
question            what type of album is 青之绘卷 ?
subject_name                                青之绘卷
candidate_mids                                []
Name: 61060, dtype: object
subject                                   0zj9_cn
relation               music/composition/composer
object                                      0ftqr
question          who is the composer of 仲間を求めて ?
subject_name                               仲間を求めて
candidate_mids                                 []
Name: 62805, dtype: object
df_dev sample length: 1000
df_train sample length: 4998


In [175]:
# Expect around (1 / (1 - .0431 - 0.0004)) * .567 == .5928% to contain no duplicates 
# The above equation tackles the bias caused by removing 4% of rows that had no subject_name in the question.
# The above question tackles the bias caused by removing 2 of 5000 rows that had no duplicates
n_single_rows = sum(len(row['candidate_mids']) == 1 for index, row in df_train_sample.iterrows())
print('% rows without duplicate mid names:', n_single_rows / df_train_sample.shape[0])

% rows without duplicate mid names: 0.5846338535414166


In [176]:
baseline_expected_accuracy = sum(1.0 / len(row['candidate_mids']) for index, row in df_dev_sample.iterrows()) / len(df_dev_sample)
print('% baseline expected from random guessing', baseline_expected_accuracy)

% baseline expected from random guessing 0.6686948211333618


In [190]:
flatten = lambda l: [item for sublist in l for item in sublist]

def add_features(row):
    all_relations = flatten([relations for _, _, relations in row['candidate_mids']])
    total_facts = float(len(all_relations))
    total_relations = float(len(set(all_relations)))
    candidate_labels = []
    candidate_features = []
    for alias, subject_mid, relations in row['candidate_mids']:
        percentage_facts = len(relations) / total_facts
        percentage_relations = len(set(relations)) / total_relations
        candidate_labels.append(subject_mid == row['subject'])
        features = [percentage_facts, percentage_relations, len(set(relations)), len(relations)]
        candidate_features.append(features)
    row['candidate_features'] = candidate_features
    row['candidate_labels'] = candidate_labels
    if sum(candidate_labels) == 0:
        # Note that one of them does not have the original MID present in the list
        # This should be fixed
        print(row)
    return row
    
df_train_sample = df_train_sample.apply(add_features, axis=1)
df_dev_sample = df_dev_sample.apply(add_features, axis=1)

subject                                                         01gd34t
relation                                       music/album/release_type
object                                                           02lx2r
question                                           how was 電車男 released
subject_name                                                        電車男
candidate_mids        [(電車男, 03gy3p6, [film/film/country, film/film/...
candidate_features                                   [[1.0, 1.0, 3, 3]]
candidate_labels                                                [False]
Name: 6739, dtype: object


In [191]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

train = flatten([row['candidate_features'] for _, row in df_train_sample.iterrows()])
train_labels = flatten([row['candidate_labels'] for _, row in df_train_sample.iterrows()])
dev = flatten([row['candidate_features'] for _, row in df_dev_sample.iterrows()])
dev_labels = flatten([row['candidate_labels'] for _, row in df_dev_sample.iterrows()])
model = LogisticRegression()
model.fit(train, train_labels)
dev_pred = model.predict(dev)
print(classification_report(dev_labels, dev_pred))

             precision    recall  f1-score   support

      False       1.00      1.00      1.00     54490
       True       0.93      0.77      0.84      1000

avg / total       0.99      0.99      0.99     55490



In [194]:
correct = 0.0
for _, row in df_dev_sample.iterrows():
    confidences = model.decision_function(row['candidate_features'])
    max_confidence = float("-inf")
    max_index = 0
    for i, confidence in enumerate(confidences):
        if confidence > max_confidence:
            max_confidence = confidence
            max_index = i
    if row['candidate_labels'][i]:
        correct += 1
print('Num Correct:', correct)
print('Accuracy:', correct / len(df_dev_sample))

[[0.003751172241325414, 0.1076923076923077, 7, 12], [0.00031259768677711783, 0.015384615384615385, 1, 1], [0.9681150359487339, 0.4, 26, 3097], [0.002188183807439825, 0.1076923076923077, 7, 7], [0.009377930603313536, 0.2, 13, 30], [0.010940919037199124, 0.24615384615384617, 16, 35], [0.0012503907471084713, 0.06153846153846154, 4, 4], [0.0025007814942169426, 0.12307692307692308, 8, 8], [0.0015629884338855893, 0.046153846153846156, 3, 5]]
[[1.0, 1.0, 3, 3]]
[[0.0008828289764088479, 0.016666666666666666, 1, 18], [0.002354210603756928, 0.2, 12, 48], [0.00019618421697974398, 0.06666666666666667, 4, 4], [0.00024523027122467996, 0.06666666666666667, 4, 5], [0.000147138162734808, 0.03333333333333333, 2, 3], [0.000147138162734808, 0.03333333333333333, 2, 3], [0.00039236843395948795, 0.13333333333333333, 8, 8], [0.00019618421697974398, 0.05, 3, 4], [0.00024523027122467996, 0.06666666666666667, 4, 5], [0.00024523027122467996, 0.06666666666666667, 4, 5], [0.0004904605424493599, 0.05, 3, 10], [0.000

[[0.2727272727272727, 0.375, 3, 3], [0.7272727272727273, 0.625, 5, 8]]
[[1.0, 1.0, 7, 9]]
[[0.5714285714285714, 0.5, 3, 4], [0.42857142857142855, 0.5, 3, 3]]
[[0.036585365853658534, 0.06944444444444445, 5, 6], [0.03048780487804878, 0.06944444444444445, 5, 5], [0.5060975609756098, 0.4444444444444444, 32, 83], [0.042682926829268296, 0.06944444444444445, 5, 7], [0.042682926829268296, 0.09722222222222222, 7, 7], [0.03048780487804878, 0.041666666666666664, 3, 5], [0.15853658536585366, 0.16666666666666666, 12, 26], [0.042682926829268296, 0.027777777777777776, 2, 7], [0.042682926829268296, 0.09722222222222222, 7, 7], [0.03048780487804878, 0.05555555555555555, 4, 5], [0.036585365853658534, 0.06944444444444445, 5, 6]]
[[9.518370454978108e-05, 0.03508771929824561, 2, 2], [0.00014277555682467162, 0.05263157894736842, 3, 3], [9.518370454978108e-05, 0.03508771929824561, 2, 2], [0.00033314296592423375, 0.12280701754385964, 7, 7], [0.0002379592613744527, 0.07017543859649122, 4, 5], [0.000428326670474

## Numbers

- Random Choice Baseline:
- 
