In [None]:
import pickle, pickledb
import numpy as np
from itertools import count
from collections import defaultdict
import tensorflow as tf
import tensorflow_hub as hub

train_file = "/data/Vivek/original/HypeNET/dataset/custom_train_0.0_0.05.tsv"
test_file =  "/data/Vivek/original/HypeNET/dataset/custom_test_0.0_0.05.tsv"
instances_file = '../files/dataset/test_instances.tsv'
knocked_file = '../files/dataset/test_knocked.tsv'
output_folder = "../junk/Output/"
USE_folder = "/home/vlead/USE"
use_embeddings = "../files/embeddings.pt"

POS_DIM = 4
DEP_DIM = 5
DIR_DIM = 1
EMBEDDING_DIM = 512
NULL_PATH = ((0, 0, 0, 0),)
relations = ["hypernym", "hyponym", "concept", "instance", "none"]
NUM_RELATIONS = len(relations)
prefix = "../junk/db_files/"

USE_link = "https://tfhub.dev/google/universal-sentence-encoder-large/5?tf-hub-format=compressed"
model = hub.load(USE_link)

f = open("../junk/resolved_use.pkl", "rb")
resolved = pickle.load(f)

def extractUSEEmbeddings(words):
    word_embeddings = model(words)
    return word_embeddings.numpy()

In [None]:
arrow_heads = {">": "up", "<":"down"}


def extract_direction(edge):

    if edge[0] == ">" or edge[0] == "<":
        direction = "start_" + arrow_heads[edge[0]]
        edge = edge[1:]
    elif edge[-1] == ">" or edge[-1] == "<":
        direction = "end_" + arrow_heads[edge[-1]]
        edge = edge[:-1]
    else:
        direction = ' '
    return direction, edge

def parse_path(path):
    parsed_path = []
    for edge in path.split("*##*"):
        direction, edge = extract_direction(edge)
        if edge.split("/"):
            try:
                embedding, pos, dependency = tuple([a[::-1] for a in edge[::-1].split("/",2)][::-1])
            except:
                print (edge, path)
                raise
            emb_idx, pos_idx, dep_idx, dir_idx = emb_indexer[embedding], pos_indexer[pos], dep_indexer[dependency], dir_indexer[direction]
            parsed_path.append(tuple([emb_idx, pos_idx, dep_idx, dir_idx]))
        else:
            return None
    return tuple(parsed_path)

def parse_tuple(tup):
    x, y = [entity_to_id(word2id_db, elem) for elem in tup]
    paths_x, paths_y = list(extract_paths(relations_db,x,y).items()), list(extract_paths(relations_db,y,x).items())
    path_count_dict_x = { id_to_path(id2path_db, path).replace("X/", tup[0]+"/").replace("Y/", tup[1]+"/") : freq for (path, freq) in paths_x }
    path_count_dict_y = { id_to_path(id2path_db, path).replace("Y/", tup[0]+"/").replace("X/", tup[1]+"/") : freq for (path, freq) in paths_y }
    path_count_dict = {**path_count_dict_x, **path_count_dict_y}
    return path_count_dict

def parse_dataset(dataset):
    parsed_dicts = [parse_tuple(tup) for tup in dataset.keys()]
    parsed_dicts = [{ parse_path(path) : path_count_dict[path] for path in path_count_dict } for path_count_dict in parsed_dicts]
    paths = [{ path : path_count_dict[path] for path in path_count_dict if path} for path_count_dict in parsed_dicts]
    paths = [{NULL_PATH: 1} if not path_list else path_list for i, path_list in enumerate(paths)]
    targets = [rel_indexer[relation] for relation in dataset.values()]
    return paths, targets

failed, success = [], []

def id_to_entity(db, entity_id):
    entity = db.get(str(entity_id))
    return entity

def id_to_path(db, entity_id):
    entity = db.get(str(entity_id))
    entity = "/".join(["*##*".join(e.split("_", 1)) for e in entity.split("/")])
    return entity

def entity_to_id(db, entity):
    global success, failed
    entity_id = db.get(entity)
    if entity_id:
        success.append(entity)
        return int(entity_id)
    closest_entity = resolved.get(entity, "")
    if closest_entity and closest_entity[0] and float(closest_entity[1]) > threshold:
        success.append(entity)
        return int(db.get(closest_entity[0]))
    failed.append(entity)
    return -1

def extract_paths(db, x, y):
    key = (str(x) + '###' + str(y))
    try:
        relation = db.get(key)
        return {int(path_count.split(":")[0]): int(path_count.split(":")[1]) for path_count in relation.split(",")}
    except Exception as e:
        return {}

word2id_db = pickledb.load(prefix + "w2i.db", False)
id2word_db = pickledb.load(prefix + "i2w.db", False)
path2id_db = pickledb.load(prefix + "p2i.db", False)
id2path_db = pickledb.load(prefix + "i2p.db", False)
relations_db = pickledb.load(prefix + "relations.db", False)


In [None]:
threshold = 0.8

emb_indexer, pos_indexer, dep_indexer, dir_indexer = [defaultdict(count(0).__next__) for i in range(4)]
unk_emb, unk_pos, unk_dep, unk_dir = emb_indexer["<UNK>"], pos_indexer["<UNK>"], dep_indexer["<UNK>"], dir_indexer["<UNK>"]
rel_indexer = {key: idx for (idx,key) in enumerate(relations)}

train_dataset = {tuple(l.split("\t")[:2]): l.split("\t")[2] for l in open(train_file).read().split("\n")}
test_dataset = {tuple(l.split("\t")[:2]): l.split("\t")[2] for l in open(test_file).read().split("\n")}
test_instances = {tuple(l.split("\t")[:2]): l.split("\t")[2] for l in open(instances_file).read().split("\n")}
test_knocked = {tuple(l.split("\t")[:2]): l.split("\t")[2] for l in open(knocked_file).read().split("\n")}

parsed_train = parse_dataset(train_dataset)
parsed_test = parse_dataset(test_dataset)
parsed_instances = parse_dataset(test_instances)
parsed_knocked = parse_dataset(test_knocked)

print ("Train len: {}, Test len: {}, Instance len: {}, Knocked len: {}".format(len(parsed_train[0]), len(parsed_test[0]),  len(parsed_instances[0]), len(parsed_knocked[0])))

emb_indexer_inv = {emb_indexer[key]: key for key in emb_indexer}
embeds = extractUSEEmbeddings(list(emb_indexer.keys())[1:])
emb_vals = np.array(np.zeros((1, embeds.shape[1])).tolist() + embeds.tolist())


output_file = "../junk/Input/data_use_" + str(threshold) + ".pkl"
f = open(output_file, "wb+")
pickle.dump([parsed_train, parsed_test, parsed_instances, parsed_knocked, emb_indexer, emb_indexer_inv, emb_vals, pos_indexer, dep_indexer, dir_indexer], f)
f.close()



In [30]:
f = pickle.load(open("../junk/use_input_0.8000000000000003.pkl", "rb"))
len(f[1][0]), len(f[2][0]), len(f[3][0])

(1197, 275, 5538)