In [14]:
import numpy as np

In [18]:
# Generate randomly distributed parameters
params = np.random.uniform(low=-50, high=150, size=20)

# Make sure important values are at the beginning for better debugging
params[0] = params.max() + 1
params[1] = params.min() - 1
params[2] = 0

params = np.round(params, 2)
print(params)

[149.35 -50.37   0.    22.8  137.8  -15.66 148.35 -26.68  17.28  12.57
  18.5   94.82  44.14 -49.37  94.86 138.89  82.92 -21.36  72.43  -6.14]


## Define the Quantization Methods

 - Asymmetric Quantization
 - Asymmetric Dequantization
 - Symmetric Quantization
 - Symmetric Dequantization
 - Quantization Error

In [19]:
def clamp(params_q: np.array, lower_bound: int, upper_bound: int) -> np.array:
    params_q[params_q < lower_bound] = lower_bound
    params_q[params_q > upper_bound] = upper_bound
    return params_q

def asymmetric_quantization(params: np.array, bits: int) -> tuple[np.array, float, int]:
    alpha = np.max(params)
    beta = np.min(params)
    scale = (alpha - beta) / (2**bits-1)
    zero = -1*np.round(beta / scale)
    lower_bound, upper_bound = 0, 2**bits-1
    # Quantize the parameters
    quantized = clamp(np.round(params / scale + zero), lower_bound, upper_bound).astype(np.int32)
    return quantized, scale, zero

def asymmetric_dequantize(params_q: np.array, scale: float, zero: int) -> np.array:
    return (params_q - zero) * scale

def symmetric_dequantize(params_q: np.array, scale: float) -> np.array:
    return params_q * scale

def symmetric_quantization(params: np.array, bits: int) -> tuple[np.array, float]:
    alpha = np.max(np.abs(params))
    scale = alpha / (2**(bits-1)-1)
    lower_bound = -2**(bits-1)
    upper_bound = 2**(bits-1)-1
    # Quantize the parameters
    quantized = clamp(np.round(params / scale), lower_bound, upper_bound).astype(np.int32)
    return quantized, scale

def quantization_error(params: np.array, params_q: np.array):
    # calculate the MSE
    return np.mean((params - params_q)**2)

(asymmetric_q, asymmetric_scale, asymmetric_zero) = asymmetric_quantization(params, 8)
(symmetric_q, symmetric_scale) = symmetric_quantization(params, 8)

print(f'Original:')
print(np.round(params, 2))
print('------------------------')
print(f'Asymmetric scale: {asymmetric_scale}, zero: {asymmetric_zero}')
print(asymmetric_q)
print('------------------------')
print(f'Symmetric scale: {symmetric_scale}')
print(symmetric_q)

Original:
[149.35 -50.37   0.    22.8  137.8  -15.66 148.35 -26.68  17.28  12.57
  18.5   94.82  44.14 -49.37  94.86 138.89  82.92 -21.36  72.43  -6.14]
------------------------
Asymmetric scale: 0.7832156862745098, zero: 64.0
[255   0  64  93 240  44 253  30  86  80  88 185 120   1 185 241 170  37
 156  56]
------------------------
Symmetric scale: 1.175984251968504
[127 -43   0  19 117 -13 126 -23  15  11  16  81  38 -42  81 118  71 -18
  62  -5]


In [20]:
params_deq_asymmetric = asymmetric_dequantize(asymmetric_q, asymmetric_scale, asymmetric_zero)
params_deq_symmetric = symmetric_dequantize(symmetric_q, symmetric_scale)

print(f'Original:')
print(np.round(params, 2))
print('--------------------')
print(f'Dequantize Asymmetric:')
print(np.round(params_deq_asymmetric,2))
print('---------------------')
print(f'Dequantize Symmetric:')
print(np.round(params_deq_symmetric, 2))

Original:
[149.35 -50.37   0.    22.8  137.8  -15.66 148.35 -26.68  17.28  12.57
  18.5   94.82  44.14 -49.37  94.86 138.89  82.92 -21.36  72.43  -6.14]
--------------------
Dequantize Asymmetric:
[149.59 -50.13   0.    22.71 137.85 -15.66 148.03 -26.63  17.23  12.53
  18.8   94.77  43.86 -49.34  94.77 138.63  83.02 -21.15  72.06  -6.27]
---------------------
Dequantize Symmetric:
[149.35 -50.57   0.    22.34 137.59 -15.29 148.17 -27.05  17.64  12.94
  18.82  95.25  44.69 -49.39  95.25 138.77  83.49 -21.17  72.91  -5.88]


In [21]:
# Calculate the quantization error
print(f'{"Asymmetric error: ":>20}{np.round(quantization_error(params, params_deq_asymmetric), 2)}')
print(f'{"Symmetric error: ":>20}{np.round(quantization_error(params, params_deq_symmetric), 2)}')

  Asymmetric error: 0.03
   Symmetric error: 0.11
