-
Notifications
You must be signed in to change notification settings - Fork 200
/
trainer_builder.py
132 lines (121 loc) · 5.58 KB
/
trainer_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import logging
import importlib
import federatedscope.register as register
logger = logging.getLogger(__name__)
TRAINER_CLASS_DICT = {
"cvtrainer": "CVTrainer",
"nlptrainer": "NLPTrainer",
"graphminibatch_trainer": "GraphMiniBatchTrainer",
"linkfullbatch_trainer": "LinkFullBatchTrainer",
"linkminibatch_trainer": "LinkMiniBatchTrainer",
"nodefullbatch_trainer": "NodeFullBatchTrainer",
"nodeminibatch_trainer": "NodeMiniBatchTrainer",
"flitplustrainer": "FLITPlusTrainer",
"flittrainer": "FLITTrainer",
"fedvattrainer": "FedVATTrainer",
"fedfocaltrainer": "FedFocalTrainer",
"mftrainer": "MFTrainer",
}
def get_trainer(model=None,
data=None,
device=None,
config=None,
only_for_eval=False,
is_attacker=False):
if config.trainer.type == 'general':
if config.backend == 'torch':
from federatedscope.core.trainers import GeneralTorchTrainer
trainer = GeneralTorchTrainer(model=model,
data=data,
device=device,
config=config,
only_for_eval=only_for_eval)
elif config.backend == 'tensorflow':
from federatedscope.core.trainers.tf_trainer import GeneralTFTrainer
trainer = GeneralTFTrainer(model=model,
data=data,
device=device,
config=config,
only_for_eval=only_for_eval)
else:
raise ValueError
elif config.trainer.type == 'none':
return None
elif config.trainer.type.lower() in TRAINER_CLASS_DICT:
if config.trainer.type.lower() in ['cvtrainer']:
dict_path = "federatedscope.cv.trainer.trainer"
elif config.trainer.type.lower() in ['nlptrainer']:
dict_path = "federatedscope.nlp.trainer.trainer"
elif config.trainer.type.lower() in [
'graphminibatch_trainer',
]:
dict_path = "federatedscope.gfl.trainer.graphtrainer"
elif config.trainer.type.lower() in [
'linkfullbatch_trainer', 'linkminibatch_trainer'
]:
dict_path = "federatedscope.gfl.trainer.linktrainer"
elif config.trainer.type.lower() in [
'nodefullbatch_trainer', 'nodeminibatch_trainer'
]:
dict_path = "federatedscope.gfl.trainer.nodetrainer"
elif config.trainer.type.lower() in [
'flitplustrainer', 'flittrainer', 'fedvattrainer', 'fedfocaltrainer'
]:
dict_path = "federatedscope.gfl.flitplus.trainer"
elif config.trainer.type.lower() in ['mftrainer']:
dict_path = "federatedscope.mf.trainer.trainer"
else:
raise ValueError
trainer_cls = getattr(importlib.import_module(name=dict_path),
TRAINER_CLASS_DICT[config.trainer.type.lower()])
trainer = trainer_cls(model=model,
data=data,
device=device,
config=config,
only_for_eval=only_for_eval)
else:
# try to find user registered trainer
trainer = None
for func in register.trainer_dict.values():
trainer_cls = func(config.trainer.type)
if trainer_cls is not None:
trainer = trainer_cls(model=model,
data=data,
device=device,
config=config,
only_for_eval=only_for_eval)
if trainer is None:
raise ValueError('Trainer {} is not provided'.format(
config.trainer.type))
# differential privacy plug-in
if config.nbafl.use:
from federatedscope.core.trainers.trainer_nbafl import wrap_nbafl_trainer
trainer = wrap_nbafl_trainer(trainer)
if config.sgdmf.use:
from federatedscope.mf.trainer.trainer_sgdmf import wrap_MFTrainer
trainer = wrap_MFTrainer(trainer)
# personalization plug-in
if config.federate.method.lower() == "pfedme":
from federatedscope.core.trainers.trainer_pFedMe import wrap_pFedMeTrainer
# wrap style: instance a (class A) -> instance a (class A)
trainer = wrap_pFedMeTrainer(trainer)
elif config.federate.method.lower() == "ditto":
from federatedscope.core.trainers.trainer_Ditto import wrap_DittoTrainer
# wrap style: instance a (class A) -> instance a (class A)
trainer = wrap_DittoTrainer(trainer)
elif config.federate.method.lower() == "fedem":
from federatedscope.core.trainers.trainer_FedEM import FedEMTrainer
# copy construct style: instance a (class A) -> instance b (class B)
trainer = FedEMTrainer(model_nums=config.model.model_num_per_trainer,
base_trainer=trainer)
# attacker plug-in
if is_attacker:
logger.info(
'---------------- This client is an attacker --------------------')
from federatedscope.attack.auxiliary.attack_trainer_builder import wrap_attacker_trainer
trainer = wrap_attacker_trainer(trainer, config)
# fed algorithm plug-in
if config.fedprox.use:
from federatedscope.core.trainers.trainer_fedprox import wrap_fedprox_trainer
trainer = wrap_fedprox_trainer(trainer)
return trainer