In [2]:
import numpy as np

if `use_min` flag not set, only scales is used:
$$
s = \frac{|x|_{max}}{-2^{n-1}} \\
$$
the conversion is below:
$$
q = \frac{x}{s} \\ 
dq = q \cdot s
$$

if it is set, both scales and mins are used:
$$
s = \frac{x_{max} - x_{min}}{2^n - 1} \\
m = x_{min}
$$
the conversion is below:
$$
q = \frac{x - x_{min}}{s} \\ 
dq = s \cdot q + x_{min}
$$

In [3]:
class quantizer:
    def __init__(self, n_bits: int = 8, bsz: int = 32, use_min: bool = False) -> None:
        """
        Args:
            n_bits (int): number of bits
            bsz (int): block size
        """
        self.n_bits = n_bits
        self.q_range = 2**n_bits - 1
        self.q_max = 2**(n_bits-1) - 1
        self.q_min = -2**(n_bits-1)
        self.bsz = bsz
        self.use_min = use_min
        self.nb = None # number of blocks
        self.x = None
        self.q = None
        self.dq = None
        self.scales = None
        self.mins = None

    def quantize_block(self, x: np.ndarray) -> tuple[np.ndarray, float, float]:
        if not self.use_min:
            x_abs = np.abs(x)
            x_abs_max = np.max(x_abs)
            scale = x_abs_max / self.q_min
            q = np.clip(np.round(x / scale), self.q_min, self.q_max)
            min = None
        else:
            scale = (np.max(x) - np.min(x)) / self.q_range
            min = np.min(x)
            q = np.clip(np.round((x - min) / scale), 0, self.q_range)
        return q, scale, min

    def dequantize_block(self, q: np.ndarray, scale: float, min: float) -> np.ndarray:
        q = q.astype(np.float32)
        if not self.use_min:
            dq = q * scale # dequantize q
        else:
            dq = q * scale + min
        return dq
    
    def quantize(self, x: np.ndarray) -> np.ndarray:
        orig_shape = x.shape
        self.x = x.reshape(-1, self.bsz) # (nb, bsz), nb: number of blocks, bsz: block size
        self.q = np.zeros_like(self.x)
        self.nb = self.x.shape[0]
        self.scales = np.zeros(self.nb)
        self.mins = np.zeros(self.nb)
        # loop over all blocks, quantize each block, return quantized block and the scale of this block
        for i in range(self.nb):
            self.q[i], self.scales[i], self.mins[i] = self.quantize_block(self.x[i])
        return self.q.reshape(orig_shape)

    def dequantize(self, q: np.ndarray) -> np.ndarray:
        orig_shape = q.shape
        self.q = q.reshape(-1, self.bsz)
        self.dq = np.zeros_like(self.q)
        for i in range(self.dq.shape[0]):
            self.dq[i] = self.dequantize_block(self.q[i], self.scales[i], self.mins[i])
        return self.dq.reshape(orig_shape)

    # calcualte mse
    def calc_mse(self) -> float:
        return np.mean((self.x - self.dq) ** 2)
    # def calc_mse(self, x: np.ndarray, y: np.ndarray) -> float:
    #     return np.mean((x - y) ** 2)

In [4]:
x = 10 * np.random.randn(4096)

In [9]:
q1 = quantizer(n_bits=4, bsz=32, use_min=False)
q = q1.quantize(x)
dq = q1.dequantize(q)
print("x: ", x)
print("q: ", q)
print("dq: ", dq)
print("mse: ", q1.calc_mse())
print("q.max:", q.max())
print("q.min:", q.min())

x:  [-8.34008678 -3.48670848 -9.30686668 ... -4.27291433 -9.67298438
  0.22511705]
q:  [ 2.  1.  3. ...  1.  3. -0.]
dq:  [ -7.37396145  -3.68698072 -11.0609417  ...  -3.32198882  -9.96596622
   0.        ]
mse:  0.8757239981009514
q.max: 7.0
q.min: -8.0


Use no-min methods to quantize scales and mins.

In [27]:
class quantizer_2_order(quantizer):
    """
    2-order quantizer, the scales and mins of blocks also be quantized 
    s_bits: number of bits for scales and mins quantization
    """
    def __init__(self, s_bits: int = 4, n_bits: int = 8, bsz: int = 32, use_min: bool = False) -> None:
        super().__init__(n_bits, bsz, use_min)
        self.s_bits = s_bits
        self.qs_max = 2 ** (s_bits - 1) - 1
        self.qs_min = - 2 ** (s_bits - 1)
        self.scale_scales = None
        self.scale_mins = None
        self.q_scales = None
        self.q_mins = None

    def quantize(self, x: np.ndarray) -> np.ndarray:
        q = super().quantize(x)

        # use no-min methods to quantize scales and mins
        scales_abs = np.abs(self.scales)
        scales_abs_max = np.max(scales_abs)
        self.scale_scales = scales_abs_max / self.qs_min
        self.q_scales = np.clip(np.round(self.scales / self.scale_scales), self.qs_min, self.qs_max)

        mins_abs = np.abs(self.mins)
        mins_abs_max = np.max(mins_abs)
        self.scale_mins = mins_abs_max / self.qs_min
        self.q_mins = np.clip(np.round(self.mins / self.scale_mins), self.qs_min, self.qs_max)
        return q

    def dequantize(self, q: np.ndarray) -> np.ndarray:
        # use the dequantized scales and mins to overwrite the original scales and mins in the super class
        self.scales = self.q_scales * self.scale_scales
        self.mins = self.q_mins * self.scale_mins
        dq = super().dequantize(q)
        return dq

In [33]:
q1 = quantizer_2_order(s_bits=6, n_bits=4, bsz=32, use_min=True)
q = q1.quantize(x)
dq = q1.dequantize(q)
print("x: ", x)
print("q: ", q)
print("dq: ", dq)
print("mse: ", q1.calc_mse())
print("q.max:", q.max())
print("q.min:", q.min())

x:  [  2.70740678  -8.25383121 -12.58215199 ...   2.41243454 -10.60010584
   3.9098263 ]
q:  [9. 5. 4. ... 8. 3. 8.]
dq:  [  2.31284523  -8.34332657 -11.00736904 ...   2.74177742 -11.15757561
   2.74177742]
mse:  0.7194974231267448
q.max: 15.0
q.min: 0.0
