-
Notifications
You must be signed in to change notification settings - Fork 799
Post Training Quantization General Questions #327
Comments
Hi @karanchahal , From your code it seems you forgot to dequantize the weights before forward passing, I would suggest to add the dequantization into the By the way - as you said the accuracy is 11%, while a random classifier on mnist would get just about 10%. That is a good indication that your model's weights are off :) Also, in case you'd like to try out distiller, quantizing your model is as simple as:
And your model is good to go. Cheers, |
Also, check out our FAQ, docs and also #316 if you'd like to know more about how to prepare your model for distiller to quantize it. |
Thanks ! I'll give the distiller code a go, although I really wanted to implement a small standalone prototype. This repo is really amazing :) I have a question, why do we need to dequantize the weights ? According to gemmlowp's article on quantisation first principles, we convert a tensor into a quantised tensor with the scale and zero point parameters. Then we perform fixed point matrix multiplication through a few tricks to not perform the floating point scale multiplication, according to Benoit's paper. When we get the final resultant quantised matrix, we get a final equation :
Now that we have this quantised result matrix, we'll input this result onto the next layer. I don't understand, where does de-quantisation fit into this picture ? Can't we use the quantised weights and activations all along the forward pass of the net, if we take into account the calculations with the zero point and scale along the way ? I'm sorry for all these doubts ! |
Hi @karanchahal , You need dequantization because all the weights are rescaled and shifted, and even though the linear/conv layers are able to "respect" that by multiplying by the scale and shifting back (also scaled), the I would suggest to either dequantize back all the weights and that way you get a simulation of quantization - and this is exactly our implementation, or make adjustments to Cheers, |
Just to add to @levzlotnik's answer: If you look at it closely, you'll notice that doing In your code above you're doing weights-only quantization. So you have integer weights and float inputs/outputs. These don't "match". So you can't just multiply one by the other and expect it to work. One thing you can do, as @levzlotnik implied, is to implement dedicated layers that will take care to match the different data types you're using (note that this is not the same as the analysis in gemmlowp and in our docs, which assume both weights and activations are quantized). The simplest thing, though, will be to do as @levzlotnik suggested just quantize and de-quantize the weights ahead of time. |
Thank you both for your well thought out replies ! Hmm interesting, I think I kind of get it. The quantisation de quantisation step models the quantisation error without going through the trouble of rewriting layers that perform valid quantized calculations. But the basic gist is that once I have some quantised tensor, operations with it need to take into account the scale and zero factor ( as it has to be equivalent to the floating point calculations). Is that correct? Although now I have a separate question, since we are modelling the floating point calculations with our quantised counterparts, why do we get accuracy dips with quantisation ? I realise that the scale factors are floating point data types and responsible for the big M multiplication. A paper says that this M factor is precomputed and cached early and made fast by using bit shifts. Is it possible that the quantisation error comes from that ? Or it is something else entirely. Thanks, |
I tried your suggested approach. So when I quantised the weights for the conv layers, I updated the forward pass for the model
However, doing this leads to an accuracy degradation of around 60%, hence I'm definetely doing something wrong. I thought dequantizing would help. Do you think this fails because the weights and activations are of a different type ? I would be really grateful if you could tell me how I should change this code for it to work. Thank you again ! Karanbir Chahal |
Hi @karanchahal , We've actually meant dequantization of the weights right after quantizing them, so your model's forward pass wouldn't change but it would still give the same result as a quantized model. Either dequantize the weight right after quantizing them, or in the functional approach:
|
Hi @karanchahal , I'm joining this party ;-).
Quantization introduces noise into the computation. Let's look at the simplest example possible. Assume we have two FP32 numbers which we want to sum: To reduce the rounding noise we either add bits and/or reduce the range of representable values in the quantized number-space (i.e. reduce the dynamic-range). @guyjacob did an excellent job of covering this here. Re this:
Again let's look at a simplified example. Assume you want to multiply two numbers (a weight and an activation), from two different FP32 number-spaces. Assume the weights FP32 number-space has a range of [-1..1], and the activations FP32 number-space has a range of [-15..16], and assume 5-bit quantization for both weights and activations. Both weights and activations are mapped to the same 5-bit representation, but they represent different values in the original space because they come from two very different number-space ranges. Cheers, |
Thank you @nzmora ! That made a lot of sense and gives me a much better intuition of these things work. I'm looking forward to pytorch's implementing quantised ops so that we can see actual code speedups, I realise that they're close to that. I was able to model weight quantisation as has been recommended and see very minimal accuracy drops ( even for 3 bits on my small model). Looking forward to trying out quantisation on both weights and activations now. Thank you again guys ! You rock :D This thread will be really useful for newcomers to quantisation. |
Happy to help :) |
@guyjacob |
Hello, @guyjacob Can you reopen this issue ? I have a few doubts, this time about quantisation aware training. I'm glad to say that I was able to implement post training quantisation for linear, conv2d layers by looking through your code and your comments here. MNIST gets quantised fine at 8 bits, (weights and activations both) at the same accuracy (99%) :D A little surprisingly that upon using symmetric 8 bit quantisation, the accuracy drops to 97% from 99% (I'm not doing any channel wise quantisation, maybe because of that ?) Upon going lower to say 6 bits or 4 bits, post training quantisation doesn't seem to work, so I though I'd use quant aware training. I'm using the Jacob Benoit paper approach where I am quantising and then dequantising the weights at each forward prop through a custom made pytorch function. To backpropogate I'm using a straight through estimator. This is my fake quant Op: class FakeQuantOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x, num_bits=8, min_val=None, max_val=None):
x = quantize_tensor(x,num_bits=num_bits, min_val=min_val, max_val=max_val)
x = dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
# straight through estimator
return grad_output, None, None, None Then I modified the forward Prop of my model to add in these fake Quant Ops to the weights of conv and fc layers before they enter their forward pass as such: def quantAwareTrainingForward(model, x, num_bits=8):
model.conv1.weight.data = FakeQuantOp.apply(model.conv1.weight.data, num_bits)
x = F.relu(model.conv1(x))
x = F.max_pool2d(x, 2, 2)
model.conv2.weight.data = FakeQuantOp.apply(model.conv2.weight.data, num_bits)
x = F.relu(model.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
model.fc1.weight.data = FakeQuantOp.apply(model.fc1.weight.data, num_bits)
x = F.relu(model.fc1(x))
x = model.fc2(x)
return F.log_softmax(x, dim=1) I am not doing any fake quant of activations as that comes later during training after getting the EMA statistics if I understand correctly. The problem with this when I try to train is that the loss seems to go upwards during training, the model trains for a short time (it goes up till 90% accuracy) but then slowly begins to diverge after each epoch, until the gradients turn into Nans. Could you give me any insight into which this is happening ? Maybe it's my forward prop or my fake quant function ? I would be very grateful for your advice. Best, |
Hello,
Im trying to quantise my CNN classification model using a simple post training quantisation. The paper by Ramakrishnan on Quantisizing CNN for efficient inference suggests that we can try to two ways to go about this.
I tried to perform weight only quantisation but am getting horrible results. My 99% MNIST model was degraded to an accuracy of 11% after post training quantisation. The code for quantising the pytorch tensors is given here:
And the function I run the 99% trained model to perform post training quantisation is as follows:
What am I doing wrong here? Any help would deeply appreciated ! I know this isn't distiller code but I was trying to build a very minimal example of post training quantisation and you folks are the experts. I am using a very simple model for MNIST classificatoin. My model code is given here:
Best,
Karanbir Chahal
The text was updated successfully, but these errors were encountered: