In [1]:
import pickle
import os
import re
import time
from difflib import SequenceMatcher

import gensim
import nltk
# nltk.download("punkt")
import numpy as np
import pandas as pd
import spacy
from nltk.util import ngrams

In [2]:
#!python3 -m spacy download en_core_web_lg
nlp = spacy.load('en_core_web_lg')

In [3]:
with open("../data/mid2ent.pkl", "rb") as f:
    mid2entity = pickle.load(f)

# MID->DAWT wikipedia format
with open("../data/fb_id2sentences_idstr_label_type.pickle", "rb") as handle:
    mid2dawt = pickle.load(handle)
print("Length: ", len(mid2dawt))

Length:  56658


In [4]:
def load_annotated_file(path):
    sbj_mid_ = list()
    obj_mid_ = list()
    rel = list()
    question = list()
    df = pd.read_csv(path, sep="\t", usecols=[0, 1, 2, 3],
                     names=["sbj", "relation", "obj", "question"])
    sbj_mid_ = df["sbj"].str.replace("www.freebase.com/m/", "").to_list()
    rel = df["relation"].str.replace("www.freebase.com/", "").to_list()
    obj_mid_ = df["obj"].str.replace("www.freebase.com/m/", "").to_list()
    question = df["question"].to_list()

    number_of_samples = 0
    samples_with_answer_existing = []
    samples_with_answer_existing_occ = []  # how many sentences contain the answer entity
    for k, v in enumerate(sbj_mid_):
        if v in mid2dawt:
            number_of_samples += 1
            list_ = list()
            for sent in mid2dawt[v]:
                # list_.append(sent[-1])
                list_.append(sent[-1][0])
            if obj_mid_[k] in list_:
                samples_with_answer_existing.append(k)
                samples_with_answer_existing_occ.append(list_.count(obj_mid_[k]))
            else:
                str = ""
                for sent in mid2dawt[v]:
                    # list_.append(sent[-1])
                    #                     print(sent[:-1])
                    str += " ".join(sent[:-1]).lower()
                if obj_mid_[k] in mid2entity:
                    if len(mid2entity[obj_mid_[k]]) != 0:
                        if mid2entity[obj_mid_[k]][0] in str:
                            samples_with_answer_existing.append(k)
                            samples_with_answer_existing_occ.append(list_.count(obj_mid_[k]))

    print("number_of_samples: ", number_of_samples)
    print("samples_with_answer_existing: ", len(samples_with_answer_existing))

    c = 0
    for occ in samples_with_answer_existing_occ:
        if occ > 1:
            c += 1
    print("More than one occurences of answer: ", c)

    return sbj_mid_, rel, question, obj_mid_, samples_with_answer_existing


In [5]:
def calculate_centroid(sent):
    words_sum = 0
    emb_sum = 0
    for word in sent:
        if word in w2v_model:
            words_sum += 1
            emb_sum += w2v_model.get_vector(word)

    if words_sum == 0:
        centroid = np.ones(300) * 999999

    else:
        centroid = emb_sum / words_sum

    return centroid


def get_scores(corpus, relation):
    corpus_centroids = []
    for sentence in corpus:
        centr = calculate_centroid(sentence)
        corpus_centroids.append(centr)

    relation_centroid = calculate_centroid(relation)
    scores = w2v_model.cosine_similarities(relation_centroid, corpus_centroids)
    return scores


def preprocess(string):
    tokens = nltk.word_tokenize(string)
    return tokens


def lowercase_sentences(corpus):
    return [[token.lower() for token in sentence] for sentence in corpus]


def uniquify_list(l):
    set_of_tuples = set(tuple(row) for row in l)
    list_of_list = [list(item) for item in set(tuple(row) for row in set_of_tuples)]
    return list_of_list


def sentences_remove_fbid(corpus):
    return [[token for token in sentence[:-1]] for sentence in corpus]


def remove_sentences_with_multiple_occurences(corpus, obj_mid_):
    """Removes sentences multiple occurences of a sentence. If a sentence containing the answer entity
    appears multiple times keeps only the one with the answer. Otherwise it keeps the first that encounter."""

    corpus_new = uniquify_list(sentences_remove_fbid(corpus))

    # key-> position in the new corpus value-> the freebase id that we will keep
    indx2fb_id = dict()
    for i in range(len(corpus_new)):
        indx2fb_id[i] = 99

    for i, sent_n_id in enumerate(corpus):
        sent = sent_n_id[:-1]
        fb_id = sent_n_id[-1][0]  # when using the extended version
        id_label_type = sent_n_id[-1]
        indx = corpus_new.index(sent)
        if indx2fb_id[indx] == 99:
            indx2fb_id[indx] = id_label_type  # fb_id
        elif fb_id == obj_mid_:
            indx2fb_id[indx] = id_label_type  # fb_id

    return corpus_new, indx2fb_id


def split_unpunctuated_text(corpus_new, indx2fb_id):
    corpus_new_sentences = []
    indx2fb_id_newsent = dict()
    indx = 0
    #     c =0
    #     cc = 0
    for key, value in indx2fb_id.items():
        #         cc += 1
        doc = nlp(" ".join(corpus_new[key]))
        found = False
        for sent in doc.sents:
            if value[1] in sent.text:
                sent_tokenized_lowered = preprocess(sent.text.lower())
                corpus_new_sentences.append(sent_tokenized_lowered)
                indx2fb_id_newsent[indx] = value
                indx += 1

            else:
                sent_tokenized_lowered = preprocess(sent.text.lower())
                corpus_new_sentences.append(sent_tokenized_lowered)
                indx2fb_id_newsent[indx] = [0, 0, 0]
                indx += 1

    return corpus_new_sentences, indx2fb_id_newsent


def split_unpunctuated_text_add_placeholders(corpus_new, indx2fb_id, pattern):
    corpus_new_sentences = []
    indx2fb_id_newsent = dict()
    indx = 0
    for key, value in indx2fb_id.items():
        doc = nlp(" ".join(corpus_new[key]))
        for sent in doc.sents:
            if value[1] in sent.text:
                text_with_placeholders = sent.text
                for k, v in pattern.items():
                    text_with_placeholders = re.sub(r"\b{}\b".format(re.escape(k)), v,
                                                    text_with_placeholders)  # text_with_placeholders.replace(k,v)

                sent_tokenized_lowered = preprocess(text_with_placeholders.lower())
                corpus_new_sentences.append(sent_tokenized_lowered)
                indx2fb_id_newsent[indx] = value
                indx += 1
            else:
                text_with_placeholders = sent.text
                for k, v in pattern.items():
                    text_with_placeholders = re.sub(r"\b{}\b".format(re.escape(k)), v,
                                                    text_with_placeholders)  # text_with_placeholders.replace(k,v)

                sent_tokenized_lowered = preprocess(text_with_placeholders.lower())
                corpus_new_sentences.append(sent_tokenized_lowered)
                indx2fb_id_newsent[indx] = [0, 0, 0]
                indx += 1

    return corpus_new_sentences, indx2fb_id_newsent


In [6]:
def create_replacement(sorted_idlbtype):
    rep = dict()
    for idlbtype in sorted_idlbtype:
        rep[idlbtype[1]] = idlbtype[2] + "_"
    return rep


def create_ngrams(text):
    n_grams = list()
    token = nltk.word_tokenize(text)  # token = text.split()

    for i in range(1, len(token) + 1):
        n_gram = ngrams(token, i)
        for gram in n_gram:
            n_grams.append(" ".join(gram))
    return n_grams


def preprocess_rtrn_string(string):
    tokens = nltk.word_tokenize(string)
    return " ".join(tokens)


def placeholder_to_question(sbj_mid_y, questions):
    wrong_num_of_ent = []
    wrong_num_of_ent_num = []
    annotations = []
    for i, question in enumerate(questions):
        print(i)
        question = preprocess_rtrn_string(question)

        if len(mid2entity[sbj_mid_y[i]]) != 0:

            # find the most probable part of the sentence that matches the true_label of the subject (sbj of the triple)
            most_similar_ngram = ""
            most_similar_ngram_v = -1

            n_grams = create_ngrams(question)
            for n_gram in n_grams:
                similarity = similar(preprocess_rtrn_string(mid2entity[sbj_mid_y[i]][0]), n_gram)
                if most_similar_ngram_v < similarity:
                    most_similar_ngram = n_gram
                    most_similar_ngram_v = similarity

            idlbtype = [text_idlabeltype[-1] for text_idlabeltype in mid2dawt[sbj_mid_y[i]]]
            sim_ratio = [SequenceMatcher(None, mid2entity[sbj_mid_y[i]][0], text_idlabeltype[-1][1]).ratio()
                         for text_idlabeltype in mid2dawt[sbj_mid_y[i]]]
            sorted_idlbtype = [x for _, x in sorted(zip(sim_ratio, idlbtype), reverse=True)]
            # replacement_ = ["1_"]*len(most_similar_ngram.split())
            print(sorted_idlbtype[-1])
            question_repl = question.replace(most_similar_ngram, sorted_idlbtype[0][-1] + "_")
            annotations.append(question_repl)
        else:
            annotations.append(question)

    return annotations


def similar(a, b):
    return SequenceMatcher(None, a, b).ratio()


In [None]:
# load word2vec
path_pretrained_emb="../data/embeddings/GoogleNews-vectors-negative300.bin"

w2v_model = gensim.models.KeyedVectors.load_word2vec_format(path_pretrained_emb, binary=True)

#### Run the cells below for training, test, and validation; each time you have to change the corresponding paths (load/save)

In [None]:
sbj_mid_, rel, question, obj_mid_, samples_with_answer_existing=load_annotated_file("../data/SimpleQuestions_v2/annotated_fb_data_train.txt")

In [None]:
total_correct_samples = []
total_wrong_samples = []


for ind, id_ in enumerate(samples_with_answer_existing):  # [samples_with_answer_existing[i] for i in x]):
    print("--------------------------------", ind, "---------------------------------------")

    sb = sbj_mid_[id_]
    ob = obj_mid_[id_]
    
    idlbtype = [text_idlabeltype[-1]
                for text_idlabeltype in mid2dawt[sb] if text_idlabeltype[-1][-1] != "MISC"]

    labels_len = [len(id_lb_type[1].split()) for id_lb_type in idlbtype]
    longest_labels_ids = [x for _, x in sorted(zip(labels_len, idlbtype), reverse=True)]

    pattern = create_replacement(longest_labels_ids)

    corpus_new, indx2fb_id = remove_sentences_with_multiple_occurences(mid2dawt[sb], ob)
    corpus_new_sentences, indx2fb_id_newsent = split_unpunctuated_text_add_placeholders(corpus_new, indx2fb_id, pattern)

    if len(corpus_new_sentences) < 1:
        print("SKIP")
        continue

    rel_txt = preprocess(rel[id_].replace("/", " ").replace("_", " ").lower())
    weights = get_scores(corpus_new_sentences, rel_txt)
    sorted_weight_indx = np.argsort(weights)

    cor = True
    wrg = True
    correct_samples = []
    wrong_samples = []

    for indx_ in range(1, len(sorted_weight_indx) + 1):
        current_indx = sorted_weight_indx[-indx_]
        if indx2fb_id_newsent[current_indx][0] == ob and cor:
            # print(corpus_new_sentences[current_indx], indx_)
            correct_samples.append(corpus_new_sentences[current_indx])
            cor = False
        elif mid2entity[ob][0] in " ".join(corpus_new_sentences[current_indx]) and cor:
            correct_samples.append(corpus_new_sentences[current_indx])
            cor = False
        elif indx2fb_id_newsent[current_indx][0] != ob and wrg:
            # print(corpus_new_sentences[current_indx], "@@@@@@@")
            wrong_samples.append(corpus_new_sentences[current_indx])
            wrg = False
        elif not cor and not wrg:
            break

    total_correct_samples.append(correct_samples)
    total_wrong_samples.append(wrong_samples)

# free up some memory 
del mid2entity
del mid2dawt
del w2v_model

time.sleep(120)

In [None]:
def directory_exists(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)
        
def save_pickle(path, name, python_object):
    directory_exists(path)
    
    with open(path+name,"wb") as handle:
        pickle.dump(python_object, handle)

In [None]:
path_pickle = '../data/DAWT/train/'

save_pickle(path_pickle, "total_correct_samples.pickle", total_correct_samples)

save_pickle(path_pickle, "total_wrong_samples.pickle", total_wrong_samples)

save_pickle(path_pickle, "samples_with_answer_existing.pickle", samples_with_answer_existing)
