/
my_model_trainer.py
108 lines (90 loc) · 4.52 KB
/
my_model_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from torch import nn
try:
from fedml_core.trainer.model_trainer import ModelTrainer
except ImportError:
from FedML.fedml_core.trainer.model_trainer import ModelTrainer
class MyModelTrainer(ModelTrainer):
def get_model_params(self):
return self.model.cpu().state_dict()
def set_model_params(self, model_parameters):
self.model.load_state_dict(model_parameters)
def train(self, train_data, device, args):
model = self.model
model.to(device)
model.train()
# train and update
criterion = nn.CrossEntropyLoss().to(device)
if args.client_optimizer == "sgd":
optimizer = torch.optim.SGD(self.model.parameters(), lr=args.lr)
else:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=args.lr,
weight_decay=args.wd, amsgrad=True)
epoch_loss = []
for epoch in range(args.epochs):
batch_loss = []
for batch_idx, (x, labels) in enumerate(train_data):
x, labels = x.to(device), labels.to(device)
# logging.info("x.size = " + str(x.size()))
# logging.info("labels.size = " + str(labels.size()))
model.zero_grad()
log_probs = model(x)
loss = criterion(log_probs, labels)
loss.backward()
# to avoid nan loss
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
optimizer.step()
# logging.info('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
# epoch, (batch_idx + 1) * self.args.batch_size, len(self.local_training_data) * self.args.batch_size,
# 100. * (batch_idx + 1) / len(self.local_training_data), loss.item()))
batch_loss.append(loss.item())
epoch_loss.append(sum(batch_loss) / len(batch_loss))
# logging.info('Client Index = {}\tEpoch: {}\tLoss: {:.6f}'.format(
# self.client_idx, epoch, sum(epoch_loss) / len(epoch_loss)))
def test(self, test_data, device, args):
model = self.model
model.to(device)
model.eval()
metrics = {
'test_correct': 0,
'test_loss': 0,
'test_precision': 0,
'test_recall': 0,
'test_total': 0
}
'''
stackoverflow_lr is the task of multi-label classification
please refer to following links for detailed explainations on cross-entropy and corresponding implementation of tff research:
https://towardsdatascience.com/cross-entropy-for-classification-d98e7f974451
https://github.com/google-research/federated/blob/49a43456aa5eaee3e1749855eed89c0087983541/optimization/stackoverflow_lr/federated_stackoverflow_lr.py#L131
'''
if args.dataset == "stackoverflow_lr":
criterion = nn.BCELoss(reduction='sum').to(device)
else:
criterion = nn.CrossEntropyLoss().to(device)
with torch.no_grad():
for batch_idx, (x, target) in enumerate(test_data):
x = x.to(device)
target = target.to(device)
pred = model(x)
loss = criterion(pred, target)
if args.dataset == "stackoverflow_lr":
predicted = (pred > .5).int()
correct = predicted.eq(target).sum(axis=-1).eq(target.size(1)).sum()
true_positive = ((target * predicted) > .1).int().sum(axis=-1)
precision = true_positive / (predicted.sum(axis=-1) + 1e-13)
recall = true_positive / (target.sum(axis=-1) + 1e-13)
metrics['test_precision'] += precision.sum().item()
metrics['test_recall'] += recall.sum().item()
else:
_, predicted = torch.max(pred, 1)
correct = predicted.eq(target).sum()
metrics['test_correct'] += correct.item()
metrics['test_loss'] += loss.item() * target.size(0)
if len(target.size()) == 1: #
metrics['test_total'] += target.size(0)
elif len(target.size()) == 2: # for tasks of next word prediction
metrics['test_total'] += target.size(0) * target.size(1)
return metrics
def test_on_the_server(self, train_data_local_dict, test_data_local_dict, device, args=None) -> bool:
return False