If not `use_min`, the optimazation goal is to minimize the following function:
$$
\text{Error}(d) = \sum_{j=0}^{N-1} \left( x_j - d \cdot q_j \right)^2
$$

Already know $x_j$ and $q_j$, when:
$$
d = \frac{\sum_{j=0}^{N-1} x_j \cdot q_j}{\sum_{j=0}^{N-1} q_j^2}
$$
makes $\text{Error}$ smallest.


If `use_min`:
$$
E = \sum_{i=0}^{n-1} \left( \text{scale} \cdot l_i + \text{min} - x_i \right)^2
$$
Already know $x_j$ and $l_j$, when:

$$
\text{scale} = \frac{N\sum_{i=0}^{N-1} l_ix_i - \sum_{i=0}^{N-1}x_i\sum_{i=0}^{N-1}l_i }{N\sum_{i=0}^{N-1}l_i^2 - (\sum_{i=0}^{N-1}l_i)^2}
$$

$$
\text{min} = \frac{\sum_{i=0}^{N-1} x_i \sum_{i=0}^{N-1} l_i^2- \sum_{i=0}^{N-1}x_il_i\sum_{i=0}^{N-1}l_i }{N\sum_{i=0}^{N-1}l_i^2 - (\sum_{i=0}^{N-1}l_i)^2}
$$

makes $\text{E}$ smallest.

In [5]:
import numpy as np

In [6]:
class quantizer:
    def __init__(self, n_bits: int = 8, use_min: bool = False) -> None:
        """
        Args:
            n_bits (int): number of bits
            bsz (int): block size
        """
        self.n_bits = n_bits
        self.q_range = 2**n_bits - 1
        self.q_max = 2**(n_bits-1) - 1
        self.q_min = -2**(n_bits-1)
        self.use_min = use_min

    def quantize(self, x: np.ndarray) -> tuple[np.ndarray, float, float]:
        if not self.use_min:
            x_abs = np.abs(x)
            x_abs_max = np.max(x_abs)
            scale = x_abs_max / self.q_min
            q = np.clip(np.round(x / scale), self.q_min, self.q_max)
            min = None
        else:
            scale = (np.max(x) - np.min(x)) / self.q_range
            min = np.min(x)
            q = np.clip(np.round((x - min) / scale), 0, self.q_range)
        return q, scale, min

    def dequantize(self, q: np.ndarray, scale: float, min: float) -> np.ndarray:
        q = q.astype(np.float32)
        if not self.use_min:
            dq = q * scale # dequantize q
        else:
            dq = q * scale + min
        return dq

    def calc_mse(self, x: np.ndarray, dq: np.ndarray) -> float:
        return np.mean((x - dq)**2)
    
    # recalculation of scale and min according to the formula above
    def recalc_scale_and_min(self, x: np.ndarray, q: np.ndarray) -> tuple[float, float]:
        assert x.shape == q.shape
        if not self.use_min:
            scale = (x * q).sum() / (q**2).sum()
            min = None
        else:
            # get the element num of q
            N = q.size
            D = N * (q**2).sum() - q.sum()**2
            scale = (N * (x * q).sum() - x.sum() * q.sum()) / D
            min = (x.sum() * (q**2).sum() - q.sum() * (x * q).sum()) / D
        return scale, min



In [78]:
x = 10 * np.random.randn(8)

In [102]:
q1 = quantizer(n_bits=4, use_min=True)
# q1 = quantizer(n_bits=4, use_min=False)
q, scale0, min0 = q1.quantize(x)
print("x:", x)
print("q:", q)
print("scale0:", scale0, "min0:", min0)
dq0 = q1.dequantize(q, scale0, min0)
print ("dq0:", dq0)
mse0 = q1.calc_mse(x, dq0)
print("mse0:", mse0)
scale1, min1 = q1.recalc_scale_and_min(x, q)
print("scale1:", scale1, "min1:", min1)
dq1 = q1.dequantize(q, scale1, min1)
print ("dq1:", dq1)
mse1 = q1.calc_mse(x, dq1)
print("mse1:", mse1)

x: [ 5.63109251 -0.66333746 -3.7742885  -5.53918171 -8.02601077  3.97313181
  0.47977315  3.40512398]
q: [15.  8.  5.  3.  0. 13.  9. 13.]
scale0: 0.910473552006432 min0: -8.026010772858172
dq0: [ 5.631092   -0.7422223  -3.4736428  -5.29459    -8.0260105   3.8101454
  0.16825104  3.8101454 ]
mse0: 0.05551108470975463
scale1: 0.9171847297558767 min1: -8.130986144472903
dq1: [ 5.6267843 -0.7935085 -3.5450625 -5.379432  -8.130986   3.7924147
  0.1236763  3.7924147]
mse1: 0.05193813484932613


Weighted quantization. the optimization function becomes:
$$
\text{Error}(d) = \sum_{j=0}^{N-1} w_j \left( x_j - d \cdot q_j \right)^2
$$

$$
d = \frac{\sum_{j=0}^{N-1} w_j \cdot x_j \cdot q_j}{\sum_{j=0}^{N-1} w_j \cdot q_j^2}
$$

in `use_min` case, it becomes:
$$
E = \sum_{i=0}^{n-1} w_i \left( \text{scale} \cdot l_i + \text{min} - x_i \right)^2
$$

$$
\text{scale} = \frac{\sum_{i=0}^{N-1} w_il_ix_i \sum_{i=0}^{N-1} w_i - \sum_{i=0}^{N-1}w_ix_i\sum_{i=0}^{N-1}w_il_i }{\sum_{i=0}^{N-1} w_i \sum_{i=0}^{N-1}w_il_i^2 - (\sum_{i=0}^{N-1}w_il_i)^2}
$$

$$
\text{min} = \frac{\sum_{i=0}^{N-1} w_ix_i \sum_{i=0}^{N-1} w_il_i^2- \sum_{i=0}^{N-1}w_ix_il_i\sum_{i=0}^{N-1}w_il_i }{\sum_{i=0}^{N-1}w_i \sum_{i=0}^{N-1}w_il_i^2 - (\sum_{i=0}^{N-1}w_il_i)^2}
$$

In [117]:
class quantizer_weighted:
    """
    mse is calculated weighted, search the optimal scale and min  
    """

    def __init__(self,  n_bits: int = 8, use_min: bool = False) -> None:
        """
        Args:
            n_bits (int): number of bits
            bsz (int): block size
        """
        self.n_bits = n_bits
        self.q_range = 2**n_bits - 1
        self.q_max = 2**(n_bits-1) - 1
        self.q_min = -2**(n_bits-1)
        self.use_min = use_min

    def quantize(self, x: np.ndarray, offset: float) -> tuple[np.ndarray, float, float]:
        if not self.use_min:
            x_abs = np.abs(x)
            x_abs_max = np.max(x_abs)
            scale = (x_abs_max + offset) / self.q_min
            q = np.clip(np.round(x / scale), self.q_min, self.q_max)
            min = None
        else:
            scale = (np.max(x) - np.min(x) + offset) / self.q_range
            min = np.min(x)
            q = np.clip(np.round((x - min) / scale), 0, self.q_range)
        return q, scale, min

    def dequantize(self, q: np.ndarray, scale: float, min: float) -> np.ndarray:
        q = q.astype(np.float32)
        if not self.use_min:
            dq = q * scale # dequantize q
        else:
            dq = q * scale + min
        return dq

    def calc_mse(self, x: np.ndarray, dq: np.ndarray, w: np.ndarray) -> float:
        assert x.shape == dq.shape
        assert w.shape == x.shape
        return np.mean(w * (x - dq)**2)
    
    # recalculation of scale and min according to the formula above
    def recalc_scale_and_min(self, x: np.ndarray, q: np.ndarray, w: np.ndarray) -> tuple[float, float]:
        assert x.shape == q.shape
        assert x.shape == w.shape
        if not self.use_min:
            scale = (w * x * q).sum() / (w * q**2).sum()
            min = None
        else:
            # get the element num of q
            D = w.sum() * (w * q**2).sum() - ((w * q).sum())**2
            scale = (w.sum() * (w * x * q).sum() - (w * x).sum() * (w * q).sum()) / D
            min = ((w * x).sum() * (w * q**2).sum() - (w * q).sum() * (w * x * q).sum()) / D
        return scale, min
    
    def search_scale_and_min(self, x: np.ndarray, w: np.ndarray):
        best_mse1 = np.inf
        best_off = 0
        for off in range(-10, 11):
            q, scale0, min0 = self.quantize(x, off*0.1)
            scale1, min1 = self.recalc_scale_and_min(x, q, w)
            dq0 = self.dequantize(q, scale0, min0)
            dq1 = self.dequantize(q, scale1, min1)
            mse0 = self.calc_mse(x, dq0, w)
            mse1 = self.calc_mse(x, dq1, w)
            if mse1 < best_mse1:
                best_mse1 = mse1
                best_off = off
            print("off:", off, "mse0:", mse0,  "mse1:", mse1)
            print("q:", q)
            print("scale0:", scale0, "min0:", min0)
            print("scale1:", scale1, "min1:", min1)
            print("dq0:", dq0)
            print("dq1:", dq1)
            print()
        print("best_off:", best_off)
        print("best_mse1:", best_mse1)



In [111]:
# w = np.ones_like(x)
w = 10 * np.random.rand(8)

In [118]:
# x = 10 * np.random.randn(8)
qtr = quantizer_weighted(n_bits=4, use_min=True)
# qtr = quantizer_weighted(n_bits=4, use_min=False)
qtr.search_scale_and_min(x, w)
print("x:", x)

off: -10 mse0: 1.0040145708848196 mse1: 0.7030886704500996
q: [15.  9.  5.  3.  0. 14. 10. 14.]
scale0: 0.8438068853397653 min0: -8.026010772858172
scale1: 0.8936402408642165 min1: -8.384840698270105
dq0: [ 4.631092   -0.43174887 -3.8069763  -5.49459    -8.0260105   3.7872858
  0.41205788  3.7872858 ]
dq1: [ 5.019762   -0.34207916 -3.9166398  -5.7039204  -8.384841    4.1261225
  0.55156136  4.1261225 ]

off: -9 mse0: 0.848790786622413 mse1: 0.44622695493893827
q: [15.  9.  5.  3.  0. 14. 10. 13.]
scale0: 0.850473552006432 min0: -8.026010772858172
scale1: 0.910816281149878 min1: -8.47379406539703
dq0: [ 4.7310925  -0.37174892 -3.773643   -5.47459    -8.0260105   3.880619
  0.47872448  3.0301456 ]
dq1: [ 5.18845   -0.2764473 -3.9197125 -5.7413454 -8.473794   4.2776337
  0.6343689  3.3668175]

off: -8 mse0: 0.7123609377117937 mse1: 0.44622695493893827
q: [15.  9.  5.  3.  0. 14. 10. 13.]
scale0: 0.8571402186730986 min0: -8.026010772858172
scale1: 0.910816281149878 min1: -8.47379406539703
