**Simplified python reference implementation of hqq-qdora.**

This nb only contains the final module, not the build up to it. For the full build up, see `python_hqq_qdora_v2.ipynb` and `python_hqq_qdora.ipynb`.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import tensor, cat, int32, float16 as fp16
from math import ceil

from fastcore.basics import store_attr

from hqq.core.quantize import Quantizer, HQQLinear, BaseQuantizeConfig # Quantizer - optional, only for optimizing during quanting ; HQQLinear & BaseQuantizeConfig to verify our implementation

torch.set_printoptions(linewidth=200, precision=2, sci_mode=False)

In [2]:
def max_abs_diff(a,b): return (a-b).abs().max()
def assert_close(a,b): assert torch.isclose(a,b,atol=1e-2).all(), f'assert_close failed, max error = {max_abs_diff(a,b)}'
def assert_somehow_close(a,b,max_err=0.12): assert torch.isclose(a,b,atol=max_err).all(), f'assert_somehow_close failed, max error = {max_abs_diff(a,b)}' # allow some error due to quanting

In [3]:
m,r,n = 128,32,128

In [4]:
class QuantedDoraModule(nn.Module):
    def __init__(self, linear, bits, group_size, rank, alpha, compute_dtype=fp16, packed=True, optimized=True, group_size_zero=None, group_size_scale=None):
        super().__init__()
        # for quanting
        store_attr('bits,group_size,packed,optimized,compute_dtype',self)
        self.group_size_zero, self.group_size_scale = group_size_zero or 128, group_size_scale or 128 # hqq uses group size of 128 for zero & scale
        self.quant(linear.weight.data)
        # for dora
        self.a = nn.Linear(linear.in_features, rank, bias=False, dtype=fp16)
        self.b = nn.Linear(rank, linear.out_features, bias=False, dtype=fp16)
        self.alpha = alpha
        self.m = nn.Parameter(linear.weight.norm(p=2, dim=1))
        # init a & b to 0 -- a should be inited differently, but for sake of simplicity, set it to 0 as well
        self.a.weight.data.zero_()
        self.b.weight.data.zero_()

    @staticmethod
    def pack(vals):
        assert len(vals.shape)==2, 'Pass a 2d tensor'
        for v in vals.flatten(): assert 0<=v.item()<=7 and v.item()//1==v.item(), f'Value {v} can\'t be represented by 3 bits or is not an integer'    
        rows, cols = vals.shape
        n_packs = ceil(rows/10)
        padded_vals = torch.zeros(n_packs*10, cols, dtype=int32)
        padded_vals[:rows, :cols] = vals
        packed = torch.zeros(n_packs, cols, dtype=int32)
        for k in range(10): packed = (packed << 3) | padded_vals[k*n_packs:(k+1)*n_packs,:] # shift right 3 bits, then set last 3 bits to padded_vals[...,...]
        return packed

    @staticmethod
    def unpack(packed, rows):
        def bin_to_dec(b3,b2,b1): return 4*b3 + 2*b2 + b1
        assert len(packed.shape)==2 and packed.dtype==int32, 'Pass a 2d tensor of int32s'
        n_packs, cols = packed.shape
        padded_vals = torch.zeros(n_packs*10, cols, dtype=int32)
        for k_up, k_down in zip(range(10), reversed(range(10))): # top-most 3bits vals (k_up=0) are most right-shifted (k_down=9)
            padded_vals[k_down*n_packs:(k_down+1)*n_packs,:] = ((packed >> (3*k_up)) & 0b111) # righ-shift 3*k_up times, so last 3 bits are those we want; then only select those via 0b111            
        return padded_vals[:rows,:]
    
    @staticmethod
    def _quant(data, group_size, bits=3, packed=True, optimize=True):
        assert data.numel()%group_size==0, f'group_size {group_size} can\'t evenly split the data (numel = {data.numel()})'
        data = data.float().reshape(group_size,-1)
        
        min_, max_ = data.min(axis=0, keepdim=True).values, data.max(axis=0, keepdim=True).values
    
        scale = (2**bits-1) / (max_-min_) # note: hqq clamp to 2e4 to avoid half-precision problems, let's ignore that for now
        zero = -min_ * scale
    
        if optimize: data, scale, zero = Quantizer.optimize_weights(data, scale, zero, min_max=[0, 2**bits-1])
        else: data = (data * scale + zero).round()

        if packed: data = QuantedDoraModule.pack(data)
        return data, zero, 1/scale # invert scale, so in dequanting we multiply instead of divide 

    @staticmethod
    def _dequant(data, zero, scale, shape, group_size, packed=True):
        if packed: data = QuantedDoraModule.unpack(data, rows=group_size)
        data = (data-zero)*scale
        return data.reshape(shape)

    def quant(self, data):
        qdata,  zero       , scale        = self._quant(data,  self.group_size,       self.bits, self.packed, self.optimized)
        qzero,  zeros_zero , zeros_scale  = self._quant(zero,  self.group_size_zero,  self.bits, self.packed, False)
        qscale, scales_zero, scales_scale = self._quant(scale, self.group_size_scale, self.bits, self.packed, False)
        store_attr('qdata, qzero, qscale, zeros_zero, zeros_scale, scales_zero, scales_scale', self)
        self.data_shape,self.zero_shape,self.scale_shape = data.shape, zero.shape, scale.shape

    def dequant(self):
        zero  = self._dequant(self.qzero,  self.zeros_zero,  self.zeros_scale,  self.zero_shape,  self.group_size_zero,  self.packed)
        scale = self._dequant(self.qscale, self.scales_zero, self.scales_scale, self.scale_shape, self.group_size_scale, self.packed)
        return  self._dequant(self.qdata,  zero,             scale,             self.data_shape,  self.group_size,       self.packed).to(self.compute_dtype)
    
    def forward(self, x):
        x = self.dequant()@x + self.b(self.a(x))
        col_norms =  (self.dequant() + self.b.weight @ self.a.weight).norm(p=2, dim=1).detach()
        x /= col_norms
        x *= self.m * self.alpha
        return x

In [5]:
base_linear = nn.Linear(n,m, bias=False, dtype=fp16) # ignore bias for now
base_linear.weight.data

tensor([[-0.02,  0.05,  0.08,  ...,  0.01, -0.02, -0.08],
        [-0.06,  0.06, -0.07,  ...,  0.02, -0.09, -0.07],
        [ 0.04,  0.06, -0.04,  ..., -0.08,  0.04,  0.05],
        ...,
        [ 0.05,  0.01, -0.07,  ..., -0.00, -0.07,  0.07],
        [-0.06,  0.03,  0.01,  ..., -0.06,  0.05, -0.06],
        [-0.01, -0.05,  0.07,  ...,  0.06,  0.02,  0.04]], dtype=torch.float16)

In [6]:
x_tst = torch.randn(n, dtype=fp16); x_tst

tensor([-0.46, -1.52, -0.69,  0.25,  0.24, -1.34,  0.62,  0.97,  0.05, -0.63, -0.70, -0.70, -0.32,  1.02, -0.63, -0.01, -0.14, -1.09,  0.02, -0.64,  0.31, -0.71, -1.03,  0.68, -0.18, -1.04,  0.63,
        -1.56,  0.56,  0.48, -0.47, -0.48,  0.91, -0.26, -0.91, -0.36,  1.37,  0.61,  1.27, -1.89,  0.71,  1.18, -0.14,  1.89, -0.59, -2.38, -0.81, -0.79,  0.33, -0.18, -0.53, -1.63,  0.87, -0.52,
        -0.28,  0.15, -0.45,  0.22, -0.10,  0.34,  0.76,  0.36,  0.14, -0.82,  0.02,  0.11,  0.12, -0.99,  0.96, -1.13,  0.05, -0.32,  0.50, -1.92, -0.94,  0.40,  0.72,  0.17, -0.12, -0.57, -1.00,
         0.32, -0.55, -1.63,  0.55, -1.47, -0.77, -1.09, -2.43, -1.39, -1.45, -0.69, -0.41,  0.40,  0.52, -0.92,  0.58, -0.52, -1.04, -0.51,  0.90, -1.67, -0.12,  1.66,  1.82,  0.20,  0.26, -2.13,
         2.00,  0.82, -1.04,  0.66, -0.75, -0.06,  0.74,  1.77,  0.27, -0.43, -0.70,  0.47,  0.45,  0.21,  2.23, -0.52, -1.25,  1.01,  1.11, -0.61], dtype=torch.float16)

In [7]:
y_tst = base_linear(x_tst); y_tst

tensor([    -0.56,      0.68,      0.33,      0.17,      0.04,      0.01,     -0.11,      0.46,     -0.24,      0.21,      0.29,      0.24,     -0.64,      0.39,     -0.05,      0.75,     -0.37,
            -0.44,      0.41,     -0.43,      0.13,     -0.95,     -0.69,     -0.65,      0.47,      0.03,     -0.16,     -0.22,      0.26,     -0.35,      0.31,     -0.51,      1.36,      0.15,
             0.67,     -0.08,      0.97,      0.41,     -0.54,      0.89,     -1.37,      0.53,      0.98,      0.05,     -0.36,      0.41,      0.26,     -0.95,     -0.15,      0.30,     -0.35,
            -0.50,     -1.25,     -0.20,      0.30,      0.22,      0.04,     -0.67,     -0.16,      0.06,      0.88,      0.04,      1.04,      0.81,      0.42,      0.04,      1.38,      0.38,
            -0.41,     -0.46,      0.46,     -0.89,      0.39,     -0.51,     -0.38,      0.43,      0.27,      0.07,      0.30,      0.69,      0.16,      0.63,     -0.74,     -0.08,      0.52,
            -0.58,      0

In [9]:
qdora_linear = QuantedDoraModule(base_linear, bits=3, group_size=64, rank=r, alpha=1, compute_dtype=fp16); qdora_linear

QuantedDoraModule(
  (a): Linear(in_features=128, out_features=32, bias=False)
  (b): Linear(in_features=32, out_features=128, bias=False)
)

In [10]:
y_qdora = qdora_linear(x_tst)

In [11]:
print('quanted result (with packing):')
print(y_qdora.reshape(8,16).data) # reshape to easier read & compare
print()
print('exact result :')
print(y_tst.reshape(8,16).data)
print()
assert_somehow_close(y_qdora, y_tst, max_err=0.3)
print(f'Max error is {max_abs_diff(y_qdora, y_tst):.2f} ✓')

quanted result (with packing):
tensor([[-0.66,  0.69,  0.20,  0.15,  0.06, -0.10, -0.11,  0.41, -0.22,  0.15,  0.31,  0.16, -0.57,  0.42,  0.10,  0.70],
        [-0.38, -0.44,  0.55, -0.36,  0.18, -1.02, -0.74, -0.61,  0.52, -0.02, -0.14, -0.33,  0.27, -0.40,  0.42, -0.53],
        [ 1.30,  0.23,  0.66, -0.04,  0.98,  0.58, -0.47,  0.92, -1.37,  0.55,  1.00,  0.07, -0.46,  0.40,  0.17, -0.93],
        [-0.18,  0.24, -0.46, -0.53, -1.26, -0.21,  0.33,  0.29,  0.06, -0.57, -0.19,  0.13,  0.80,  0.03,  1.04,  0.79],
        [ 0.39,  0.15,  1.37,  0.37, -0.43, -0.48,  0.51, -0.97,  0.43, -0.60, -0.40,  0.42,  0.27,  0.18,  0.35,  0.77],
        [ 0.15,  0.55, -0.75, -0.07,  0.59, -0.57,  0.13,  0.20,  0.03, -0.63, -0.37, -0.34, -0.36,  0.46,  0.01, -0.13],
        [-0.07, -0.14, -0.19,  1.13,  0.44,  0.23,  1.15,  0.16,  0.31,  0.02,  0.56,  0.10, -0.32, -0.65, -0.22, -0.77],
        [-0.52,  0.78, -0.29, -0.25,  0.27,  1.00, -0.06,  0.34, -0.94, -0.79, -0.09, -0.40, -0.61,  0.27, -0.14, -

Let's call backwards on the model, to make sure it runs.

In [13]:
assert {name for name,p in qdora_linear.named_parameters()} == {'m','a.weight','b.weight'} # assert only the dora part is trainable

In [14]:
loss = y_qdora.sum() # arbitrary operation to make y_qdora a scalar
loss

tensor(3.90, dtype=torch.float16, grad_fn=<SumBackward0>)

In [15]:
loss.backward()

In [18]:
print('Loss shapes:')
for name,p in qdora_linear.named_parameters(): print(f'Shape of grad of {name:<8} is {str(list(p.grad.shape)):<7}; shape of  {name:<8} is {list(p.shape)}')

Loss shapes:
Shape of grad of m        is [128]  ; shape of  m        is [128]
Shape of grad of a.weight is [32, 128]; shape of  a.weight is [32, 128]
Shape of grad of b.weight is [128, 32]; shape of  b.weight is [128, 32]


In [19]:
for name,p in qdora_linear.named_parameters(): print(f'--- {name}.grad:\n{p.grad}')

--- m.grad:
tensor([-1.13,  1.16,  0.37,  0.26,  0.11, -0.16, -0.21,  0.69, -0.37,  0.24,  0.56,  0.28, -0.95,  0.71,  0.17,  1.25, -0.67, -0.77,  0.95, -0.64,  0.30, -1.73, -1.41, -1.03,  0.91, -0.03, -0.24,
        -0.62,  0.46, -0.72,  0.74, -0.92,  2.36,  0.43,  1.12, -0.07,  1.73,  0.93, -0.82,  1.74, -2.20,  0.93,  1.73,  0.11, -0.81,  0.69,  0.29, -1.51, -0.30,  0.43, -0.78, -0.91, -2.11, -0.39,
         0.57,  0.47,  0.12, -1.00, -0.32,  0.24,  1.30,  0.05,  1.80,  1.41,  0.68,  0.25,  2.29,  0.65, -0.72, -0.82,  0.89, -1.60,  0.77, -1.08, -0.71,  0.69,  0.50,  0.30,  0.61,  1.28,  0.27,
         0.94, -1.28, -0.11,  0.96, -1.07,  0.22,  0.34,  0.04, -0.97, -0.65, -0.56, -0.60,  0.77,  0.02, -0.23, -0.12, -0.24, -0.34,  1.90,  0.76,  0.44,  2.03,  0.28,  0.53,  0.03,  0.94,  0.16,
        -0.54, -1.11, -0.39, -1.34, -0.92,  1.29, -0.49, -0.43,  0.45,  1.85, -0.10,  0.60, -1.64, -1.30, -0.15, -0.72, -1.12,  0.49, -0.25, -0.29], dtype=torch.float16)
--- a.weight.grad:
tensor([[0.

**Let's verify our implementation against hqq:**

In [30]:
W = base_linear.weight.data; W

tensor([[-0.02,  0.05,  0.08,  ...,  0.01, -0.02, -0.08],
        [-0.06,  0.06, -0.07,  ...,  0.02, -0.09, -0.07],
        [ 0.04,  0.06, -0.04,  ..., -0.08,  0.04,  0.05],
        ...,
        [ 0.05,  0.01, -0.07,  ..., -0.00, -0.07,  0.07],
        [-0.06,  0.03,  0.01,  ..., -0.06,  0.05, -0.06],
        [-0.01, -0.05,  0.07,  ...,  0.06,  0.02,  0.04]], dtype=torch.float16)

In [27]:
hqq_linear = HQQLinear(
    base_linear,
    quant_config=BaseQuantizeConfig(nbits=3, group_size=64), #quantization configuration
    compute_dtype=torch.float16,
    device='cuda',
    initialize=True, #Use False to quantize later
)

In [45]:
W_hqq = hqq_linear.dequantize().cpu(); W_hqq

tensor([[-0.02,  0.06,  0.08,  ...,  0.01, -0.01, -0.09],
        [-0.06,  0.06, -0.06,  ...,  0.02, -0.09, -0.06],
        [ 0.03,  0.06, -0.03,  ..., -0.09,  0.04,  0.06],
        ...,
        [ 0.06,  0.01, -0.06,  ..., -0.01, -0.06,  0.07],
        [-0.06,  0.03,  0.01,  ..., -0.06,  0.04, -0.07],
        [-0.01, -0.06,  0.06,  ...,  0.07,  0.01,  0.04]], dtype=torch.float16)

In [46]:
W_ours = qdora_linear.dequant(); W_ours

tensor([[-0.02,  0.05,  0.08,  ...,  0.01, -0.01, -0.09],
        [-0.06,  0.06, -0.06,  ...,  0.01, -0.09, -0.05],
        [ 0.03,  0.05, -0.04,  ..., -0.08,  0.04,  0.06],
        ...,
        [ 0.06,  0.01, -0.06,  ..., -0.01, -0.06,  0.07],
        [-0.06,  0.03,  0.01,  ..., -0.06,  0.04, -0.06],
        [-0.01, -0.06,  0.06,  ...,  0.06,  0.01,  0.04]], dtype=torch.float16)

In [51]:
print(f'Max abs diff between: W     and W_hqq : {max_abs_diff(W, W_hqq):.2f}')
print(f'Max abs diff between: W     and W_ours: {max_abs_diff(W, W_ours):.2f}')
print(f'Max abs diff between: W_hqq and W_ours: {max_abs_diff(W_hqq, W_ours):.2f}')

Max abs diff between: W     and W_hqq : 0.01
Max abs diff between: W     and W_ours: 0.01
Max abs diff between: W_hqq and W_ours: 0.03


In [52]:
assert_somehow_close(W_est, W_est_hqq, max_err=0.03)