**Simplified python reference implementation of hqq-qdora.**

This nb only contains the final module, not the build up to it. For the full build up, see `python_hqq_qdora_v2.ipynb` and `python_hqq_qdora.ipynb`.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import tensor, cat, int32, float16 as fp16, bfloat16 as bf16
from math import ceil

from fastcore.foundation import L
from fastcore.basics import store_attr, AttrDict

from hqq.core.quantize import Quantizer, HQQLinear, BaseQuantizeConfig # Quantizer - optional, only for optimizing during quanting ; HQQLinear & BaseQuantizeConfig to verify our implementation

torch.set_printoptions(linewidth=200, precision=2, sci_mode=False)

In [2]:
def max_abs_diff(a,b): return (a-b).abs().max()
def assert_close(a,b): assert torch.isclose(a,b,atol=1e-2).all(), f'assert_close failed, max error = {max_abs_diff(a,b)}'
def assert_somehow_close(a,b,max_err=0.12): assert torch.isclose(a,b,atol=max_err).all(), f'assert_somehow_close failed, max error = {max_abs_diff(a,b)}' # allow some error due to quanting

In [3]:
m,r,n = 128,32,128

In [4]:
# adapted from `python_hqq_qdora_clean.ipynb`
# changes
# - unpack and _dequant keep data on device it's on
# - only dequant/unpack ourselves, and use hqq for quanting/packing
# - zero/scale are not packed (as is hqq)
# - a (of lora) is initialized correctly

class QuantedDoraModule(nn.Module):
    def __init__(self, hqq_linear, rank, alpha):
        super().__init__()
        self.device = hqq_linear.device        
        self.hqq_linear = hqq_linear
        # save all metadata, we need for dequanting
        meta, meta_zero, meta_scale = AttrDict(hqq_linear.meta),  AttrDict(hqq_linear.meta['meta_zero']),  AttrDict(hqq_linear.meta['meta_scale'])
        self.qdata = hqq_linear.W_q.data
        self.qzero, self.qscale  = (meta.zero_scale[0], meta.zero_scale[1]) if 'zero_scale' in meta else meta.zero_q, meta.scale_q
        self.zeros_zero,  self.zeros_scale  = meta_zero.zero, meta_zero.scale        
        self.scales_zero, self.scales_scale = meta_scale.zero, meta_scale.scale
        self.data_shape, self.zero_shape, self.scale_shape = meta.shape, meta_zero.shape, meta_scale.shape
        self.group_size, self.group_size_zero, self.group_size_scale = meta.group_size, meta_zero.group_size, meta_scale.group_size
        self.compute_dtype = meta.compute_dtype
        # for dora
        self.a = nn.Linear(hqq_linear.in_features, rank,  bias=False, dtype=self.compute_dtype, device=self.device)
        self.b = nn.Linear(rank, hqq_linear.out_features, bias=False, dtype=self.compute_dtype, device=self.device)        
        self.alpha = alpha
        self.m = nn.Parameter(hqq_linear.dequantize().norm(p=2, dim=1))
        # init a & b
        self.a.weight.data = torch.randn(rank, hqq_linear.in_features).to(dtype=self.compute_dtype, device=self.device) / (rank**0.5)
        self.b.weight.data.zero_()

    @staticmethod
    def unpack(packed, rows):
        def bin_to_dec(b3,b2,b1): return 4*b3 + 2*b2 + b1
        assert len(packed.shape)==2 and packed.dtype==int32, 'Pass a 2d tensor of int32s'
        n_packs, cols = packed.shape
        padded_vals = torch.zeros(n_packs*10, cols, dtype=int32)
        for k_up, k_down in zip(range(10), reversed(range(10))): # top-most 3bits vals (k_up=0) are most right-shifted (k_down=9)
            padded_vals[k_down*n_packs:(k_down+1)*n_packs,:] = ((packed >> (3*k_up)) & 0b111) # righ-shift 3*k_up times, so last 3 bits are those we want; then only select those via 0b111            
        return padded_vals[:rows,:].to(packed.device)
    
    @staticmethod
    def _dequant(data, zero, scale, shape, group_size, packed):
        if packed: data = QuantedDoraModule.unpack(data, rows=group_size)
        data = (data-zero)*scale
        return data.reshape(shape).to(data.device)

    def dequant(self):
        zero  = self._dequant(self.qzero,  self.zeros_zero,  self.zeros_scale,  self.zero_shape,  self.group_size_zero,  packed=False) # zero/scale are uint8, so don't require unpacking
        scale = self._dequant(self.qscale, self.scales_zero, self.scales_scale, self.scale_shape, self.group_size_scale, packed=False)
        return  self._dequant(self.qdata,  zero,             scale,             self.data_shape,  self.group_size, packed=True).to(self.compute_dtype)
    
    def forward(self, x):
        x = F.linear(x, self.dequant()) + self.b(self.a(x)) # use F.linear for batched matmul works
        col_norms =  (self.dequant() + self.b.weight @ self.a.weight).norm(p=2, dim=1).detach()
        x /= col_norms
        x *= self.m * self.alpha
        return x

In [5]:
b = 32       # batch size
m,n = 16,8 # out, in
r = 2       # lora rank

gz = 64     # group size

base_linear = nn.Linear(n,m,bias=False, dtype=bf16, device='cuda')

ngroups = base_linear.weight.numel()//gz

# equals BaseQuantizeConfig(nbits=3, group_size=gz, quant_zero=True, quant_scale=True, offload_meta=True, view_as_float=True),
# but with group_size for scale & zero set to ngroups, instead of default 128
quant_cfg = dict(
    # note: Kerem used view_as_float=True, which stores quanted, packed weights as compute_dtype (for us: bf16) instead of int32
    weight_quant_params = dict(nbits=3, group_size=gz,      bitpack=True, optimize=True),
    # note: hqq sets nbits for scale/zero to 8, regardless of nbits for weights; nbits=3 result in error further below     
    scale_quant_params  = dict(nbits=8, group_size=ngroups, bitpack=True, optimize=False), 
    zero_quant_params   = dict(nbits=8, group_size=ngroups, bitpack=True, optimize=False),
    offload_meta = False # note: 1) Kerem used offload_meta=True; 2) offload_meta=True concats meta['zero_q'] & meta['scale_q'] together into meta['zero_scale']
)
hqq_linear = HQQLinear(base_linear, quant_cfg, compute_dtype=bf16)
assert hqq_linear.W_q.dtype==int32

In [6]:
qdora_linear = QuantedDoraModule(hqq_linear, r, 1.0); qdora_linear

QuantedDoraModule(
  (hqq_linear): HQQLinear(in_features=8, out_features=16, bias=False)
  (a): Linear(in_features=8, out_features=2, bias=False)
  (b): Linear(in_features=2, out_features=16, bias=False)
)

In [29]:
x = torch.randn((b, n), device='cuda', dtype=bf16) # batched input
y = base_linear(x)
y_qdora = qdora_linear(x)

In [31]:
print(f'Shapes: x = {x.shape} ; y = {y.shape} ; y_qdora = {y_qdora.shape}')

Shapes: x = torch.Size([32, 8]) ; y = torch.Size([32, 16]) ; y_qdora = torch.Size([32, 16])


In [32]:
print('quanted result (with packing):')
print(y_qdora.data)
print()
print('exact result :')
print(y.data)
print()
assert_somehow_close(y_qdora, y, max_err=0.3)
print(f'Max error is {max_abs_diff(y_qdora, y):.2f} ✓')

quanted result (with packing):
tensor([[    -0.20,      0.60,      0.04,     -0.60,     -0.40,     -0.07,     -0.64,     -0.74,      0.29,      0.15,     -1.36,     -0.84,     -0.55,     -0.03,      0.89,     -0.72],
        [     0.54,      0.48,      0.34,     -0.02,     -0.37,      0.17,     -0.46,      0.13,     -0.51,      0.33,     -0.73,      0.09,      0.56,     -0.01,      1.23,     -0.13],
        [    -0.56,     -0.09,     -0.72,      0.77,     -0.95,      0.41,      0.33,      1.02,     -1.45,     -0.88,      0.54,     -1.11,     -0.47,      0.01,     -1.29,     -0.38],
        [     1.23,     -0.77,     -0.75,      0.17,     -0.18,     -1.68,      0.88,      0.07,     -0.21,     -0.39,      0.82,      0.40,     -0.39,     -0.93,      0.12,      0.46],
        [     0.35,     -0.35,     -0.07,     -0.82,     -0.06,     -0.25,      0.06,     -0.02,      0.54,      0.29,     -0.37,      0.71,      0.17,     -0.29,     -0.61,      0.48],
        [    -0.22,      0.10,      0.2

Let's call backwards on the model, to make sure it runs.

In [33]:
def trainable_params(model): return L((name,p) for name,p in qdora_linear.named_parameters() if p.requires_grad)

In [34]:
assert set(trainable_params(qdora_linear).itemgot(0)) == {'m','a.weight','b.weight'} # assert only the dora part is trainable

In [35]:
loss = y_qdora.sum() # arbitrary operation to make y_qdora a scalar
loss

tensor(-28.38, device='cuda:0', dtype=torch.bfloat16, grad_fn=<SumBackward0>)

In [36]:
loss.backward()

In [37]:
print('Loss shapes:')
for name,p in trainable_params(qdora_linear): print(f'Shape of grad of {name:<8} is {str(list(p.grad.shape)):<7}; shape of  {name:<8} is {list(p.shape)}')

Loss shapes:
Shape of grad of m        is [16]   ; shape of  m        is [16]
Shape of grad of a.weight is [2, 8] ; shape of  a.weight is [2, 8]
Shape of grad of b.weight is [16, 2]; shape of  b.weight is [16, 2]


In [38]:
for name,p in trainable_params(qdora_linear): print(f'--- {name}.grad:\n{p.grad}')

--- m.grad:
tensor([ -3.88,  -1.54,  -1.24,   2.14,  -8.12,  -1.60,   0.99,   1.09, -11.69,  -8.00,  -1.25, -10.31,  -1.73,  -5.25,   2.17,  -9.12], device='cuda:0', dtype=torch.bfloat16)
--- a.weight.grad:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0', dtype=torch.bfloat16)
--- b.weight.grad:
tensor([[-12.56,  -9.69],
        [-12.56,  -9.69],
        [-12.56,  -9.69],
        [-12.56,  -9.69],
        [-12.56,  -9.69],
        [-12.56,  -9.69],
        [-12.56,  -9.69],
        [-12.56,  -9.69],
        [-12.56,  -9.69],
        [-12.56,  -9.69],
        [-12.56,  -9.69],
        [-12.56,  -9.69],
        [-12.56,  -9.69],
        [-12.56,  -9.69],
        [-12.56,  -9.69],
        [-12.56,  -9.69]], device='cuda:0', dtype=torch.bfloat16)


**Let's verify our implementation against hqq:**

In [39]:
W = base_linear.weight.data; W

tensor([[-0.16, -0.11, -0.08,  0.03,  0.17,  0.11,  0.13, -0.30],
        [-0.16,  0.34,  0.16,  0.05,  0.13, -0.26,  0.28, -0.06],
        [ 0.07,  0.20,  0.06, -0.08, -0.28,  0.17, -0.11, -0.29],
        [-0.20,  0.32, -0.03, -0.03,  0.16,  0.26,  0.24,  0.20],
        [-0.08,  0.12, -0.09, -0.25, -0.02,  0.17, -0.22, -0.08],
        [ 0.11,  0.34,  0.03,  0.18,  0.31, -0.34, -0.34,  0.29],
        [-0.16, -0.06, -0.31, -0.06, -0.11,  0.23,  0.08,  0.03],
        [-0.33,  0.10, -0.05,  0.34, -0.31,  0.18,  0.29, -0.02],
        [-0.21, -0.27,  0.15, -0.11, -0.02, -0.20, -0.21, -0.14],
        [-0.22,  0.29, -0.16,  0.02,  0.00, -0.18,  0.08, -0.34],
        [-0.22,  0.01, -0.30, -0.17,  0.27,  0.34, -0.11,  0.33],
        [-0.19, -0.06, -0.29,  0.06,  0.16,  0.03, -0.28, -0.30],
        [ 0.09, -0.04,  0.24,  0.11, -0.18,  0.20, -0.28, -0.23],
        [-0.20,  0.26,  0.18,  0.06,  0.29, -0.28, -0.06,  0.18],
        [ 0.30,  0.34,  0.10, -0.30,  0.23,  0.27,  0.16, -0.34],
        [-

In [45]:
W_hqq = hqq_linear.dequantize(); W_hqq

tensor([[-0.12, -0.15, -0.12,  0.05,  0.16,  0.15,  0.16, -0.35],
        [-0.12,  0.34,  0.16,  0.05,  0.16, -0.25,  0.25, -0.05],
        [ 0.07,  0.24,  0.07, -0.05, -0.30,  0.15, -0.12, -0.25],
        [-0.21,  0.34, -0.03, -0.05,  0.16,  0.24,  0.25,  0.24],
        [-0.12,  0.15, -0.12, -0.25, -0.03,  0.15, -0.21, -0.05],
        [ 0.07,  0.34,  0.07,  0.15,  0.34, -0.35, -0.30,  0.24],
        [-0.12, -0.05, -0.30, -0.05, -0.12,  0.24,  0.07,  0.05],
        [-0.30,  0.15, -0.03,  0.34, -0.30,  0.15,  0.25, -0.05],
        [-0.21, -0.25,  0.16, -0.15, -0.03, -0.25, -0.21, -0.15],
        [-0.21,  0.24, -0.12,  0.05, -0.03, -0.15,  0.07, -0.35],
        [-0.21,  0.05, -0.30, -0.15,  0.25,  0.34, -0.12,  0.34],
        [-0.21, -0.05, -0.30,  0.05,  0.16,  0.05, -0.30, -0.35],
        [ 0.07, -0.05,  0.25,  0.15, -0.21,  0.24, -0.30, -0.25],
        [-0.21,  0.24,  0.16,  0.05,  0.25, -0.25, -0.03,  0.15],
        [ 0.34,  0.34,  0.07, -0.35,  0.25,  0.24,  0.16, -0.35],
        [-

In [46]:
W_ours = qdora_linear.dequant(); W_ours

tensor([[-0.12, -0.15, -0.12,  0.05,  0.16,  0.15,  0.16, -0.35],
        [-0.12,  0.34,  0.16,  0.05,  0.16, -0.25,  0.25, -0.05],
        [ 0.07,  0.24,  0.07, -0.05, -0.30,  0.15, -0.12, -0.25],
        [-0.21,  0.34, -0.03, -0.05,  0.16,  0.24,  0.25,  0.24],
        [-0.12,  0.15, -0.12, -0.25, -0.03,  0.15, -0.21, -0.05],
        [ 0.07,  0.34,  0.07,  0.15,  0.34, -0.35, -0.30,  0.24],
        [-0.12, -0.05, -0.30, -0.05, -0.12,  0.24,  0.07,  0.05],
        [-0.30,  0.15, -0.03,  0.34, -0.30,  0.15,  0.25, -0.05],
        [-0.21, -0.25,  0.16, -0.15, -0.03, -0.25, -0.21, -0.15],
        [-0.21,  0.24, -0.12,  0.05, -0.03, -0.15,  0.07, -0.35],
        [-0.21,  0.05, -0.30, -0.15,  0.25,  0.34, -0.12,  0.34],
        [-0.21, -0.05, -0.30,  0.05,  0.16,  0.05, -0.30, -0.35],
        [ 0.07, -0.05,  0.25,  0.15, -0.21,  0.24, -0.30, -0.25],
        [-0.21,  0.24,  0.16,  0.05,  0.25, -0.25, -0.03,  0.15],
        [ 0.34,  0.34,  0.07, -0.35,  0.25,  0.24,  0.16, -0.35],
        [-

In [47]:
print(f'Max abs diff between: W     and W_hqq : {max_abs_diff(W, W_hqq):.2f}')
print(f'Max abs diff between: W     and W_ours: {max_abs_diff(W, W_ours):.2f}')
print(f'Max abs diff between: W_hqq and W_ours: {max_abs_diff(W_hqq, W_ours):.2f}')

Max abs diff between: W     and W_hqq : 0.05
Max abs diff between: W     and W_ours: 0.05
Max abs diff between: W_hqq and W_ours: 0.00


In [49]:
assert_somehow_close(W_ours, W_hqq, max_err=0.03)