This repository has been archived by the owner on May 1, 2023. It is now read-only.
/
policy.py
executable file
·327 lines (272 loc) · 15.5 KB
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
#
# Copyright (c) 2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Policies for scheduling by a CompressionScheduler instance.
- PruningPolicy: prunning policy
- RegularizationPolicy: regulization scheduling
- LRPolicy: learning-rate decay scheduling
- QuantizationPolicy: quantization scheduling
"""
import torch
import torch.nn as nn
import torch.optim.lr_scheduler
from collections import namedtuple, OrderedDict
import logging
import distiller
__all__ = ['PruningPolicy', 'RegularizationPolicy', 'QuantizationPolicy', 'LRPolicy', 'ScheduledTrainingPolicy',
'PolicyLoss', 'LossComponent']
msglogger = logging.getLogger()
PolicyLoss = namedtuple('PolicyLoss', ['overall_loss', 'loss_components'])
LossComponent = namedtuple('LossComponent', ['name', 'value'])
class ScheduledTrainingPolicy(object):
""" Base class for all scheduled training policies.
The CompressionScheduler invokes these methods as the training progresses.
"""
def __init__(self, classes=None, layers=None):
self.classes = classes
self.layers = layers
def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs):
"""A new epcoh is about to begin"""
pass
def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch,
zeros_mask_dict, meta, optimizer=None):
"""The forward-pass of a new mini-batch is about to begin"""
pass
def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss, zeros_mask_dict,
optimizer=None):
"""The mini-batch training pass has completed the forward-pass,
and is about to begin the backward pass.
This callback receives a 'loss' argument. The callback should not modify this argument, but it can
optionally return an instance of 'PolicyLoss' which will be used in place of `loss'.
Note: The 'loss_components' parameter within 'PolicyLoss' should contain any new, individual loss components
the callback contributed to 'overall_loss'. It should not contain the incoming 'loss' argument.
"""
pass
def before_parameter_optimization(self, model, epoch, minibatch_id, minibatches_per_epoch,
zeros_mask_dict, meta, optimizer):
"""The mini-batch training pass has completed the backward-pass,
and the optimizer is about to update the weights."""
pass
def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
"""The mini-batch training pass has ended"""
pass
def on_epoch_end(self, model, zeros_mask_dict, meta, **kwargs):
"""The current epoch has ended"""
pass
class PruningPolicy(ScheduledTrainingPolicy):
"""Base class for pruning policies.
"""
def __init__(self, pruner, pruner_args, classes=None, layers=None):
"""
Arguments:
mask_on_forward_only: controls what we do after the weights are updated by the backward pass.
In issue #53 (https://github.com/IntelLabs/distiller/issues/53) we explain why in some
cases masked weights will be updated to a non-zero value, even if their gradients are masked
(e.g. when using SGD with momentum). Therefore, to circumvent this weights-update performed by
the backward pass, we usually mask the weights again - right after the backward pass. To
disable this masking set:
pruner_args['mask_on_forward_only'] = False
use_double_copies: when set to `True`, two sets of weights are used. In the forward-pass we use
masked weights to compute the loss, but in the backward-pass we update the unmasked weights (using
gradients computed from the masked-weights loss).
mini_batch_pruning_frequency: this controls pruning scheduling at the mini-batch granularity. Every
mini_batch_pruning_frequency training steps (i.e. mini_batches) we perform pruning. This provides more
fine-grained control over pruning than that provided by CompressionScheduler (epoch granularity).
When setting 'mini_batch_pruning_frequency' to a value other than zero, make sure to configure the policy's
schedule to once-every-epoch.
fold_batchnorm: when set to `True`, the weights of BatchNorm modules are folded into the the weights of
Conv-2D modules (if Conv2D->BN edges exist in the model graph). Each weights filter is attenuated using
a different pair of (gamma, beta) coefficients, so `fold_batchnorm` is relevant for fine-grained and
filter-ranking pruning methods. We attenuate using the running values of the mean and variance, as is
done in quantization.
This control argument is only supported for Conv-2D modules (i.e. other convolution operation variants and
Linear operations are not supported).
"""
super(PruningPolicy, self).__init__(classes, layers)
self.pruner = pruner
# Copy external policy configuration, if available
if pruner_args is None:
pruner_args = {}
self.levels = pruner_args.get('levels', None)
self.keep_mask = pruner_args.get('keep_mask', False)
self.mini_batch_pruning_frequency = pruner_args.get('mini_batch_pruning_frequency', 0)
self.mask_on_forward_only = pruner_args.get('mask_on_forward_only', False)
self.mask_gradients = pruner_args.get('mask_gradients', False)
if self.mask_gradients and not self.mask_on_forward_only:
raise ValueError("mask_gradients and (not mask_on_forward_only) are mutually exclusive")
self.backward_hook_handle = None # The backward-callback handle
self.use_double_copies = pruner_args.get('use_double_copies', False)
self.discard_masks_at_minibatch_end = pruner_args.get('discard_masks_at_minibatch_end', False)
self.skip_first_minibatch = pruner_args.get('skip_first_minibatch', False)
self.fold_bn = pruner_args.get('fold_batchnorm', False)
# These are required for BN-folding. We cache them to improve performance
self.named_modules = None
self.sg = None
# Initialize state
self.is_last_epoch = False
self.is_initialized = False
@staticmethod
def _fold_batchnorm(model, param_name, param, named_modules, sg):
def _get_all_parameters(param_module, bn_module):
w, b, gamma, beta = param_module.weight, param_module.bias, bn_module.weight, bn_module.bias
if not bn_module.affine:
gamma = 1.
beta = 0.
return w, b, gamma, beta
def get_bn_folded_weights(conv_module, bn_module):
"""Compute the weights of `conv_module` after folding successor BN layer.
In inference, DL frameworks and graph-compilers fold the batch normalization into
the weights as defined by equations 20 and 21 of https://arxiv.org/pdf/1806.08342.pdf
:param conv_module: nn.Conv2d module
:param bn_module: nn.BatchNorm2d module which succeeds `conv_module`
:return: Folded weights
"""
w, b, gamma, beta = _get_all_parameters(conv_module, bn_module)
with torch.no_grad():
sigma_running = torch.sqrt(bn_module.running_var + bn_module.eps)
w_corrected = w * (gamma / sigma_running).view(-1, 1, 1, 1)
return w_corrected
layer_name = distiller.utils.param_name_2_module_name(param_name)
if not isinstance(named_modules[layer_name], nn.Conv2d):
return param
bn_layers = sg.successors_f(layer_name, ['BatchNormalization'])
if bn_layers:
assert len(bn_layers) == 1
bn_module = named_modules[bn_layers[0]]
conv_module = named_modules[layer_name]
param = get_bn_folded_weights(conv_module, bn_module)
return param
def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs):
msglogger.debug("Pruner {} is about to prune".format(self.pruner.name))
self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1)
if self.levels is not None:
self.pruner.levels = self.levels
meta['model'] = model
is_initialized = self.is_initialized
if self.fold_bn:
# Cache this information (required for BN-folding) to improve performance
self.named_modules = OrderedDict(model.named_modules())
dummy_input = torch.randn(model.input_shape)
self.sg = distiller.SummaryGraph(model, dummy_input)
for param_name, param in model.named_parameters():
if self.fold_bn:
param = self._fold_batchnorm(model, param_name, param, self.named_modules, self.sg)
if not is_initialized:
# Initialize the maskers
masker = zeros_mask_dict[param_name]
masker.use_double_copies = self.use_double_copies
masker.mask_on_forward_only = self.mask_on_forward_only
# register for the backward hook of the parameters
if self.mask_gradients:
masker.backward_hook_handle = param.register_hook(masker.mask_gradient)
self.is_initialized = True
if not self.skip_first_minibatch:
self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta)
else:
self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta)
def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch,
zeros_mask_dict, meta, optimizer=None):
set_masks = False
global_mini_batch_id = epoch * minibatches_per_epoch + minibatch_id
if ((minibatch_id > 0) and
(self.mini_batch_pruning_frequency != 0) and
(global_mini_batch_id % self.mini_batch_pruning_frequency == 0)):
# This is _not_ the first mini-batch of a new epoch (performed in on_epoch_begin)
# and a pruning step is scheduled
set_masks = True
if self.skip_first_minibatch and global_mini_batch_id == 1:
# Because we skipped the first mini-batch of the first epoch (global_mini_batch_id == 0)
set_masks = True
for param_name, param in model.named_parameters():
if set_masks:
if self.fold_bn:
param = self._fold_batchnorm(model, param_name, param, self.named_modules, self.sg)
self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta)
zeros_mask_dict[param_name].apply_mask(param)
def before_parameter_optimization(self, model, epoch, minibatch_id, minibatches_per_epoch,
zeros_mask_dict, meta, optimizer):
for param_name, param in model.named_parameters():
zeros_mask_dict[param_name].revert_weights(param)
def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
if self.discard_masks_at_minibatch_end:
for param_name, param in model.named_parameters():
zeros_mask_dict[param_name].mask = None
def on_epoch_end(self, model, zeros_mask_dict, meta, **kwargs):
"""The current epoch has ended"""
if self.is_last_epoch:
for param_name, param in model.named_parameters():
masker = zeros_mask_dict[param_name]
if self.keep_mask:
masker.use_double_copies = False
masker.mask_on_forward_only = False
masker.mask_tensor(param)
if masker.backward_hook_handle is not None:
masker.backward_hook_handle.remove()
masker.backward_hook_handle = None
class RegularizationPolicy(ScheduledTrainingPolicy):
"""Regularization policy.
"""
def __init__(self, regularizer, keep_mask=False):
super(RegularizationPolicy, self).__init__()
self.regularizer = regularizer
self.keep_mask = keep_mask
self.is_last_epoch = False
def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs):
self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1)
def before_backward_pass(self, model, epoch, minibatch_id, minibatches_per_epoch, loss,
zeros_mask_dict, optimizer=None):
regularizer_loss = torch.tensor(0, dtype=torch.float, device=loss.device)
for param_name, param in model.named_parameters():
self.regularizer.loss(param, param_name, regularizer_loss, zeros_mask_dict)
policy_loss = PolicyLoss(loss + regularizer_loss,
[LossComponent(self.regularizer.__class__.__name__ + '_loss', regularizer_loss)])
return policy_loss
def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
if self.regularizer.threshold_criteria is None:
return
keep_mask = False
if (minibatches_per_epoch-1 == minibatch_id) and self.is_last_epoch and self.keep_mask:
# If this is the last mini_batch in the last epoch, and the scheduler wants to
# keep the regularization mask, then now is the time ;-)
msglogger.info("RegularizationPolicy is keeping the regularization mask")
keep_mask = True
for param_name, param in model.named_parameters():
self.regularizer.threshold(param, param_name, zeros_mask_dict)
if keep_mask:
zeros_mask_dict[param_name].is_regularization_mask = False
zeros_mask_dict[param_name].apply_mask(param)
class LRPolicy(ScheduledTrainingPolicy):
"""Learning-rate decay scheduling policy.
"""
def __init__(self, lr_scheduler):
super(LRPolicy, self).__init__()
self.lr_scheduler = lr_scheduler
def on_epoch_end(self, model, zeros_mask_dict, meta, **kwargs):
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
# Note: ReduceLROnPlateau doesn't inherit from _LRScheduler
self.lr_scheduler.step(kwargs['metrics'][self.lr_scheduler.mode],
epoch=meta['current_epoch'] + 1)
else:
self.lr_scheduler.step(epoch=meta['current_epoch'] + 1)
class QuantizationPolicy(ScheduledTrainingPolicy):
def __init__(self, quantizer):
super(QuantizationPolicy, self).__init__()
self.quantizer = quantizer
self.quantizer.prepare_model()
self.quantizer.quantize_params()
def on_minibatch_end(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer):
# After parameters update, quantize the parameters again
# (Doing this here ensures the model parameters are quantized at training completion (and at validation time)
self.quantizer.quantize_params()