In [13]:
import torch
import torch.nn as nn


class LSQQuantizer(nn.Module):
    def __init__(self, bits=4, raise_zero=True, all_positive=False, **kwargs):
        super().__init__()
        # NOTE: raise_zero should never be used with FP quantization

        self.bits = bits
        self.n_levels = 2**bits
        self.all_positive = all_positive
        self.raise_zero = raise_zero

        self.q_min, self.q_max = self.get_dtype_bounds()

        self.is_alpha_init = False
        self.alpha_weight = nn.Parameter(torch.tensor(1.0), requires_grad=True)

    def get_dtype_bounds(self):
        if not self.all_positive:
            q_min = -self.n_levels / 2
            q_max = self.n_levels / 2 - 1
        else:
            q_min = 0
            q_max = self.n_levels - 1
        return q_min, q_max

    def cast(self, x):
        # This method can be inherited to use any casting, e.g. int, fp(e2m1, e1m2,...), optimal gaussian, etc.
        # NOTE: raise_zero should never be used with FP quantization
        return x.round()

    def ste_cast(self, x):
        return (self.cast(x) - x).detach() + x
    
    def grad_scale(self, x, scale):
        return (x - x * scale).detach() + x * scale

    def forward(self, x):
        step = self.alpha_weight
        step = self.grad_scale(step, 2)
        xs = x / step
        if self.raise_zero:
            xsc = torch.clamp(xs - 1 / 2, self.q_min, self.q_max)
            xscr = self.ste_cast(xsc) + 1 / 2
        else:
            xsc = torch.clamp(xs, self.q_min, self.q_max)
            xscr = self.ste_cast(xsc)
        xq = xscr * step

        print(f"x: {x}")
        print(f"step: {step.item()}")
        print(f"xs: {xs}")
        print(f"q_min: {self.q_min}, q_max: {self.q_max}")
        print(f"xsc: {xsc}")
        print(f"xscr: {xscr}")
        print(f"xq: {xq}")

        return xq

In [14]:
quantizer = LSQQuantizer(bits=4)
quantizer.is_alpha_init = True
quantizer.alpha_weight.requires_grad = True
x = torch.tensor(1.2)
xq = quantizer(x)
xq.backward()
print(f"alpha_weight.grad: {quantizer.alpha_weight.grad}")
print(quantizer.q_min, quantizer.q_max)


x: 1.2000000476837158
step: 1.0
xs: 1.2000000476837158
q_min: -8.0, q_max: 7.0
xsc: 0.7000000476837158
xscr: 1.5
xq: 1.5
alpha_weight.grad: 0.5999999046325684
-8.0 7.0


In [136]:
class PACTQuantizer(LSQQuantizer):
    def forward(self, x):
        step = self.alpha_weight
        xs = x / step
        if self.raise_zero:
            xsc = torch.clamp(xs - 1 / 2, self.q_min, self.q_max)
            outlier_mask = ~torch.isclose(xsc, xs - 1 / 2)  # clipped values will be 1
            xscr = self.ste_cast(xsc) + 1 / 2
        else:
            xsc = torch.clamp(xs, self.q_min, self.q_max)
            outlier_mask = ~torch.isclose(xsc, xs)
            xscr = self.ste_cast(xsc)
        xq = xscr * step

        print(f"x: {x}")
        print(f"step: {step.item()}")
        print(f"xs: {xs}")
        print(f"q_min: {self.q_min}, q_max: {self.q_max}")
        print(f"xsc: {xsc}")
        print(f"outlier_mask: {outlier_mask}")
        print(f"xscr: {xscr}")
        print(f"xq: {xq}")

        return xq * outlier_mask + (xq - xq * outlier_mask).detach()


In [137]:
quantizer = PACTQuantizer(bits=4)
quantizer.is_alpha_init = True
quantizer.alpha_weight.requires_grad = True
x = torch.tensor([[1.3, 1.8], [211.1, 2.6]])
xq = quantizer(x)
xq.sum().backward()
print(f"alpha_weight.grad: {quantizer.alpha_weight.grad}")


x: tensor([[  1.3000,   1.8000],
        [211.1000,   2.6000]])
step: 1.0
xs: tensor([[  1.3000,   1.8000],
        [211.1000,   2.6000]], grad_fn=<DivBackward0>)
q_min: -8.0, q_max: 7.0
xsc: tensor([[0.8000, 1.3000],
        [7.0000, 2.1000]], grad_fn=<ClampBackward1>)
outlier_mask: tensor([[False, False],
        [ True, False]])
xscr: tensor([[1.5000, 1.5000],
        [7.5000, 2.5000]], grad_fn=<AddBackward0>)
xq: tensor([[1.5000, 1.5000],
        [7.5000, 2.5000]], grad_fn=<MulBackward0>)
alpha_weight.grad: 7.5


In [138]:
OPTIMAL_GAUSSIAN_SCALES = {
    1: 0.7978845587140913,
    1.585: 1.2240089519030855,
    2: 1.4935346200015913,
    3: 2.051068354131873,
    4: 2.513930578568423,
    5: 2.9160938834961225,
    6: 3.276597282593217,
    7: 3.6010497188221655,
    8: 3.884938678807525,
}
bits = 4
n_levels = 2**bits
scale = OPTIMAL_GAUSSIAN_SCALES[bits] * 1 + 1e-8
step = 2 * scale / (n_levels - 1)
print(scale, step)

2.513930588568423 0.3351907451424564


In [145]:
class BaseQuantizer(nn.Module):
    def __init__(self, bits=4):
        super().__init__()
        self.bits = bits
        self.n_levels = 2**bits


class STEQuantizer(BaseQuantizer):
    def __init__(self, bits=4, centered=True):
        super().__init__(bits)
        self.centered = centered

    def forward(self, x):
        scale = OPTIMAL_GAUSSIAN_SCALES[self.bits] * torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True)) + 1e-8
        if self.centered:
            step = 2 * scale / (self.n_levels - 1)
            x_clip = torch.clamp(x, -scale, scale)
            xq = torch.round(x_clip / step + 1 / 2) * step - step / 2
        else:
            step = 2 * scale / self.n_levels
            x_clip = torch.clamp(x, -scale * (self.n_levels - 2) / self.n_levels, scale)
            xq = torch.round(x_clip / step) * step
        print("scale", scale, scale.shape)
        print("step", step, step.shape)
        print("x_clip", x_clip)
        print("xq", xq)
        return x + (xq - x).detach()

quantizer = STEQuantizer(bits=4)
x = torch.tensor([[1.3, 1.8], [211.1, 2.6]])
xq = quantizer(x)

scale tensor([[  3.9470],
        [375.2835]]) torch.Size([2, 1])
step tensor([[ 0.5263],
        [50.0378]]) torch.Size([2, 1])
x_clip tensor([[  1.3000,   1.8000],
        [211.1000,   2.6000]])
xq tensor([[  1.3157,   1.8419],
        [225.1701,  25.0189]])
