**Simplified python reference implementation of hqq-qdora.**

This nb only contains the final module, not the build up to it. For the full build up, see `python_hqq_qdora_v2.ipynb` and `python_hqq_qdora.ipynb`.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import tensor, cat, int32, float16 as fp16, bfloat16 as bf16
from math import ceil

from fastcore.foundation import L
from fastcore.basics import store_attr, AttrDict

from hqq.core.quantize import Quantizer, HQQLinear, BaseQuantizeConfig # Quantizer - optional, only for optimizing during quanting ; HQQLinear & BaseQuantizeConfig to verify our implementation

torch.set_printoptions(linewidth=200, precision=2, sci_mode=False)

In [None]:
def max_abs_diff(a,b): return (a-b).abs().max()
def assert_close(a,b): assert torch.isclose(a,b,atol=1e-2).all(), f'assert_close failed, max error = {max_abs_diff(a,b)}'
def assert_somehow_close(a,b,max_err=0.12): assert torch.isclose(a,b,atol=max_err).all(), f'assert_somehow_close failed, max error = {max_abs_diff(a,b)}' # allow some error due to quanting

In [None]:
m,r,n = 128,32,128
axis=1

In [None]:
# adapted from `python_hqq_qdora_clean.ipynb`
# changes
# - add axis param, to allow using axis=1
# - unpack and _dequant keep data on device it's on
# - only dequant/unpack ourselves, and use hqq for quanting/packing
# - zero/scale are not packed (as is hqq)
# - a (of lora) is initialized correctly

class QuantedDoraModule(nn.Module):
    def __init__(self, hqq_linear, rank, alpha):
        super().__init__()
        self.device = hqq_linear.device        
        self.hqq_linear = hqq_linear
        # save all metadata, we need for dequanting
        meta, meta_zero, meta_scale = AttrDict(hqq_linear.meta),  AttrDict(hqq_linear.meta['meta_zero']),  AttrDict(hqq_linear.meta['meta_scale'])
        self.qdata = hqq_linear.W_q.data
        self.qzero, self.qscale  = (meta.zero_scale[0], meta.zero_scale[1]) if 'zero_scale' in meta else meta.zero_q, meta.scale_q
        self.zeros_zero,  self.zeros_scale  = meta_zero.zero, meta_zero.scale        
        self.scales_zero, self.scales_scale = meta_scale.zero, meta_scale.scale
        self.data_shape, self.zero_shape, self.scale_shape = meta.shape, meta_zero.shape, meta_scale.shape
        self.group_size, self.group_size_zero, self.group_size_scale = meta.group_size, meta_zero.group_size, meta_scale.group_size
        self.compute_dtype = meta.compute_dtype
        self.axis = meta.axis
        # for dora
        self.a = nn.Linear(hqq_linear.in_features, rank,  bias=False, dtype=self.compute_dtype, device=self.device)
        self.b = nn.Linear(rank, hqq_linear.out_features, bias=False, dtype=self.compute_dtype, device=self.device)        
        self.alpha = alpha
        self.m = nn.Parameter(hqq_linear.dequantize().norm(p=2, dim=1))
        # init a & b
        self.a.weight.data = torch.randn(rank, hqq_linear.in_features).to(dtype=self.compute_dtype, device=self.device) / (rank**0.5)
        self.b.weight.data.zero_()

    @staticmethod
    def unpack(packed):
        def bin_to_dec(b3,b2,b1): return 4*b3 + 2*b2 + b1
        assert len(packed.shape)==2 and packed.dtype==int32, 'Pass a 2d tensor of int32s'
        n_packs, cols = packed.shape
        padded_vals = torch.zeros(n_packs*10, cols, dtype=int32)
        for k_up, k_down in zip(range(10), reversed(range(10))): # top-most 3bits vals (k_up=0) are most right-shifted (k_down=9)
            padded_vals[k_down*n_packs:(k_down+1)*n_packs,:] = ((packed >> (3*k_up)) & 0b111) # righ-shift 3*k_up times, so last 3 bits are those we want; then only select those via 0b111
        return padded_vals.to(packed.device)
    
    @staticmethod
    def _dequant(data, zero, scale, shape, group_size, packed, axis):
        if packed:
            data = QuantedDoraModule.unpack(data)
            rows = group_size if axis==0 else shape.numel()//group_size
            data = data[:rows,:]  # removed padded rows that were added for packing (which required row num to be multiple of 10) 
        data = (data-zero)*scale
        return data.reshape(shape).to(data.device)

    def dequant(self):
        zero  = self._dequant(self.qzero,  self.zeros_zero,  self.zeros_scale,  self.zero_shape,  self.group_size_zero,  packed=False, axis=0) # zero/scale are uint8, so don't require unpacking
        scale = self._dequant(self.qscale, self.scales_zero, self.scales_scale, self.scale_shape, self.group_size_scale, packed=False, axis=0)
        return  self._dequant(self.qdata,  zero,             scale,             self.data_shape,  self.group_size,       packed=True,  axis=self.axis).to(self.compute_dtype)
    
    def forward(self, x):
        x = F.linear(x, self.dequant()) + self.b(self.a(x)) # use F.linear for batched matmul works
        col_norms =  (self.dequant() + self.b.weight @ self.a.weight).norm(p=2, dim=1).detach()
        x /= col_norms
        x *= self.m * self.alpha
        return x

In [None]:
b = 32       # batch size
m,n = 16,8 # out, in
r = 2       # lora rank

gz = 64     # group size

base_linear = nn.Linear(n,m,bias=False, dtype=bf16, device='cuda')

ngroups = base_linear.weight.numel()//gz

# equals BaseQuantizeConfig(nbits=3, group_size=gz, quant_zero=True, quant_scale=True, offload_meta=True, view_as_float=True),
# but with group_size for scale & zero set to ngroups, instead of default 128
quant_cfg = dict(
    # note: Kerem used view_as_float=True, which stores quanted, packed weights as compute_dtype (for us: bf16) instead of int32
    weight_quant_params = dict(nbits=3, group_size=gz,      bitpack=True, optimize=True, axis=axis),
    # note: hqq sets nbits for scale/zero to 8, regardless of nbits for weights; nbits=3 result in error further below     
    scale_quant_params  = dict(nbits=8, group_size=ngroups, bitpack=True, optimize=False), 
    zero_quant_params   = dict(nbits=8, group_size=ngroups, bitpack=True, optimize=False),
    offload_meta = False # note: 1) Kerem used offload_meta=True; 2) offload_meta=True concats meta['zero_q'] & meta['scale_q'] together into meta['zero_scale']
)
hqq_linear = HQQLinear(base_linear, quant_cfg, compute_dtype=bf16)
assert hqq_linear.W_q.dtype==int32

In [None]:
qdora_linear = QuantedDoraModule(hqq_linear, r, 1.0); qdora_linear

QuantedDoraModule(
  (hqq_linear): HQQLinear(in_features=8, out_features=16, bias=False)
  (a): Linear(in_features=8, out_features=2, bias=False)
  (b): Linear(in_features=2, out_features=16, bias=False)
)

In [None]:
x = torch.randn((b, n), device='cuda', dtype=bf16) # batched input
y = base_linear(x)
y_qdora = qdora_linear(x)

In [None]:
print(f'Shapes: x = {x.shape} ; y = {y.shape} ; y_qdora = {y_qdora.shape}')

Shapes: x = torch.Size([32, 8]) ; y = torch.Size([32, 16]) ; y_qdora = torch.Size([32, 16])


In [None]:
print('quanted result (with packing):')
print(y_qdora.data)
print()
print('exact result :')
print(y.data)
print()
assert_somehow_close(y_qdora, y, max_err=0.3)
print(f'Max error is {max_abs_diff(y_qdora, y):.2f} ✓')

quanted result (with packing):
tensor([[    -0.18,     -0.18,      0.13,     -0.24,     -0.14,     -0.29,      0.04,     -0.10,     -0.24,      0.27,      0.37,      0.16,     -0.33,     -0.00,      0.17,      0.12],
        [    -0.40,     -0.61,     -0.12,     -0.27,     -0.41,     -0.15,      0.04,     -0.06,      0.36,     -0.84,     -0.29,      0.42,     -0.42,     -0.74,      0.53,      0.73],
        [     0.38,     -0.03,      0.38,      0.86,      0.37,     -0.72,      0.08,      0.04,     -0.15,     -0.26,     -0.90,      0.25,     -0.19,      0.03,      0.13,      0.34],
        [     0.20,      0.15,      0.04,      0.63,      0.81,     -1.13,     -0.48,     -0.41,      0.08,     -0.74,     -0.88,      0.64,     -0.16,     -0.52,      0.47,      0.47],
        [    -0.11,     -0.72,     -0.39,      1.24,     -0.76,      0.38,      0.31,      0.25,      0.76,     -1.09,     -1.74,     -0.31,     -0.01,     -0.08,      0.05,      0.89],
        [     0.56,      0.39,      1.5

Let's call backwards on the model, to make sure it runs.

In [None]:
def trainable_params(model): return L((name,p) for name,p in qdora_linear.named_parameters() if p.requires_grad)

In [None]:
assert set(trainable_params(qdora_linear).itemgot(0)) == {'m','a.weight','b.weight'} # assert only the dora part is trainable

In [None]:
loss = y_qdora.sum() # arbitrary operation to make y_qdora a scalar
loss

tensor(-0.19, device='cuda:0', dtype=torch.bfloat16, grad_fn=<SumBackward0>)

In [None]:
loss.backward()

In [None]:
print('Loss shapes:')
for name,p in trainable_params(qdora_linear): print(f'Shape of grad of {name:<8} is {str(list(p.grad.shape)):<7}; shape of  {name:<8} is {list(p.shape)}')

Loss shapes:
Shape of grad of m        is [16]   ; shape of  m        is [16]
Shape of grad of a.weight is [2, 8] ; shape of  a.weight is [2, 8]
Shape of grad of b.weight is [16, 2]; shape of  b.weight is [16, 2]


In [None]:
for name,p in trainable_params(qdora_linear): print(f'--- {name}.grad:\n{p.grad}')

--- m.grad:
tensor([ -2.47,   0.41,   0.48,   6.28,   3.83,  -1.70,  -5.94,  -1.81,   9.44, -13.06,  -7.28,   7.66,  -3.36,  -9.12,   4.94,  11.06], device='cuda:0', dtype=torch.bfloat16)
--- a.weight.grad:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0', dtype=torch.bfloat16)
--- b.weight.grad:
tensor([[11.00,  8.06],
        [11.00,  8.06],
        [11.00,  8.06],
        [11.00,  8.06],
        [11.00,  8.06],
        [11.00,  8.06],
        [11.00,  8.06],
        [11.00,  8.06],
        [11.00,  8.06],
        [11.00,  8.06],
        [11.00,  8.06],
        [11.00,  8.06],
        [11.00,  8.06],
        [11.00,  8.06],
        [11.00,  8.06],
        [11.00,  8.06]], device='cuda:0', dtype=torch.bfloat16)


**Let's verify our implementation against hqq:**

In [None]:
W = base_linear.weight.data; W

tensor([[     0.23,      0.17,      0.08,      0.21,     -0.13,     -0.10,      0.27,      0.07],
        [     0.05,     -0.10,      0.16,     -0.10,      0.08,     -0.03,      0.34,      0.03],
        [     0.28,      0.09,      0.24,     -0.22,     -0.30,      0.05,      0.35,      0.32],
        [    -0.14,     -0.18,      0.32,     -0.01,     -0.34,     -0.30,      0.19,     -0.19],
        [     0.12,      0.00,      0.14,     -0.10,     -0.06,     -0.19,      0.24,      0.26],
        [    -0.09,      0.06,      0.13,     -0.05,      0.32,      0.29,      0.18,     -0.31],
        [     0.08,      0.34,     -0.00,      0.01,     -0.14,     -0.02,      0.17,      0.01],
        [     0.20,      0.04,      0.15,      0.17,     -0.11,      0.16,      0.18,     -0.05],
        [    -0.15,     -0.18,      0.29,      0.05,      0.25,      0.03,     -0.26,     -0.17],
        [     0.26,     -0.09,     -0.34,     -0.10,     -0.07,      0.25,      0.30,     -0.02],
        [    -0.31, 

In [None]:
W_hqq = hqq_linear.dequantize(); W_hqq

tensor([[ 0.27,  0.17,  0.07,  0.17, -0.12, -0.12,  0.27,  0.07],
        [ 0.07, -0.12,  0.17, -0.12,  0.07, -0.03,  0.37,  0.07],
        [ 0.27,  0.07,  0.27, -0.22, -0.32,  0.07,  0.37,  0.27],
        [-0.12, -0.22,  0.37, -0.03, -0.32, -0.32,  0.17, -0.22],
        [ 0.17, -0.03,  0.17, -0.12, -0.03, -0.22,  0.27,  0.27],
        [-0.12,  0.07,  0.17, -0.03,  0.37,  0.27,  0.17, -0.32],
        [ 0.07,  0.37, -0.03, -0.03, -0.12, -0.03,  0.17, -0.03],
        [ 0.17,  0.07,  0.17,  0.17, -0.12,  0.17,  0.17, -0.03],
        [-0.15, -0.15,  0.25,  0.05,  0.25,  0.05, -0.25, -0.15],
        [ 0.25, -0.05, -0.35, -0.15, -0.05,  0.25,  0.35, -0.05],
        [-0.35, -0.25, -0.25, -0.15, -0.15,  0.35,  0.35,  0.15],
        [ 0.25, -0.05,  0.25, -0.35,  0.05,  0.05, -0.15,  0.35],
        [ 0.05,  0.25, -0.05,  0.05,  0.35, -0.15,  0.15, -0.15],
        [-0.05, -0.15, -0.15,  0.15, -0.25,  0.05,  0.35, -0.25],
        [-0.25, -0.35,  0.05,  0.25, -0.35,  0.05, -0.25,  0.15],
        [-

In [None]:
W_ours = qdora_linear.dequant(); W_ours

tensor([[ 0.27,  0.17,  0.07,  0.17, -0.12, -0.12,  0.27,  0.07],
        [ 0.07, -0.12,  0.17, -0.12,  0.07, -0.03,  0.37,  0.07],
        [ 0.27,  0.07,  0.27, -0.22, -0.32,  0.07,  0.37,  0.27],
        [-0.12, -0.22,  0.37, -0.03, -0.32, -0.32,  0.17, -0.22],
        [ 0.17, -0.03,  0.17, -0.12, -0.03, -0.22,  0.27,  0.27],
        [-0.12,  0.07,  0.17, -0.03,  0.37,  0.27,  0.17, -0.32],
        [ 0.07,  0.37, -0.03, -0.03, -0.12, -0.03,  0.17, -0.03],
        [ 0.17,  0.07,  0.17,  0.17, -0.12,  0.17,  0.17, -0.03],
        [-0.15, -0.15,  0.25,  0.05,  0.25,  0.05, -0.25, -0.15],
        [ 0.25, -0.05, -0.35, -0.15, -0.05,  0.25,  0.35, -0.05],
        [-0.35, -0.25, -0.25, -0.15, -0.15,  0.35,  0.35,  0.15],
        [ 0.25, -0.05,  0.25, -0.35,  0.05,  0.05, -0.15,  0.35],
        [ 0.05,  0.25, -0.05,  0.05,  0.35, -0.15,  0.15, -0.15],
        [-0.05, -0.15, -0.15,  0.15, -0.25,  0.05,  0.35, -0.25],
        [-0.25, -0.35,  0.05,  0.25, -0.35,  0.05, -0.25,  0.15],
        [-

In [None]:
print(f'Max abs diff between: W     and W_hqq : {max_abs_diff(W, W_hqq):.2f}')
print(f'Max abs diff between: W     and W_ours: {max_abs_diff(W, W_ours):.2f}')
print(f'Max abs diff between: W_hqq and W_ours: {max_abs_diff(W_hqq, W_ours):.2f}')

Max abs diff between: W     and W_hqq : 0.05
Max abs diff between: W     and W_ours: 0.05
Max abs diff between: W_hqq and W_ours: 0.00


In [None]:
assert_somehow_close(W_ours, W_hqq, max_err=0.03)