Skip to content

Commit

Permalink
fix quant loop
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Apr 14, 2023
1 parent 3d211ff commit eaa84f4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 30 deletions.
13 changes: 5 additions & 8 deletions configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,7 @@

# learning policy
param_scheduler = dict(
_delete_=True,
type='CosineAnnealingLR',
T_max=100,
by_epoch=True,
begin=0,
end=100)
_delete_=True, type='ConstantLR', factor=1.0, by_epoch=True)

model_wrapper_cfg = dict(
type='mmrazor.MMArchitectureQuantDDP',
Expand All @@ -58,7 +53,9 @@
train_cfg = dict(
_delete_=True,
type='mmrazor.LSQEpochBasedLoop',
max_epochs=100,
max_epochs=10,
val_interval=1)
val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop')
test_cfg = val_cfg
# test_cfg = val_cfg

default_hooks = dict(sync=dict(type='SyncBuffersHook'))
59 changes: 37 additions & 22 deletions mmrazor/engine/runner/quantization_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from torch.nn.intrinsic.qat import freeze_bn_stats
except ImportError:
from mmrazor.utils import get_placeholder

disable_observer = get_placeholder('torch>=1.13')
enable_fake_quant = get_placeholder('torch>=1.13')
enable_observer = get_placeholder('torch>=1.13')
freeze_bn_stats = get_placeholder('torch>=1.13')

from mmengine.dist import all_reduce_params
from torch.utils.data import DataLoader

from mmrazor.models import register_torch_fake_quants, register_torch_observers
Expand Down Expand Up @@ -69,7 +71,18 @@ def prepare_for_run_epoch(self):
"""Toggle the state of the observers and fake quantizers before qat
training."""
self.runner.model.apply(enable_fake_quant)
self.runner.model.apply(enable_observer)

# The initialized _epoch equals to 0 so _epoch + 1
# equal to the current epoch
if (self.disable_observer_begin > 0
and self._epoch + 1 >= self.disable_observer_begin):
self.runner.model.apply(disable_observer)
else:
self.runner.model.apply(enable_observer)

if (self.freeze_bn_begin > 0
and self._epoch + 1 >= self.freeze_bn_begin):
self.runner.model.apply(freeze_bn_stats)

def prepare_for_val(self):
"""Toggle the state of the observers and fake quantizers before
Expand All @@ -89,8 +102,6 @@ def run(self):
if (self.runner.val_loop is not None
and self._epoch >= self.val_begin
and self._epoch % self.val_interval == 0):
# observer disabled during evaluation

self.runner.val_loop.run()

self.runner.call_hook('after_train')
Expand All @@ -100,16 +111,6 @@ def run_epoch(self) -> None:
self.runner.call_hook('before_train_epoch')
self.runner.model.train()

# The initialized _epoch equals to 0 so _epoch + 1
# equal to the current epoch
if (self.disable_observer_begin > 0
and self._epoch + 1 >= self.disable_observer_begin):
self.runner.model.apply(disable_observer)

if (self.freeze_bn_begin > 0
and self._epoch + 1 >= self.freeze_bn_begin):
self.runner.model.apply(freeze_bn_stats)

for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)

Expand Down Expand Up @@ -165,7 +166,11 @@ def __init__(
def prepare_for_run_epoch(self):
"""Toggle the state of the observers and fake quantizers before qat
training."""
pass
if (self.freeze_bn_begin > 0
and self._epoch + 1 >= self.freeze_bn_begin):
self.runner.model.apply(freeze_bn_stats)

self.runner.model.apply(enable_param_learning)

def prepare_for_val(self):
"""Toggle the state of the observers and fake quantizers before
Expand All @@ -177,20 +182,30 @@ def run_epoch(self) -> None:
self.runner.call_hook('before_train_epoch')
self.runner.model.train()

# TODO freeze bn
if self._epoch + 1 >= self.freeze_bn_begin:
self.runner.model.apply(freeze_bn_stats)

for idx, data_batch in enumerate(self.dataloader):
if self.is_first_batch:
# lsq init
self.is_first_batch = False
# lsq observer init
self.runner.model.apply(enable_static_estimate)
else:
self.runner.model.apply(enable_param_learning)

self.run_iter(idx, data_batch)

if self.is_first_batch:
# In the first batch, scale in LearnableFakeQuantize is
# calculated through lsq observer. As the values of `scale` of
# different observers in different rank are usually different,
# we have to sync the `scale` here.
all_reduce_params(self.runner.model.parameters(), op='mean')
# Change back to param learning mode
self.is_first_batch = False
self.runner.model.apply(enable_param_learning)

if idx > 100:
break

self.runner.model.sync_qparams(src_mode='loss')
# Make sure the registered buffer such as `observer_enabled` is
# correct in the saved checkpoint.
self.prepare_for_val()
self.runner.call_hook('after_train_epoch')
self._epoch += 1

Expand Down

0 comments on commit eaa84f4

Please sign in to comment.