### Quantization

##### 1 Dimensional Uniform Quantization

Given a uniform quantization grid, we aim to find the quantized vector.

In [16]:
import torch

v = torch.tensor([3.2, -1.4, 2.5, -0.9, 1.8, -3.7, 0.0, 4.0, 2.2, -1.3])

def quantization(v):
    w = (v + 4) * 7 
    a = w // 8
    b = w % 8
    b[b >= 4] = 1
    b[b < 4] = 0
    a += b
    a[a >= 8] = 7
    a[a < 0] = 0
    return a.to(torch.int8)

w = quantization(v) # quantized vector
w

tensor([6, 2, 5, 2, 5, 0, 3, 7, 5, 2], dtype=torch.int8)

##### 1 Dimensional Non-Uniform Quantization

In [17]:
from scipy.stats import norm

# determine the 3 bit normal float grid [-1, 1] according to the paper QLoRA

left = [0.98, 0.82, 0.66]
right = [0.62, 0.74, 0.86, 0.98]
grid = [0]
grid = [- norm.ppf(p) for p in left] + [0] + [norm.ppf(p) for p in right]
grid = torch.tensor(grid)
grid /= torch.max(torch.abs(grid))

def quantizer(v):

    v /= torch.max(torch.abs(v))
    w = torch.zeros_like(v)
    for i in range(len(grid) - 1):
        mask = (v >= grid[i]) & (v < grid[i + 1])
        w[mask] = i
    return w.to(torch.int8)

print(quantizer(v))

tensor([6, 1, 6, 1, 5, 0, 3, 0, 6, 1], dtype=torch.int8)
