-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Design doc of fixed-point quantization. #10553
Changes from 4 commits
21c557a
40f6a18
a79a36f
882e6f4
773e566
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
Fixed-point quantization is to use lower bit, for example, 2 bit, 3 bit or 8 bit fixed-point to represent weights and activations, which usually are singe float point with 32 bit. The fixed-point representation has advantages in reducing memory bandwidth, lowering power consumption and computational resources as well as the model storage requirements. It is especially import for the inference in embedded device deployment. | ||
|
||
According some experiments, the apporach to quantize the model trained in float point directly works sufficiently on the large model, like the over-parameterized VGG model. But the accuracy drops a lot for the small model. In order to improve the tradeoff be-tween accuracy and latency, many quantized training apporaches are proposed. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
This document is to design a quantized training framework on Fluid. The first part will introduce how to quantize, The second part will describe the quantized training framework. The last part will describe how to the quantization range. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
|
||
### How to quantize | ||
|
||
There are many ways to quantizate the float value to fixed-point value. For example: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
$$ r = min(max(x, a), b)$$ | ||
$$ s = \frac{b - a}{n - 1} $$ | ||
$$ q = \left \lfloor \frac{r - a}{s} \right \rceil $$ | ||
|
||
where, $x$ is the float value to be quantized, $[a, b]$ is the quantization range, $a$ is the minimum value and $b$ is the maximal value. $\left \lfloor \right \rceil$ denotes rounding to the nearest integer. If the quantization level is $k$, $n$ is $2^k$, for example, $k$ is 8 and $n$ is 256. $q$ is the quantized integer. | ||
|
||
|
||
The quantization we apllied is parameterized by the number of quantization levels and maximum absolute value: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
$$ M = max(abs(x)) $$ | ||
$$ q = \left \lfloor \frac{x}{M} * (n - 1) \right \rceil $$ | ||
|
||
where, $x$ is the float value to be quantized, $M$ is maximum absolute value. $\left \lfloor \right \rceil$ denotes rounding to the nearest integer. For 8 bit quantization, $n=2^{8}=256$. $q$ is the quantized integer. | ||
|
||
|
||
Wether the *min-max* quantization or *max-abs* quantization, they also can be represent: | ||
|
||
$q = scale * r + b$ | ||
|
||
We call *min-max*, *max-abs* as the quantization arguments, also call them quantization scale or quantization range. | ||
|
||
|
||
How to calculate the quantization range (or maximum absolute value) for inference will be described in the last part. | ||
|
||
|
||
### Training Framework | ||
|
||
#### Forward pass | ||
|
||
The forward pass is simulated quantization, see the figure 1. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
The training framework is as following figure. | ||
|
||
<p align="center"> | ||
<img src="quantization_forward.png" width="300" height="340" /><br/> | ||
|
||
Fig 1. Forward in training with simulated quantization. | ||
</p> | ||
|
||
- At first, both input and weight will be quantized to 8 bit. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
- Then, do the multiplication (or convolution) operation with integers. | ||
- Then, dequantize the multiplication (or convolution) results to 32 bit float point. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
- At last, do bias-addition in float type of 32 bit. Here, the bias is not quantized. | ||
|
||
For general matrix to matrix multiplication (GEMM), quantize for $X$ and $W$: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
$$ X_q = \left \lfloor \frac{X}{X_m} * (n - 1) \right \rceil $$ | ||
$$ W_q = \left \lfloor \frac{W}{W_m} * (n - 1) \right \rceil $$ | ||
|
||
Do GEMM: | ||
|
||
$$ Y = X_q * W_q $$ | ||
|
||
|
||
Dequantize $Y$: | ||
|
||
$$ | ||
\begin{align} | ||
Y_{dq} &=\frac{Y}{(n - 1) * (n - 1)} * X_m * W_m \\\ | ||
&=\frac{X_q * W_q}{(n - 1) * (n - 1)} * X_m * W_m \\\ | ||
&=(\frac{X_q}{n - 1} * X_m) * (\frac{W_q}{n - 1} * W_m) | ||
\end{align} | ||
$$ | ||
|
||
From these formulas, dequantization also can be moved before GEMM, do dequantization for $Xq$ and $Wq$ at first, then do GEMM. The forward workflow in training is equivalent to following framework. | ||
|
||
<p align="center"> | ||
<img src="quantization_forward.png" width="300" height="330" /><br/> | ||
|
||
Fig 2. Equitvalent forward in training with simulated quantization. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
</p> | ||
|
||
We use this equivalent workflow in the training. In our desigin, there is a quantization transipler to insert the quantization operator and the de-quantization operator in the Fluid `ProgramDesc`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. transipler -> transpiler There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
#### Backward pass | ||
|
||
See the figure 3. The gradients are calculated by dequantized weights and activations. All inputs and outputs are float point with 32 bit. And in the weight updating process, the gradients will be added to the original weight, not the quantized or dequantized weights. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
<p align="center"> | ||
<img src="quantization_backward_and_optimization.png" /><br/> | ||
|
||
Fig 3. Backward and weight updating in training with simulated quantization. | ||
|
||
</p> | ||
|
||
So the quantization transipler will change some inputs of the corresponding backward operators. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems the current plan is to insert quant op first, generate backward, then change backward ops. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah, the implementation in https://github.com/PaddlePaddle/Paddle/pull/10693/files is this way. The usage is like:
Since the backward needs to use the dequantized weights and activations, see the Figure 3 in https://github.com/qingqing01/Paddle/blob/quantization_doc/doc/fluid/design/quantization/fixed_point_quantization.md , we still need to rewrite the backward ops. |
||
|
||
### How to calculate quantization scale | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use |
||
|
||
There are two strategies to calculate quantization scale, we call them dynamic and static strategy. The dynamic strategy is to calculate the quantization scale value each iteration. The static strategy is to fix the quantization scale for different inputs. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
For weights, we apply the dynamic strategy for weights in the training, that is to say, the quantization scale will recalculate during each iteration until the traning is finished. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
For activations, the quantization scales are estimated during training, then use them in inference. There are several different ways to estimat: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
|
||
1. Calculate the mean of maximum absolute during a window. | ||
2. Calculate the max of maximum absolute during a window. | ||
3. Calculate the running mean of maximum absolute during a window, as follows: | ||
|
||
$$ Vt = (1 - k) * V + k * V_{t-1} $$ | ||
|
||
where, $V$ is the maximum absolute value of current batch, $Vt$ is the running mean value. $k$ is a factor, such as 0.9. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is to use lower bit
-->uses lower bits
2 bit, 3 bit or 8 bit fixed-point
-->2-bit, 3-bit or 8-bit fixed point
singe float point with 32 bit
-->in single-precision float-point format with 32 bits
import
-->important
embedded device
-->embedded-device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.