Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Bugfix in EMA calculation in FakeLinearQuantization
Browse files Browse the repository at this point in the history
  • Loading branch information
guyjacob committed Dec 9, 2018
1 parent aa316b6 commit 51880a2
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions distiller/quantization/range_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def replace_fn(module, name, qbits_map):
def update_ema(biased_ema, value, decay, step):
biased_ema = biased_ema * decay + (1 - decay) * value
unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction
return unbiased_ema
return biased_ema, unbiased_ema


def inputs_quantize_wrapped_forward(self, input):
Expand Down Expand Up @@ -394,8 +394,10 @@ def forward(self, input):
with torch.no_grad():
current_min, current_max = get_tensor_min_max(input)
self.iter_count = self.iter_count + 1
self.tracked_min = update_ema(self.tracked_min_biased, current_min, self.ema_decay, self.iter_count)
self.tracked_max = update_ema(self.tracked_max_biased, current_max, self.ema_decay, self.iter_count)
self.tracked_min_biased, self.tracked_min = update_ema(self.tracked_min_biased,
current_min, self.ema_decay, self.iter_count)
self.tracked_max_biased, self.tracked_max = update_ema(self.tracked_max_biased,
current_max, self.ema_decay, self.iter_count)

if self.mode == LinearQuantMode.SYMMETRIC:
max_abs = max(abs(self.tracked_min), abs(self.tracked_max))
Expand Down

0 comments on commit 51880a2

Please sign in to comment.