In [1]:
!which python

/nfs/scistore19/alistgrp/stabesh/micromamba/envs/quest/bin/python


In [2]:
import torch

In [3]:
def ste_round(x):
    return (x.round() - x).detach() + x


x = torch.randn(4, 4)
x[0, 0] = 1
x = x*10
x = x.requires_grad_(True)
print(x)
y = ste_round(x)
# y = torch.round(x)
print(y)
y.sum().backward()
print(x.grad)


tensor([[10.0000, 11.5285, -5.9167,  9.1315],
        [-9.4046, -3.7989,  6.5704, -9.5838],
        [-7.1115, 18.2400, -0.9039, -3.0564],
        [14.4162, -2.6938, -5.0514,  7.3895]], requires_grad=True)
tensor([[ 10.,  12.,  -6.,   9.],
        [ -9.,  -4.,   7., -10.],
        [ -7.,  18.,  -1.,  -3.],
        [ 14.,  -3.,  -5.,   7.]], grad_fn=<AddBackward0>)
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])


In [7]:
from torch import nn
from fast_hadamard_transform import hadamard_transform

OPTIMAL_GAUSSIAN_SCALES = {
    1: 0.7978845587140913,
    2: 1.4935346200015913,
    3: 2.051068354131873,
    4: 2.513930578568423,
    5: 2.9160938834961225,
    6: 3.276597282593217,
    7: 3.6010497188221655,
    8: 3.884938678807525,
}


class SimplifiedHadamardClipQuantizer(nn.Module):
    aux_matrix = hadamard_transform(
        torch.eye(128, dtype=torch.bfloat16, device="cuda"), scale=2 ** (-7 / 2)
    )

    def __init__(self, bits=4):
        super().__init__()
        self.bits = bits
        self.n_levels = 2**bits
        self.matrix = None

    def forward1(self, x):
        if self.matrix is None:
            self.matrix = torch.block_diag(
                *[self.aux_matrix.to(x.device).to(x.dtype)] * (x.shape[-1] // 128),
            )

        x_had = x @ self.matrix
        with torch.no_grad():
            scale = (
                OPTIMAL_GAUSSIAN_SCALES[self.bits]
                * torch.sqrt(torch.mean(x_had**2, dim=-1, keepdim=True))
                + 1e-8
            )
            step = 2 * scale / (self.n_levels - 1)
            x_clip = torch.clamp(x_had, -scale, scale)
            xq = torch.round(x_clip / step + 1 / 2) * step - step / 2
            mask = (torch.abs(x_had) <= scale).float()
            xq = xq @ self.matrix.T

        grad_flow_output = (x_had * mask) @ self.matrix.T

        return grad_flow_output + (xq - grad_flow_output).detach()

    @staticmethod
    def ste_round(x):
        return (x.round() - x).detach() + x

    def forward2(self, x):
        if self.matrix is None:
            self.matrix = torch.block_diag(
                *[self.aux_matrix.to(x.device).to(x.dtype)] * (x.shape[-1] // 128),
            )

        x_had = x @ self.matrix

        with torch.no_grad():
            scale = (
                OPTIMAL_GAUSSIAN_SCALES[self.bits]
                * torch.sqrt(torch.mean(x_had**2, dim=-1, keepdim=True))
                + 1e-8
            )
            step = 2 * scale / (self.n_levels - 1)

        x_clip = torch.clamp(x_had, -scale, scale)
        xq = self.ste_round(x_clip / step + 1 / 2) * step - step / 2
        xq = xq @ self.matrix.T

        return xq


In [9]:
def test_forward1_vs_forward2():

    # We define multiple distributions and shapes to test
    distributions = {
        "normal": lambda shape: torch.randn(shape, device="cuda"),
        "uniform": lambda shape: torch.rand(shape, device="cuda").mul(2).sub(1),  # in [-1, 1]
        "positive_uniform": lambda shape: torch.rand(shape, device="cuda"),       # in [0, 1]
    }
    shapes = [
        (32, 128),  # typical 'batch x features'
        (1, 128),   # single sample
        (64, 256),  # bigger batch and features
    ]

    for dist_name, dist_func in distributions.items():
        for shape in shapes:
            q = SimplifiedHadamardClipQuantizer(bits=4).cuda()

            torch.manual_seed(0)  # reset seed for reproducibility
            print(f"Testing dist='{dist_name}', shape={shape}")

            # Create input x
            x_init = dist_func(shape).requires_grad_(True)

            # forward1
            x1 = x_init.clone().detach().requires_grad_(True)
            out1 = q.forward1(x1)

            # forward2
            x2 = x_init.clone().detach().requires_grad_(True)
            out2 = q.forward2(x2)

            # Compare forward outputs
            same_output = torch.allclose(out1, out2, atol=1e-6, rtol=1e-6)
            print(f"  Outputs match: {same_output}")
            if not same_output:
                max_diff_output = (out1 - out2).abs().max().item()
                print(f"  Max diff in outputs: {max_diff_output:.3e}")

            # Compare backward gradients
            out1.backward(torch.ones_like(out1))
            out2.backward(torch.ones_like(out2))

            same_grad = torch.allclose(x1.grad, x2.grad, atol=1e-6, rtol=1e-6)
            print(f"  Gradients match: {same_grad}")
            if not same_grad:
                max_diff_grad = (x1.grad - x2.grad).abs().max().item()
                print(f"  Max diff in grads: {max_diff_grad:.3e}")

            print("")

test_forward1_vs_forward2()


Testing dist='normal', shape=(32, 128)
  Outputs match: True
  Gradients match: True

Testing dist='normal', shape=(1, 128)
  Outputs match: True
  Gradients match: True

Testing dist='normal', shape=(64, 256)
  Outputs match: True
  Gradients match: True

Testing dist='uniform', shape=(32, 128)
  Outputs match: True
  Gradients match: True

Testing dist='uniform', shape=(1, 128)
  Outputs match: True
  Gradients match: True

Testing dist='uniform', shape=(64, 256)
  Outputs match: True
  Gradients match: True

Testing dist='positive_uniform', shape=(32, 128)
  Outputs match: True
  Gradients match: True

Testing dist='positive_uniform', shape=(1, 128)
  Outputs match: True
  Gradients match: True

Testing dist='positive_uniform', shape=(64, 256)
  Outputs match: True
  Gradients match: True

