In [112]:
import torch
import numpy as np

In [113]:
np.random.seed(42)

arr = np.random.uniform(low=100, high=500, size=100)

min = np.min(arr)
max = np.max(arr)

In [114]:
arr[:10]

array([249.81604754, 480.28572256, 392.79757672, 339.46339368,
       162.40745618, 162.39780813, 123.23344487, 446.47045831,
       340.4460047 , 383.22903112])

In [119]:
def clamp(arr, lower, upper):
    arr[arr < lower] = lower
    arr[arr > upper] = upper
    return arr

def asymmetric_quantize_percentile(arr, percentile=99.9, bits=4):
    max = np.percentile(arr, percentile)
    min = np.percentile(arr, 100-percentile)

    scale = (max - min) / (2**bits - 1)
    zero_point = -1 * np.round((min / scale))

    upper = (max / scale) + zero_point
    lower = (min / scale) + zero_point 

    arr = np.round(((arr / scale) + zero_point))
    arr = clamp(arr, lower, upper)
    return arr, scale, zero_point

def asymmetric_quantize(arr, bits=4):
    min = np.min(arr)
    max = np.max(arr)

    scale = (max - min) / (2**bits - 1)
    zero_point = -1 * np.round((min / scale))

    upper = (max / scale) + zero_point
    lower = (min / scale) + zero_point 

    arr = np.round(((arr / scale) + zero_point))
    arr = clamp(arr, lower, upper)
    return arr, scale, zero_point

def asymmetric_dequantize(arr, scale, zero_point):
    return (arr - zero_point) * scale

def symmetric_quantize(arr, bits=4):
    max = np.max(np.abs(arr))
    scale = max / (2**(bits-1) - 1)
    arr = np.round(arr / scale)
    lower = - 2 ** (bits-1)
    upper = 2 ** (bits-1) - 1
    arr = clamp(arr, lower, upper)
    return arr, scale

def symmetric_dequantize(arr, scale):
    return arr * scale

def find_error(arr1, arr2):
    return np.sum(np.abs(arr1 - arr2))

In [116]:
quantized, scale, zero_point = asymmetric_quantize(arr, 8)
dequantized = asymmetric_dequantize(quantized, scale, zero_point)
print(f"The error in assymetric quantization is : {find_error(dequantized, arr)}")

The error in assymetric quantization is : 35.80548018681516


In [117]:
quantized, scale = symmetric_quantize(arr, 8)
dequantized = symmetric_dequantize(quantized, scale)
print(f"The error in assymetric quantization is : {find_error(dequantized, arr)}")

The error in assymetric quantization is : 102.83849154891263


In [121]:
quantized, scale, zero_point = asymmetric_quantize_percentile(arr, 99.9, 8)
dequantized = asymmetric_dequantize(quantized, scale, zero_point)
print(f"The error in assymetric quantization is : {find_error(dequantized, arr)}")

The error in assymetric quantization is : 40.16981129339155
