forked from ssnl/dataset-distillation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_distilled_image.py
293 lines (241 loc) · 10.7 KB
/
train_distilled_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import logging
import time
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from basics import task_loss, final_objective_loss, evaluate_steps
from utils.distributed import broadcast_coalesced, all_reduce_coalesced
from utils.io import save_results
def permute_list(list):
indices = np.random.permutation(len(list))
return [list[i] for i in indices]
class Trainer(object):
def __init__(self, state, models):
self.state = state
self.models = models
self.num_data_steps = state.distill_steps # how much data we have
self.T = state.distill_steps * state.distill_epochs # how many sc steps we run
self.num_per_step = state.num_classes * state.distilled_images_per_class_per_step
assert state.distill_lr >= 0, 'distill_lr must >= 0'
self.init_data_optim()
def init_data_optim(self):
self.params = []
state = self.state
optim_lr = state.lr
# labels
self.labels = []
distill_label = torch.arange(state.num_classes, dtype=torch.long, device=state.device) \
.repeat(state.distilled_images_per_class_per_step, 1) # [[0, 1, 2, ...], [0, 1, 2, ...]]
distill_label = distill_label.t().reshape(-1) # [0, 0, ..., 1, 1, ...]
for _ in range(self.num_data_steps):
self.labels.append(distill_label)
self.all_labels = torch.cat(self.labels)
# data
self.data = []
for _ in range(self.num_data_steps):
# distill_data = torch.randn(self.num_per_step, state.nc, state.input_size, state.input_size,
# device=state.device, requires_grad=True)
distill_data = torch.randn(self.num_per_step, state.input_size, device=state.device, requires_grad=True)
self.data.append(distill_data)
self.params.append(distill_data)
# lr
# undo the softplus + threshold
raw_init_distill_lr = torch.tensor(state.distill_lr, device=state.device)
raw_init_distill_lr = raw_init_distill_lr.repeat(self.T, 1)
self.raw_distill_lrs = raw_init_distill_lr.expm1_().log_().requires_grad_()
self.params.append(self.raw_distill_lrs)
assert len(self.params) > 0, "must have at least 1 parameter"
# now all the params are in self.params, sync if using distributed
if state.distributed:
broadcast_coalesced(self.params)
logging.info("parameters broadcast done!")
self.optimizer = optim.Adam(self.params, lr=state.lr, betas=(0.5, 0.999))
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=state.decay_epochs,
gamma=state.decay_factor)
for p in self.params:
p.grad = torch.zeros_like(p)
def get_steps(self):
# print("data: {}".format(self.data))
data_label_iterable = (x for _ in range(self.state.distill_epochs) for x in zip(self.data, self.labels))
lrs = F.softplus(self.raw_distill_lrs).unbind()
steps = []
for (data, label), lr in zip(data_label_iterable, lrs):
steps.append((data, label, lr))
# print("data: {}".format(steps))
return steps
def forward(self, model, rdata, rlabel, steps):
state = self.state
# forward
model.train()
w = model.get_param()
params = [w]
gws = []
for step_i, (data, label, lr) in enumerate(steps):
with torch.enable_grad():
output = model.forward_with_param(data, w)
loss = task_loss(state, output, label)
gw, = torch.autograd.grad(loss, w, lr.squeeze(), create_graph=True)
with torch.no_grad():
new_w = w.sub(gw).requires_grad_()
params.append(new_w)
gws.append(gw)
w = new_w
# final L
model.eval()
output = model.forward_with_param(rdata, params[-1])
ll = final_objective_loss(state, output, rlabel)
return ll, (ll, params, gws)
def backward(self, model, rdata, rlabel, steps, saved_for_backward):
l, params, gws = saved_for_backward
state = self.state
datas = []
gdatas = []
lrs = []
glrs = []
dw, = torch.autograd.grad(l, (params[-1],))
# backward
model.train()
# Notation:
# math: \grad is \nabla
# symbol: d* means the gradient of final L w.r.t. *
# dw is \d L / \dw
# dgw is \d L / \d (\grad_w_t L_t )
# We fold lr as part of the input to the step-wise loss
#
# gw_t = \grad_w_t L_t (1)
# w_{t+1} = w_t - gw_t (2)
#
# Invariants at beginning of each iteration:
# ws are BEFORE applying gradient descent in this step
# Gradients dw is w.r.t. the updated ws AFTER this step
# dw = \d L / d w_{t+1}
for (data, label, lr), w, gw in reversed(list(zip(steps, params, gws))):
# hvp_in are the tensors we need gradients w.r.t. final L:
# lr (if learning)
# data
# ws (PRE-GD) (needed for next step)
#
# source of gradients can be from:
# gw, the gradient in this step, whose gradients come from:
# the POST-GD updated ws
hvp_in = [w]
hvp_in.append(data)
hvp_in.append(lr)
dgw = dw.neg() # gw is already weighted by lr, so simple negation
hvp_grad = torch.autograd.grad(
outputs=(gw,),
inputs=hvp_in,
grad_outputs=(dgw,)
)
# Update for next iteration, i.e., previous step
with torch.no_grad():
# Save the computed gdata and glrs
datas.append(data)
gdatas.append(hvp_grad[1])
lrs.append(lr)
glrs.append(hvp_grad[2])
# Update for next iteration, i.e., previous step
# Update dw
# dw becomes the gradients w.r.t. the updated w for previous step
dw.add_(hvp_grad[0])
return datas, gdatas, lrs, glrs
def accumulate_grad(self, grad_infos):
bwd_out = []
bwd_grad = []
for datas, gdatas, lrs, glrs in grad_infos:
bwd_out += list(lrs)
bwd_grad += list(glrs)
for d, g in zip(datas, gdatas):
d.grad.add_(g)
if len(bwd_out) > 0:
torch.autograd.backward(bwd_out, bwd_grad)
def save_results(self, steps=None, visualize=False, subfolder=''):
with torch.no_grad():
steps = steps or self.get_steps()
save_results(self.state, steps, visualize=visualize, subfolder=subfolder)
def __call__(self):
return self.train()
def prefetch_train_loader_iter(self):
state = self.state
device = state.device
train_iter = iter(state.train_loader)
for epoch in range(state.epochs):
niter = len(train_iter)
prefetch_it = max(0, niter - 2)
for it, val in enumerate(train_iter):
# Prefetch (start workers) at the end of epoch BEFORE yielding
if it == prefetch_it and epoch < state.epochs - 1:
train_iter = iter(state.train_loader)
yield epoch, it, val
def train(self):
state = self.state
device = state.device
train_loader = state.train_loader
sample_n_nets = state.local_sample_n_nets
grad_divisor = state.sample_n_nets # i.e., global sample_n_nets
ckpt_int = state.checkpoint_interval
data_t0 = time.time()
for epoch, it, (rdata, rlabel) in self.prefetch_train_loader_iter():
data_t = time.time() - data_t0
if it == 0:
self.scheduler.step()
if it == 0 and ((ckpt_int >= 0 and epoch % ckpt_int == 0) or epoch == 0):
with torch.no_grad():
steps = self.get_steps()
self.save_results(steps=steps, subfolder='checkpoints/epoch{:04d}'.format(epoch))
evaluate_steps(state, steps, 'Begin of epoch {}'.format(epoch))
do_log_this_iter = it == 0 or (state.log_interval >= 0 and it % state.log_interval == 0)
self.optimizer.zero_grad()
rdata, rlabel = rdata.to(device, non_blocking=True), rlabel.to(device, non_blocking=True)
if sample_n_nets == state.local_n_nets:
tmodels = self.models
else:
idxs = np.random.choice(state.local_n_nets, sample_n_nets, replace=False)
tmodels = [self.models[i] for i in idxs]
t0 = time.time()
losses = []
steps = self.get_steps()
# activate everything needed to run on this process
grad_infos = []
for model in tmodels:
if state.train_nets_type == 'unknown_init':
model.reset(state)
l, saved = self.forward(model, rdata, rlabel, steps)
losses.append(l.detach())
grad_infos.append(self.backward(model, rdata, rlabel, steps, saved))
del l, saved
self.accumulate_grad(grad_infos)
# all reduce if needed
# average grad
all_reduce_tensors = [p.grad for p in self.params]
if do_log_this_iter:
losses = torch.stack(losses, 0).sum()
all_reduce_tensors.append(losses)
if state.distributed:
all_reduce_coalesced(all_reduce_tensors, grad_divisor)
else:
for t in all_reduce_tensors:
t.div_(grad_divisor)
# opt step
self.optimizer.step()
t = time.time() - t0
if do_log_this_iter:
loss = losses.item()
logging.info((
'Epoch: {:4d} [{:7d}/{:7d} ({:2.0f}%)]\tLoss: {:.4f}\t'
'Data Time: {:.2f}s\tTrain Time: {:.2f}s'
).format(
epoch, it * train_loader.batch_size, len(train_loader.dataset),
100. * it / len(train_loader), loss, data_t, t,
))
if loss != loss: # nan
raise RuntimeError('loss became NaN')
del steps, grad_infos, losses, all_reduce_tensors
data_t0 = time.time()
with torch.no_grad():
steps = self.get_steps()
self.save_results(steps)
return steps
def distill(state, models):
return Trainer(state, models).train()