forked from NickDrake117/GNNkeras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
starter.py
106 lines (85 loc) · 4.87 KB
/
starter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# codinf=utf-8
import tensorflow as tf
from numpy import random
from GNN.Models.MLP import MLP, get_inout_dims
from GNN.Models.GNN import GNNgraphBased
from GNN.Models.LGNN import LGNN
from GNN.Sequencers.GraphSequencers import MultiGraphSequencer
#######################################################################################################################
# SCRIPT OPTIONS - modify the parameters to adapt the execution to the problem under consideration ####################
#######################################################################################################################
### GRAPHS OPTIONS
aggregation_mode = 'average'
# c: Classification
addressed_problem = 'c'
# g: graph-focused
focus = 'g'
# NET STATE PARAMETERS
activations_net_state : str = 'selu'
kernel_init_net_state : str = 'lecun_normal'
bias_init_net_state : str = 'lecun_normal'
### NET OUTPUT PARAMETERS
activations_net_output : str = 'softmax'
kernel_init_net_output : str = 'glorot_normal'
bias_init_net_output : str = 'glorot_normal'
# GNN PARAMETERS
dim_state : int = 0
max_iter : int = 5
state_threshold : float = 0.01
# LGNN PARAMETERS
layers : int = 3
get_state : bool = True
get_output : bool = True
training_mode : str = 'serial'
# LEARNING PARAMETERS
epochs : int = 10
batch_size : int = 1000
loss_function : tf.keras.losses = tf.keras.losses.categorical_crossentropy
optimizer : tf.keras.optimizers = tf.optimizers.Adam(learning_rate=0.01)
#######################################################################################################################
# SCRIPT ##############################################################################################################
#######################################################################################################################
### LOAD DATASET from MUTAG
# problem is set automatically to graph-focused one -> focus='g'
# then aggregation_mode is set for each graph, since they are initialized with aggregation_mode = 'average',
# but one can choose between 'average', 'sum', 'normalized'.
from load_MUTAG import graphs
for g in graphs: g.setAggregation(aggregation_mode)
### PREPROCESSING
# SPLITTING DATASET in Train, Validation and Test set, no graph normalization is applied
random.shuffle(graphs)
gTr = graphs[:-1500]
gTe = graphs[-1500:-750]
gVa = graphs[-750:]
gGen = gTr[0].copy()
### MODELS
# MLP NETS - STATE
input_net_st, layers_net_st = zip(*[get_inout_dims(net_name='state', dim_node_label=gGen.DIM_NODE_LABEL,
dim_arc_label=gGen.DIM_ARC_LABEL, dim_target=gGen.DIM_TARGET,
focus=focus, dim_state=dim_state,
layer=i, get_state=get_state, get_output=get_output) for i in range(layers)])
nets_St = [MLP(input_dim=k, layers=j, activations=activations_net_state,
kernel_initializer=kernel_init_net_state, bias_initializer=bias_init_net_state,
name=f'State_{idx}') for idx, (i, j) in enumerate(zip(input_net_st, layers_net_st)) for k in i]
# MLP NETS - OUTPUT
input_net_out, layers_net_out = zip(*[get_inout_dims(net_name='output', dim_node_label=gGen.DIM_NODE_LABEL,
dim_arc_label=gGen.DIM_ARC_LABEL, dim_target=gGen.DIM_TARGET,
focus=focus, dim_state=dim_state,
layer=i, get_state=get_state, get_output=get_output) for i in range(layers)])
nets_Out = [MLP(input_dim=k, layers=j, activations=activations_net_output,
kernel_initializer=kernel_init_net_output, bias_initializer=bias_init_net_output,
name=f'Out_{idx}') for idx, (i, j) in enumerate(zip(input_net_out, layers_net_out)) for k in i]
# GNN
gnn = GNNgraphBased(nets_St[0], nets_Out[0], dim_state, max_iter, state_threshold).copy()
gnn.compile(optimizer=optimizer, loss=loss_function, average_st_grads=False, metrics=['accuracy'], run_eagerly=True)
# LGNN
lgnn = LGNN([GNNgraphBased(s, o, dim_state, max_iter, state_threshold) for s, o in zip(nets_St, nets_Out)], get_state, get_output)
lgnn.compile(optimizer=optimizer, loss=loss_function, average_st_grads=True, metrics=['accuracy'], run_eagerly=True,
training_mode=training_mode)
### DATA PROCESSING
gTr_Sequencer = MultiGraphSequencer(gTr, focus, aggregation_mode, batch_size)
gVa_Sequencer = MultiGraphSequencer(gVa, focus, aggregation_mode, batch_size)
gTe_Sequencer = MultiGraphSequencer(gTe, focus, aggregation_mode, batch_size)
### LEARNING PROCEDURE
# gnn.fit(gTr_Sequencer, epochs=epochs, validation_data=gVa_Sequencer)
# lgnn.fit(gTr_Sequencer, epochs=epochs, validation_data=gVa_Sequencer)