-
Notifications
You must be signed in to change notification settings - Fork 71
/
film_trainer.py
117 lines (89 loc) · 4.38 KB
/
film_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# os.environ['TL_BACKEND'] = 'torch'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
import argparse
import tensorlayerx as tlx
from tensorlayerx.model import TrainOneStep, WithLoss
from gammagl.datasets.ppi import PPI
from gammagl.models import FILMModel
from gammagl.loader import DataLoader
from sklearn.metrics import f1_score
class SemiSpvzLoss(WithLoss):
def __init__(self, net, loss_fn):
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn)
def forward(self, data, y):
train_logits = self.backbone_network(data['x'], data['edge_index'])
loss = self._loss_fn(train_logits, tlx.cast(data['y'], dtype=tlx.float32))
return loss
def calculate_acc(logits, y, metrics):
metrics.update(logits, y)
rst = metrics.result()
metrics.reset()
return rst
def main(args):
print("loading ppi dataset...")
train_dataset = PPI()
val_dataset = PPI(split='val')
test_dataset = PPI(split='test')
batch_size = int(args.batch_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
net = FILMModel(in_channels=train_dataset.num_node_features,
hidden_dim=args.hidden_dim,
out_channels=train_dataset.num_classes,
num_layers=args.num_layers,
drop_rate=args.drop_rate,
name="FILM")
optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.l2_coef)
train_weights = net.trainable_weights
loss_func = SemiSpvzLoss(net, tlx.losses.sigmoid_cross_entropy)
train_one_step = TrainOneStep(loss_func, optimizer, train_weights)
best_val_acc = 0
for epoch in range(args.n_epoch):
net.set_train()
for batch in train_loader:
train_loss = train_one_step(batch, batch['y'])
net.set_eval()
for batch in val_loader:
val_logits = net(batch['x'], batch['edge_index'])
val_y = batch['y']
pred = tlx.where(val_logits > 0, 1, 0)
val_f1 = f1_score(val_y.cpu(), pred.cpu(), average='micro')
print("Epoch [{:0>3d}] ".format(epoch + 1) \
+ " train loss: {:.4f}".format(train_loss.item()) \
+ " val f1-micro: {:.4f}".format(val_f1))
if val_f1 > best_val_acc:
best_val_acc = val_f1
net.save_weights(args.best_model_path + net.name + ".npz", format='npz_dict')
net.load_weights(args.best_model_path + net.name + ".npz", format='npz_dict')
net.set_eval()
for batch in test_loader:
test_logits = net(batch['x'], batch['edge_index'])
test_y = batch['y']
pred = tlx.where(test_logits > 0, 1, 0)
test_f1 = f1_score(test_y.cpu(), pred.cpu(), average='micro')
print("Test f1-micro: {:.4f}".format(test_f1))
if __name__ == '__main__':
# parameters setting
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=0.001, help="learnin rate")
parser.add_argument("--n_epoch", type=int, default=2000, help="number of epoch")
parser.add_argument("--hidden_dim", type=int, default=320, help="dimention of hidden layers")
parser.add_argument("--drop_rate", type=float, default=0.1, help="drop_rate")
parser.add_argument("--l2_coef", type=float, default=5e-4, help="l2 loss coeficient")
parser.add_argument('--dataset', type=str, default='ppi', help='dataset(ppi)')
# parser.add_argument("--dataset_path", type=str, default=r'../../../data', help="path to save dataset")
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model")
parser.add_argument("--self_loops", type=int, default=1, help="number of graph self-loop")
parser.add_argument("--batch_size", type=int, default=2, help="batch_size of dataloader")
parser.add_argument("--num_layers", type=int, default=4, help="num of film layers")
parser.add_argument("--gpu", type=int, default=0)
args = parser.parse_args()
if args.gpu >= 0:
tlx.set_device("GPU", args.gpu)
else:
tlx.set_device("CPU")
main(args)