In [1]:
from embeddings.play import Play
import numpy as np
from heapq import nlargest
from sklearn.metrics.pairwise import cosine_similarity
import sys
sys.path.append('..')
from my_util import ugly_normalize, ntopidx
from vdbscan import do_cluster

d = Play('data/result_3/best_model.pt', 'data/result_3/saved_config.json', 'data/rel_20.txt')
keywords = open('data/keyword_p.txt').read().strip().split('\n')
keywords_token, keywords = d.token_align(keywords, 6)
relations = open('data/rel_20.txt').read().strip().split('\n')
relations_token = [rel.split() for rel in relations]
relations = [(rel[:rel.index('<pad>')].strip() if '<pad>' in rel else rel) for rel in relations]

Loaded checkpoint 'data/result_3/best_model.pt' (epoch 17 iter: 540001 train_loss: 1.632509862942454, dev_loss: 2.3585431747436525, train_pos:0.5035273432731628, train_neg: 0.017596999183297157, dev_pos: 0.31248798966407776, dev_neg: 0.016466999426484108)


In [5]:
# Helper functions for finding similar keywords and relations
def read_test_file(file_name:str):
    ret = []
    with open(file_name) as f_in:
        for line in f_in:
            central_kw, kws = line.strip().split(':')
            ret.append((central_kw, kws.split(',')))
    return ret

def run_test(test_data:list, d:Play, keywords:list, keywords_token:list, relations:list):
    ret = []
    for central_kw, kws in test_data:
        central_kw_token = central_kw.split()
        kws_token, kws = d.token_align(kws, 6)
        general_rel_prediction = d.get_prediction(keywords_token, [central_kw_token] * len(keywords_token))
        test_rel_prediction = d.get_prediction(kws_token, [central_kw_token] * len(kws_token))
        rel_score = cosine_similarity(test_rel_prediction, d.relation_representation)
        rel_predict_score = cosine_similarity(test_rel_prediction, general_rel_prediction)
        for i in range(len(kws)):
            rel_top_40 = [relations[idx] for idx in ntopidx(40, rel_score[i])]
            kws_top_40 = [keywords[idx] for idx in ntopidx(40, rel_predict_score[i])]
            ret.append((kws[i], central_kw, rel_top_40, kws_top_40))
    return ret

def write_result(data:list, file_name):
    with open(file_name, 'w', encoding='utf-8') as f_out:
        content = []
        for kw, central_kw, similar_rels, similar_kws in data:
            content.append('!%s<=>%s\n' % (kw, central_kw))
            content.append('>Similar Relation')
            content += similar_rels
            content.append('\n>Similar Keyword')
            content += similar_kws
            content.append('\n')
        f_out.write('\n'.join(content))

def do_test(test_file:str, result_file:str, d:Play, keywords:list, keywords_token:list, relations:list):
    test_data = read_test_file(test_file)
    result = run_test(test_data, d, keywords, keywords_token, relations)
    write_result(result, result_file)

In [7]:
# Helper functions for doing relation clustering
def find_rel(target_rel:str, rel_list:list) -> list:
    ret = []
    for i, line in enumerate(rel_list):
        if target_rel in line:
            ret.append((line, i))
    return ret

def find_group_member(target:int, clusters:dict):
    for value in clusters.values():
        if target in value:
            return value

def find_most_similar(target:int, vecs:np.ndarray, relation_representation:np.ndarray, n:int=10):
    target_vec = relation_representation[vecs]
    similarities = cosine_similarity(target_vec, vecs)
    return ntopidx(n, similarities)

In [6]:
do_test('test/single_word_test.txt', 'result/single_word_out_2.txt', d, keywords, keywords_token, relations)

In [8]:
# Generate clusters
k = 7
cluster_id = do_cluster(d.relation_representation, k)
group_num = max(cluster_id) + 1
rel_id = np.arange(len(cluster_id))
rel_clusters = {}
for cid in range(-1, group_num):
    rel_clusters[cid] = set(rel_id[cluster_id == cid])

KeyboardInterrupt: 

In [9]:
from sklearn.cluster import DBSCAN
temp = DBSCAN(eps=0.1 , min_samples=5, metric='cosine').fit(d.relation_representation)

In [10]:
(temp.labels_== 0).sum()

62

In [11]:
max(temp.labels_)

907

In [23]:
s = np.arange(len(temp.labels_))[temp.labels_ == 102]
print(len(s))

35


In [24]:
print('\n'.join([relations[i] for i in s]))

are based on
however are based on
were based on
are often based on
are mainly based on
are mostly based on
are either based on
are usually based on
are proposed based on
are derived based on
are designed based on
are typically based on
currently are based on
are based upon
are selected based on
are extracted based on
however are mainly based on
are largely based on
are primarily based on
are identified based on
are generated based on
are determined based on
moreover are based on
hence are based on
are still based on
are generally based on
are constraints based on
are based primarily on
traditionally are based on
are currently based on
are analysed based on
are reported based on
are therefore based on
are therefore often based on
moreover are based primarily on
