-
Notifications
You must be signed in to change notification settings - Fork 218
/
Copy pathgraphtrainer.py
112 lines (93 loc) · 4.64 KB
/
graphtrainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import logging
import os
import numpy as np
from federatedscope.core.monitors import Monitor
from federatedscope.register import register_trainer
from federatedscope.core.trainers import GeneralTorchTrainer
logger = logging.getLogger(__name__)
class GraphMiniBatchTrainer(GeneralTorchTrainer):
def _hook_on_fit_start_init(self, ctx):
super()._hook_on_fit_start_init(ctx)
setattr(ctx, "{}_y_inds".format(ctx.cur_data_split), [])
def _hook_on_batch_forward(self, ctx):
batch = ctx.data_batch.to(ctx.device)
pred = ctx.model(batch)
# TODO: deal with the type of data within the dataloader or dataset
if 'regression' in ctx.cfg.model.task.lower():
label = batch.y
else:
label = batch.y.squeeze(-1).long()
if len(label.size()) == 0:
label = label.unsqueeze(0)
ctx.loss_batch = ctx.criterion(pred, label)
ctx.batch_size = len(label)
ctx.y_true = label
ctx.y_prob = pred
# record the index of the ${MODE} samples
if hasattr(ctx.data_batch, 'data_index'):
setattr(
ctx,
f'{ctx.cur_data_split}_y_inds',
ctx.get(f'{ctx.cur_data_split}_y_inds') + [batch[_].data_index.item() for _ in range(len(label))]
)
def _hook_on_batch_forward_flop_count(self, ctx):
if not isinstance(self.ctx.monitor, Monitor):
logger.warning(
f"The trainer {type(self)} does contain a valid monitor, "
f"this may be caused by initializing trainer subclasses "
f"without passing a valid monitor instance."
f"Plz check whether this is you want.")
return
if self.cfg.eval.count_flops and self.ctx.monitor.flops_per_sample \
== 0:
# calculate the flops_per_sample
try:
batch = ctx.data_batch.to(ctx.device)
from torch_geometric.data import Data
if isinstance(batch, Data):
x, edge_index = batch.x, batch.edge_index
from fvcore.nn import FlopCountAnalysis
flops_one_batch = FlopCountAnalysis(ctx.model,
(x, edge_index)).total()
if self.model_nums > 1 and ctx.mirrored_models:
flops_one_batch *= self.model_nums
logger.warning(
"the flops_per_batch is multiplied by "
"internal model nums as self.mirrored_models=True."
"if this is not the case you want, "
"please customize the count hook")
self.ctx.monitor.track_avg_flops(flops_one_batch,
ctx.batch_size)
except:
logger.warning(
"current flop count implementation is for general "
"NodeFullBatchTrainer case: "
"1) the ctx.model takes only batch = ctx.data_batch as "
"input."
"Please check the forward format or implement your own "
"flop_count function")
self.ctx.monitor.flops_per_sample = -1 # warning at the
# first failure
# by default, we assume the data has the same input shape,
# thus simply multiply the flops to avoid redundant forward
self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \
ctx.batch_size
def save_prediction(self, path, client_id, task_type):
y_inds, y_probs = self.ctx.test_y_inds, self.ctx.test_y_prob
os.makedirs(path, exist_ok=True)
# TODO: more feasible, for now we hard code it for cikmcup
y_preds = np.argmax(y_probs, axis=-1) if 'classification' in task_type.lower() else y_probs
if len(y_inds) != len(y_preds):
raise ValueError(f'The length of the predictions {len(y_preds)} not equal to the samples {len(y_inds)}.')
with open(os.path.join(path, 'prediction.csv'), 'a') as file:
for y_ind, y_pred in zip(y_inds, y_preds):
if 'classification' in task_type.lower():
line = [client_id, y_ind] + [y_pred]
else:
line = [client_id, y_ind] + list(y_pred)
file.write(','.join([str(_) for _ in line]) + '\n')
def call_graph_level_trainer(trainer_type):
if trainer_type == 'graphminibatch_trainer':
trainer_builder = GraphMiniBatchTrainer
return trainer_builder
register_trainer('graphminibatch_trainer', call_graph_level_trainer)