/
context.py
184 lines (154 loc) · 7.75 KB
/
context.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import math
from federatedscope.core.auxiliaries.criterion_builder import get_criterion
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
from federatedscope.core.auxiliaries.model_builder import get_trainable_para_names
from federatedscope.core.auxiliaries.regularizer_builder import get_regularizer
class Context(dict):
"""Record and pass variables among different hook functions.
Arguments:
model (Module): training model
data (dict): a dict contains train/val/test dataset or dataloader
device: running device
Record attributes:
- model (Module): the training model
- data (dict): a dict contains train/val/test dataset or dataloader
- device (torch.device): specific device to running to
- criterion: specific loss function
- optimizer: specific optimizer
- mode: maintain the current mode of the model
- data_batch: current batch data from train/test/val data loader
- trainable_para_names (list): a list of the names of the trainable parameters within ```ctx.model```
- train_data: training dataset
- train_loader: training dataloader
- num_train_data (int): the number of training samples within one epoch
- num_train_epoch (int): the number of total training epochs
- num_train_batch (int): the number of batches within one completed training epoch
- num_train_batch_last_epoch (int): the number of batches within the last epoch
- test_data: test data
- test_loader: test dataloader
- num_test_data (int): the number of test samples within one epoch
- num_test_epoch (int): the number of test epochs, default 1
- num_test_batch (int): the number of batches within one completed test epoch
- val_data: val data
- val_loader: val dataloader
- num_val_data (int): the number of val samples within one epoch
- num_val_epoch (int): the number of val epochs, default 1
- num_val_batch (int): the number of batches within one completed val epoch
Statistical variables:
- loss_batch (float): loss of the current data_batch, shared by train/test/val
- loss_regular (float): loss of the regularizer
- loss_task (float): the sum of loss_batch and loss_regular
- loss_total_batch_train (float): accumulated batch loss during training
- loss_total_regular_train (float): accumulated regular loss during training
- num_samples_train (int): accumulated number of training samples involved at present
- loss_total_test (float): accumulated batch loss during test
- num_samples_test (float): accumulated regular loss during test
- loss_total_val (float): accumulated batch loss during val
- num_samples_val (float): accumulated regular loss during val
- eval_metrics (dict): evaluation results
"""
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError("Attribute {} is not found".format(item))
def __init__(self,
model,
cfg,
data=None,
device=None,
init_dict=None,
init_attr=True):
if init_dict is None:
super(Context, self).__init__()
else:
super(Context, self).__init__(init_dict)
self.cfg = cfg
self.model = model
self.data = data
self.device = device
self.cur_mode = None
self.cur_data_split = None
if init_attr:
# setup static variables for training/evaluation
self.setup_vars()
def setup_vars(self):
if self.cfg.backend == 'torch':
self.trainable_para_names = get_trainable_para_names(self.model)
self.criterion = get_criterion(self.cfg.criterion.type,
self.device)
self.regularizer = get_regularizer(self.cfg.regularizer.type)
self.optimizer = get_optimizer(
self.cfg.optimizer.type,
self.model,
self.cfg.optimizer.lr,
weight_decay=self.cfg.optimizer.weight_decay,
momentum=self.cfg.optimizer.momentum)
self.grad_clip = self.cfg.optimizer.grad_clip
elif self.cfg.backend == 'tensorflow':
self.trainable_para_names = self.model.trainable_variables()
self.criterion = None
self.regularizer = None
self.optimizer = None
self.grad_clip = None
self.mode = list()
self.cur_data_splits_used_by_routine = list()
# Process training data
if self.train_data is not None or self.train_loader is not None:
# Calculate the number of update steps during training given the local_update_steps
num_train_batch, num_train_batch_last_epoch, num_train_epoch, num_total_train_batch = self.pre_calculate_batch_epoch_num(
self.cfg.federate.local_update_steps)
self.num_train_epoch = num_train_epoch
self.num_train_batch = num_train_batch
self.num_train_batch_last_epoch = num_train_batch_last_epoch
self.num_total_train_batch = num_total_train_batch
# Process evaluation data
for mode in ["val", "test"]:
setattr(self, "num_{}_epoch".format(mode), 1)
if self.get("{}_data".format(mode)) is not None or self.get(
"{}_loader".format(mode)) is not None:
setattr(
self, "num_{}_batch".format(mode),
getattr(self, "num_{}_data".format(mode)) //
self.cfg.data.batch_size +
int(not self.cfg.data.drop_last and bool(
getattr(self, "num_{}_data".format(mode)) %
self.cfg.data.batch_size)))
def pre_calculate_batch_epoch_num(self, local_update_steps):
num_train_batch = self.num_train_data // self.cfg.data.batch_size + int(
not self.cfg.data.drop_last
and bool(self.num_train_data % self.cfg.data.batch_size))
if self.cfg.federate.batch_or_epoch == "epoch":
num_train_epoch = local_update_steps
num_train_batch_last_epoch = num_train_batch
num_total_train_batch = local_update_steps * num_train_batch
else:
num_train_epoch = math.ceil(local_update_steps / num_train_batch)
num_train_batch_last_epoch = local_update_steps % num_train_batch
num_total_train_batch = local_update_steps
return num_train_batch, num_train_batch_last_epoch, num_train_epoch, num_total_train_batch
def append_mode(self, mode):
self.mode.append(mode)
self.cur_mode = self.mode[-1]
self.change_mode(self.cur_mode)
def pop_mode(self):
self.mode.pop()
self.cur_mode = self.mode[-1] if len(self.mode) != 0 else None
if len(self.mode) != 0:
self.change_mode(self.cur_mode)
def change_mode(self, mode):
# change state
if self.cfg.backend == 'torch':
getattr(self.model, mode if mode == 'train' else 'eval')()
else:
pass
def track_used_dataset(self, dataset):
# stack-style to enable mixture usage such as evaluation on train dataset
self.cur_data_splits_used_by_routine.append(dataset)
self.cur_data_split = self.cur_data_splits_used_by_routine[-1]
def reset_used_dataset(self):
self.cur_data_splits_used_by_routine.pop()
self.cur_data_split = self.cur_data_splits_used_by_routine[-1] if \
len(self.cur_data_splits_used_by_routine) != 0 else None