Skip to content

Commit

Permalink
fix quant loop
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Apr 14, 2023
1 parent eaa84f4 commit 1c82c86
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions mmrazor/engine/runner/quantization_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
enable_observer = get_placeholder('torch>=1.13')
freeze_bn_stats = get_placeholder('torch>=1.13')

from mmengine.dist import all_reduce_params
from mmengine.dist import all_reduce_params, is_distributed
from torch.utils.data import DataLoader

from mmrazor.models import register_torch_fake_quants, register_torch_observers
Expand Down Expand Up @@ -162,6 +162,7 @@ def __init__(
dynamic_intervals=dynamic_intervals)

self.is_first_batch = True
self.distributed = is_distributed()

def prepare_for_run_epoch(self):
"""Toggle the state of the observers and fake quantizers before qat
Expand Down Expand Up @@ -194,7 +195,10 @@ def run_epoch(self) -> None:
# calculated through lsq observer. As the values of `scale` of
# different observers in different rank are usually different,
# we have to sync the `scale` here.
all_reduce_params(self.runner.model.parameters(), op='mean')
if self.distributed:
all_reduce_params(
self.runner.model.parameters(), op='mean')

# Change back to param learning mode
self.is_first_batch = False
self.runner.model.apply(enable_param_learning)
Expand Down

0 comments on commit 1c82c86

Please sign in to comment.