Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Post Training Quantization General Questions #327

Closed
karanchahal opened this issue Jul 20, 2019 · 13 comments
Closed

Post Training Quantization General Questions #327

karanchahal opened this issue Jul 20, 2019 · 13 comments
Labels
quantization The issue is related to quantization

Comments

@karanchahal
Copy link

Hello,

Im trying to quantise my CNN classification model using a simple post training quantisation. The paper by Ramakrishnan on Quantisizing CNN for efficient inference suggests that we can try to two ways to go about this.

  1. Weight Only quantisation- Either quantise the tensors layer wise or channel wise, they recommend channel wise quantisation.
  2. Activations and weights both.

I tried to perform weight only quantisation but am getting horrible results. My 99% MNIST model was degraded to an accuracy of 11% after post training quantisation. The code for quantising the pytorch tensors is given here:

QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point'])

def quantize_tensor(x, num_bits=8):
    qmin = 0.
    qmax = 2.**num_bits - 1.
    min_val, max_val = x.min(), x.max()

    scale = (max_val - min_val) / (qmax - qmin)

    initial_zero_point = qmin - min_val / scale

    zero_point = 0
    if initial_zero_point < qmin:
        zero_point = qmin
    elif initial_zero_point > qmax:
        zero_point = qmax
    else:
        zero_point = initial_zero_point

    zero_point = int(zero_point)
    q_x = zero_point + x / scale
    q_x.clamp_(qmin, qmax).round_()
    q_x = q_x.round().byte()
    return QTensor(tensor=q_x, scale=scale, zero_point=zero_point)

def dequantize_tensor(q_x):
    return q_x.scale * (q_x.tensor.float() - q_x.zero_point)

And the function I run the 99% trained model to perform post training quantisation is as follows:

def quantizeModel(model):
  """
  Post Training Quantisation
  """
  q_tensors = {}
  
  for k, m in enumerate(model.modules()):

      if  isinstance(m, nn.Conv2d):
          weight= m.weight.data
          q_weight = weight.clone()
        
          num_inp_channels = weight.shape[0]
          num_output_channels = weight.shape[1]
          
          # per channel quantisation
          for i in range(num_inp_channels):
            for j in range(num_output_channels):
              quantized_weight = quantize_tensor(weight[i,j])
              q_weight[i,j] = quantized_weight.tensor.float()
          m.weight.data = q_weight
        else:
          # for fully connected layers, layer wise quantisation
          weight = m.weight.data
          quantized_weight = quantize_tensor(weight)
          m.weight.data = quantized_weight.tensor.float()
  return model

What am I doing wrong here? Any help would deeply appreciated ! I know this isn't distiller code but I was trying to build a very minimal example of post training quantisation and you folks are the experts. I am using a very simple model for MNIST classificatoin. My model code is given here:

class Net(nn.Module):
    def __init__(self, mnist=True):
      
        super(Net, self).__init__()
        if mnist:
          num_channels = 1
        else:
          num_channels = 3
          
        self.conv1 = nn.Conv2d(num_channels, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    
    def quantize(self):
      """
      Post Training Quantisation
      """
      q_tensors = {}

      for k, m in enumerate(self.modules()):

        if isinstance(m, nn.Conv2d):   
          weight= m.weight.data
          q_weight = weight.clone()

          num_inp_channels = weight.shape[0]
          num_output_channels = weight.shape[1]

          # per channel quantisation
          for i in range(num_inp_channels):
            for j in range(num_output_channels):
              quantized_weight = quantize_tensor(weight[i,j])
              q_weight[i,j] = quantized_weight.tensor.float()

          m.weight.data = q_weight
        
      
      
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)        
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return F.log_softmax(x, dim=1)
    

Best,
Karanbir Chahal

@levzlotnik
Copy link
Contributor

levzlotnik commented Jul 20, 2019

Hi @karanchahal ,

From your code it seems you forgot to dequantize the weights before forward passing, I would suggest to add the dequantization into the model.quantize().

By the way - as you said the accuracy is 11%, while a random classifier on mnist would get just about 10%. That is a good indication that your model's weights are off :)

Also, in case you'd like to try out distiller, quantizing your model is as simple as:

import distiller
from distiller.quantization import PostTrainLinearQuantizer

quantizer = PostTrainLinearQuantizer(model, bits_activations=None, bits_weights=8)
quantizer.prepare_model(torch.rand(*your_input_shape))

And your model is good to go.
However that wouldn't quantize the relu ops, so if you want to quantize them I suggest you add separate nn.ReLU layers in your model's __init__.

Cheers,
Lev

@levzlotnik
Copy link
Contributor

Also, check out our FAQ, docs and also #316 if you'd like to know more about how to prepare your model for distiller to quantize it.

@karanchahal
Copy link
Author

karanchahal commented Jul 21, 2019

Thanks ! I'll give the distiller code a go, although I really wanted to implement a small standalone prototype. This repo is really amazing :)

I have a question, why do we need to dequantize the weights ? According to gemmlowp's article on quantisation first principles, we convert a tensor into a quantised tensor with the scale and zero point parameters.

Then we perform fixed point matrix multiplication through a few tricks to not perform the floating point scale multiplication, according to Benoit's paper. When we get the final resultant quantised matrix, we get a final equation :

result_quantized_value = result_zero_point +
    (lhs_scale * rhs_scale / result_scale) *
        Sum_over_i(
            (lhs_quantized_value[i] - lhs_zero_point) *
            (rhs_quantized_value[i] - rhs_zero_point)
        )                                                  (5)

Now that we have this quantised result matrix, we'll input this result onto the next layer.

I don't understand, where does de-quantisation fit into this picture ? Can't we use the quantised weights and activations all along the forward pass of the net, if we take into account the calculations with the zero point and scale along the way ?

I'm sorry for all these doubts !

@levzlotnik
Copy link
Contributor

Hi @karanchahal ,

You need dequantization because all the weights are rescaled and shifted, and even though the linear/conv layers are able to "respect" that by multiplying by the scale and shifting back (also scaled), the nn.ReLU and nn.Softmax layers aren't able to respect that, since they are non-linear.
Moreover - since the softmax is what gives you the actual probabilities vector, it's the most sensitive to fluctuations and quantization, and that's why we usually don't quantize anything before the softmax.

I would suggest to either dequantize back all the weights and that way you get a simulation of quantization - and this is exactly our implementation, or make adjustments to relu to match the quantization - although it was proven to not work really well. The usual approach in hardware is to fuse the activations in to the linear/conv layers, but for now- I suggest to try dequantizing the weights back to their normal scale.

Cheers,
Lev

@guyjacob
Copy link
Contributor

Just to add to @levzlotnik's answer:
Note that the accumulated sum in the equation you pasted from the gemmlowp docs is multiplied by a factor:
(lhs_scale * rhs_scale / result_scale)
This is sometimes called "re-quantization". When you derive the math for doing quantized convolution or fully-connected layer, you end up with this multiplicative factor (see also in our docs here). It's what ensures that the different scales of the input, weights and output match up.

If you look at it closely, you'll notice that doing lhs_scale * rhs_scale is exactly de-quantization of the sum. Dividing by result_scale then quantizes the result back to the scale of the output (i.e. the input of the next layer).

In your code above you're doing weights-only quantization. So you have integer weights and float inputs/outputs. These don't "match". So you can't just multiply one by the other and expect it to work. One thing you can do, as @levzlotnik implied, is to implement dedicated layers that will take care to match the different data types you're using (note that this is not the same as the analysis in gemmlowp and in our docs, which assume both weights and activations are quantized).

The simplest thing, though, will be to do as @levzlotnik suggested just quantize and de-quantize the weights ahead of time.

@karanchahal
Copy link
Author

Thank you both for your well thought out replies !

Hmm interesting, I think I kind of get it. The quantisation de quantisation step models the quantisation error without going through the trouble of rewriting layers that perform valid quantized calculations.

But the basic gist is that once I have some quantised tensor, operations with it need to take into account the scale and zero factor ( as it has to be equivalent to the floating point calculations). Is that correct?

Although now I have a separate question, since we are modelling the floating point calculations with our quantised counterparts, why do we get accuracy dips with quantisation ?

I realise that the scale factors are floating point data types and responsible for the big M multiplication. A paper says that this M factor is precomputed and cached early and made fast by using bit shifts.

Is it possible that the quantisation error comes from that ? Or it is something else entirely.

Thanks,
Karanbir Chahal

@guyjacob guyjacob added the quantization The issue is related to quantization label Jul 22, 2019
@karanchahal
Copy link
Author

I tried your suggested approach. So when I quantised the weights for the conv layers, I updated the forward pass for the model

def quantisedForward(model, x, q_params):
  x = F.relu(model.conv1(x)) 
  x = dequantize(x, q_params[0]) # dequantise using scale and zero point of weights of conv1

  x = F.max_pool2d(x, 2, 2)
  x = F.relu(model.conv2(x))
  x = dequantize(x, q_params[1]) # dequantise using scale and zero point of weights of conv2
  
  x = F.max_pool2d(x, 2, 2)

  x = x.view(-1, 4*4*50)

  x = F.relu(model.fc1(x))

  x = model.fc2(x)

  return F.log_softmax(x, dim=1)

However, doing this leads to an accuracy degradation of around 60%, hence I'm definetely doing something wrong.

I thought dequantizing would help. Do you think this fails because the weights and activations are of a different type ? I would be really grateful if you could tell me how I should change this code for it to work.

Thank you again !

Karanbir Chahal

@levzlotnik
Copy link
Contributor

Hi @karanchahal ,

We've actually meant dequantization of the weights right after quantizing them, so your model's forward pass wouldn't change but it would still give the same result as a quantized model.
However, in this case - you should dequantize the tensor before the non-linearities (like relu), since non-linearities "mess up" the tensors in terms of quantization.

Either dequantize the weight right after quantizing them, or in the functional approach:

x = model.conv1(x) # the weights are quantized with (w_scale, zero_point)
x = dequantize(x, q_params[0]) 
x = F.relu(x)
... # and so on...

@nzmora
Copy link
Contributor

nzmora commented Jul 23, 2019

Hi @karanchahal ,

I'm joining this party ;-).
Re your question:

Although now I have a separate question, since we are modelling the floating point calculations with our quantised counterparts, why do we get accuracy dips with quantisation ?

Quantization introduces noise into the computation. Let's look at the simplest example possible. Assume we have two FP32 numbers which we want to sum: 1.01 + 14.9 = 15.91. Now let's do this math in a quantized number-space using 16 representable values. This is equivalent to 4-bit integer quantization, but let's simplify and assume we represent the quantized values also using FP32, uniformly, using the range [1.0..16.0]. In other words, in this simplified example the quantization is just simple rounding ("Quantization is the process of constraining an input from a continuous or otherwise large set of values (such as the real numbers) to a discrete set (such as the integers).") So in this example 1.01 is mapped to the quantized value 1.0 and 14.9 is mapped to 15.0. And when we perform the summation we introduced before we get: quant(1.01) + quant(14.9) = 1.0 + 15.0 = 16.0 != 15.91. This is rounding noise: the noise introduced by the process of mapping values in one number-space (e.g. 1.01 in |R) to values in another number-space (e.g. 1.0 in [1.0, 2.0, ..., 16.0]). Do this many times with additions and multiplications, and you're asking for trouble.

To reduce the rounding noise we either add bits and/or reduce the range of representable values in the quantized number-space (i.e. reduce the dynamic-range). @guyjacob did an excellent job of covering this here.
When we reduce the dynamic-range of our quantized representation, we introduce noise introduced by clipping outliers. Quantization algorithms try to balance these two noise sources.

Re this:

I realise that the scale factors are floating point data types and responsible for the big M multiplication. A paper says that this M factor is precomputed and cached early and made fast by using bit shifts.

Again let's look at a simplified example. Assume you want to multiply two numbers (a weight and an activation), from two different FP32 number-spaces. Assume the weights FP32 number-space has a range of [-1..1], and the activations FP32 number-space has a range of [-15..16], and assume 5-bit quantization for both weights and activations. Both weights and activations are mapped to the same 5-bit representation, but they represent different values in the original space because they come from two very different number-space ranges.
This is why @guyjacob said you need to include (lhs_scale * rhs_scale / result_scale) (result_scale maps to the output number-space). As @guyjacob noted: "It's what ensures that the different scales of the input, weights and output match up."
There are nuances and subtleties obviously, but this is the gist of it. I hope this did confuse you further, but it's important you understand where the scale and shift values come from, and how they are used.

Cheers,
Neta

@karanchahal
Copy link
Author

Thank you @nzmora ! That made a lot of sense and gives me a much better intuition of these things work.

I'm looking forward to pytorch's implementing quantised ops so that we can see actual code speedups, I realise that they're close to that.

I was able to model weight quantisation as has been recommended and see very minimal accuracy drops ( even for 3 bits on my small model). Looking forward to trying out quantisation on both weights and activations now.

Thank you again guys ! You rock :D This thread will be really useful for newcomers to quantisation.

@guyjacob guyjacob changed the title Quantisation Aware Training simple example Post Training Quantization General Questions Jul 24, 2019
@guyjacob
Copy link
Contributor

Happy to help :)
I changed the issue title - hope you don't mind.
Will close this now.

@brisker
Copy link

brisker commented Jul 31, 2019

@guyjacob
How to run the quant_aware_train experiments by Dorefa-Net? Do we need to modify the compress_classifier.py file, because I can not see any code related to quantization here in the training mode :https://github.com/NervanaSystems/distiller/blob/master/examples/classifier_compression/compress_classifier.py#L353
but I can see the quantization code here in the evaluation mode: https://github.com/NervanaSystems/distiller/blob/master/examples/classifier_compression/compress_classifier.py#L641

@karanchahal
Copy link
Author

Hello, @guyjacob

Can you reopen this issue ?

I have a few doubts, this time about quantisation aware training. I'm glad to say that I was able to implement post training quantisation for linear, conv2d layers by looking through your code and your comments here. MNIST gets quantised fine at 8 bits, (weights and activations both) at the same accuracy (99%) :D

A little surprisingly that upon using symmetric 8 bit quantisation, the accuracy drops to 97% from 99% (I'm not doing any channel wise quantisation, maybe because of that ?)

Upon going lower to say 6 bits or 4 bits, post training quantisation doesn't seem to work, so I though I'd use quant aware training.

I'm using the Jacob Benoit paper approach where I am quantising and then dequantising the weights at each forward prop through a custom made pytorch function. To backpropogate I'm using a straight through estimator.

This is my fake quant Op:
import torch

class FakeQuantOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, num_bits=8, min_val=None, max_val=None):
        x = quantize_tensor(x,num_bits=num_bits, min_val=min_val, max_val=max_val)
        x =  dequantize_tensor(x)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        # straight through estimator
        return grad_output, None, None, None

Then I modified the forward Prop of my model to add in these fake Quant Ops to the weights of conv and fc layers before they enter their forward pass as such:

def quantAwareTrainingForward(model, x, num_bits=8):

  model.conv1.weight.data = FakeQuantOp.apply(model.conv1.weight.data, num_bits)
  x = F.relu(model.conv1(x))

  x = F.max_pool2d(x, 2, 2)

  model.conv2.weight.data = FakeQuantOp.apply(model.conv2.weight.data, num_bits)
  x = F.relu(model.conv2(x))

  x = F.max_pool2d(x, 2, 2)

  x = x.view(-1, 4*4*50)

  model.fc1.weight.data = FakeQuantOp.apply(model.fc1.weight.data, num_bits)
  x = F.relu(model.fc1(x))

  x = model.fc2(x)

  return F.log_softmax(x, dim=1)

I am not doing any fake quant of activations as that comes later during training after getting the EMA statistics if I understand correctly.

The problem with this when I try to train is that the loss seems to go upwards during training, the model trains for a short time (it goes up till 90% accuracy) but then slowly begins to diverge after each epoch, until the gradients turn into Nans.

Could you give me any insight into which this is happening ? Maybe it's my forward prop or my fake quant function ?

I would be very grateful for your advice.

Best,
Karanbir SIngh Chahal

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
quantization The issue is related to quantization
Projects
None yet
Development

No branches or pull requests

5 participants