In [188]:
import networkx
import json
import pickle
import gensim
import numpy as np
from tqdm import tqdm
from gensim.scripts.glove2word2vec import glove2word2vec
from gensim.models import KeyedVectors
from utils import sample

In [158]:
DATASETS = 'Ohsumed'
STC_Benchmark_path = "../benchmark"
NumOfTrainTextPerClass = 2
TOPK = 10
SIM_MIN = 0.5
g = networkx.Graph()

train_: dict = np.load("{}/{}/train.npy".format(STC_Benchmark_path,DATASETS), allow_pickle=True).tolist()
test_: dict = np.load("{}/{}/test.npy".format(STC_Benchmark_path,DATASETS), allow_pickle=True).tolist()
train_tagme: list = np.load("{}/{}/train_tagme.npy".format(STC_Benchmark_path,DATASETS), allow_pickle=True).tolist()
test_tagme: list = np.load("{}/{}/test_tagme.npy".format(STC_Benchmark_path,DATASETS), allow_pickle=True).tolist()
all_: dict = np.load("{}/{}/all.npy".format(STC_Benchmark_path,DATASETS), allow_pickle=True).tolist()
all_tagme: list = np.load("{}/{}/all_tagme.npy".format(STC_Benchmark_path,DATASETS), allow_pickle=True).tolist()

### load text

In [120]:
def loadText(data_dict, trainNumPerClass=20):
    datapath = "{}/{}/".format(STC_Benchmark_path,DATASETS)
    X = []
    Y = []
    catemap = dict()
    from sklearn.model_selection import train_test_split
    for index, (key, value) in enumerate(data_dict.items()):
        X.append([index])
        Y.append(value[0])
    cateset = list(set(Y))
    catemap = dict()
    for i in range(len(cateset)):
        catemap[cateset[i]] = i
    Y = [catemap[i] for i in Y]
    X = np.array(X)
    trainNum = trainNumPerClass*len(catemap)
    ind_train, ind_test = train_test_split(X,train_size=trainNum, random_state=1, )
    ind_vali, ind_test = train_test_split(ind_test,train_size=trainNum/(len(X)-trainNum), random_state=1, )
    train = sum(ind_train.tolist(), [])
    vali = sum(ind_vali.tolist(), [])
    test = sum(ind_test.tolist(), [])
    alltext = set(train + vali + test)
    print( "train: {}\nvali: {}\ntest: {}\nAllTexts: {}".format( len(train), len(vali), len(test), len(alltext)) )
    with open(datapath+'train.list', 'w') as f:
        f.write( '\n'.join(map(str, train)) )
    with open(datapath+'vali.list', 'w') as f:
        f.write( '\n'.join(map(str, vali)) )
    with open(datapath+'test.list', 'w') as f:
        f.write( '\n'.join(map(str, test)) )
    return train, vali, test, alltext

In [121]:
train, vali, test, alltext = loadText(dict(train_,**test_))

train: 460
vali: 459
test: 6479
AllTexts: 7398


### load entity

In [159]:
def loadTagMeEntity(tagme_list):
    entitySet = set()
    rho = 0.1
    noEntity = set()
    for line in tqdm(tagme_list, desc="tagme_list: "):
        ind, entityList = line.strip('\n').split('\t')
        if int(ind) not in alltext:
            continue
        entityList = json.loads(entityList)
        entities = [(d['title'].replace(" ", '_'), d['rho'], d['link_probability'])\
                        for d in entityList if 'title' in d and float(d['rho']) > rho]
        
        entitySet.update([d['title'].replace(" ", '_')\
                        for d in entityList if 'title' in d and float(d['rho']) > rho])
        g.add_edges_from([(ind, e[0], {'rho': e[1], 'link_probability': e[2]}) for e in entities])
        if len(entities) == 0:
            noEntity.add(ind)
            g.add_node(ind)
    return entitySet, noEntity

In [160]:
entitySet, noEntity = loadTagMeEntity(train_tagme+test_tagme)

tagme_list: 100%|██████████| 7400/7400 [00:00<00:00, 30802.64it/s]


In [161]:
len(g.edges())

34912

### load labels

In [162]:
def loadLabel(data_dict):
    for index, (key, value) in enumerate(data_dict.items()):
        if index not in alltext:
            continue
        ind_str = str(index)
        if ind_str not in g.nodes():
            g.add_node(ind_str)
        g.nodes[ind_str]['type'] = value[0]

In [163]:
loadLabel(dict(train_,**test_))

In [166]:
#g._node

### load similarities between entities

In [184]:
glove_input_file = "/Users/sauron/Desktop/Code/SKB-STC/data/embeddings/glove.6B.300d.wiki.txt"
word2vec_output_file = './data/glove.6B.300d.word2vec.txt'
(count, dimensions) = glove2word2vec(glove_input_file, word2vec_output_file)
print(count, '\n', dimensions)

  This is separate from the ipykernel package so we can avoid doing imports until


400001 
 300


In [189]:
model = KeyedVectors.load_word2vec_format(word2vec_output_file, binary=False)

### topK + 阈值

In [182]:
el[1].strip(')')

'Cardiac_tamponade'

In [196]:
def entToEnt(entitySet):
    sim_min = SIM_MIN
    topK = TOPK
    el = list(entitySet)
    entity_edge = []
    cnt_no = 0
    cnt_yes = 0
    cnt = 0
    for i in tqdm(range(len(el)), desc="entity to entity: "):
        simList = []
        topKleft = topK
        for j in range(len(el)):
            if i == j:
                continue
            cnt += 1
            try:
                sim = model.similarity(el[i].lower().strip(')'), el[j].lower().strip(')'))
                cnt_yes += 1
                if sim >= sim_min:
                    entity_edge.append( (el[i], el[j], {'sim': sim}) )
                    topKleft -= 1
                else:
                    simList.append( (sim, el[j]) )
            except Exception as e:
                cnt_no += 1
        simList = sorted(simList, key=(lambda x: x[0]), reverse=True)
        for i in range(min(max(topKleft, 0), len(simList))):
            entity_edge.append( (el[i], simList[i][1], {'sim': simList[i][0]}) )
    print(cnt_yes, cnt_no)
    return entity_edge

In [197]:
entity_edge = entToEnt(entitySet)

entity to entity: 100%|██████████| 7666/7666 [14:22<00:00,  8.89it/s]

5633502 53126388





In [200]:
g.add_edges_from(entity_edge)

### save the network

In [201]:
with open('./data/model_network.pkl', 'wb') as f:
    pickle.dump(g, f)