-
Notifications
You must be signed in to change notification settings - Fork 15
/
main.py
37 lines (30 loc) · 1.02 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import tensorflow as tf
from include.Config import Config
from include.Model import build, training
from include.Test import get_hits
from include.Load import *
import warnings
warnings.filterwarnings("ignore")
'''
Follow the code style of GCN-Align:
https://github.com/1049451037/GCN-Align
'''
seed = 12306
np.random.seed(seed)
tf.set_random_seed(seed)
if __name__ == '__main__':
e = len(set(loadfile(Config.e1, 1)) | set(loadfile(Config.e2, 1)))
ILL = loadfile(Config.ill, 2)
illL = len(ILL)
np.random.shuffle(ILL)
train = np.array(ILL[:illL // 10 * Config.seed])
test = ILL[illL // 10 * Config.seed:]
KG1 = loadfile(Config.kg1, 3)
KG2 = loadfile(Config.kg2, 3)
output_layer, loss = build(
Config.dim, Config.act_func, Config.alpha, Config.beta, Config.gamma, Config.k, Config.language[0:2], e, train, KG1 + KG2)
vec, J = training(output_layer, loss, 0.001,
Config.epochs, train, e, Config.k, test)
print('loss:', J)
print('Result:')
get_hits(vec, test)