**Simplified python reference implementation of hqq-qdora.**

This nb only contains the final module, not the build up to it. For the full build up, see `python_hqq_qdora_v2.ipynb` and `python_hqq_qdora.ipynb`.

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import tensor, cat, int32, float16 as fp16
from math import ceil

from fastcore.basics import store_attr

from hqq.core.quantize import Quantizer # optional; only for optimizing during quanting

torch.set_printoptions(linewidth=200, precision=2, sci_mode=False)

In [10]:
def max_abs_diff(a,b): return (a-b).abs().max()
def assert_close(a,b): assert torch.isclose(a,b,atol=1e-2).all(), f'assert_close failed, max error = {max_abs_diff(a,b)}'
def assert_somehow_close(a,b,max_err=0.12): assert torch.isclose(a,b,atol=max_err).all(), f'assert_somehow_close failed, max error = {max_abs_diff(a,b)}' # allow some error due to quanting

In [24]:
class QuantedDoraModule(nn.Module):
    def __init__(self, linear, bits, group_size, rank, alpha, compute_dtype=fp16, packed=True, optimized=True, group_size_zero=None, group_size_scale=None):
        super().__init__()
        # for quanting
        store_attr('bits,group_size,packed,optimized,compute_dtype',self)
        self.group_size_zero, self.group_size_scale = group_size_zero or 128, group_size_scale or 128 # hqq uses group size of 128 for zero & scale
        self.quant(linear.weight.data)
        # for dora
        self.a = nn.Linear(linear.in_features, rank, bias=False, dtype=fp16)
        self.b = nn.Linear(rank, linear.out_features, bias=False, dtype=fp16)
        self.alpha = alpha
        self.m = nn.Parameter(linear.weight.norm(p=2, dim=1))
        # init a & b to 0 -- a should be inited differently, but for sake of simplicity, set it to 0 as well
        self.a.weight.data.zero_()
        self.b.weight.data.zero_()

    @staticmethod
    def pack(vals):
        assert len(vals.shape)==2, 'Pass a 2d tensor'
        for v in vals.flatten(): assert 0<=v.item()<=7 and v.item()//1==v.item(), f'Value {v} can\'t be represented by 3 bits or is not an integer'    
        rows, cols = vals.shape
        n_packs = ceil(rows/10)
        padded_vals = torch.zeros(n_packs*10, cols, dtype=int32)
        padded_vals[:rows, :cols] = vals
        packed = torch.zeros(n_packs, cols, dtype=int32)
        for k in range(10): packed = (packed << 3) | padded_vals[k*n_packs:(k+1)*n_packs,:] # shift right 3 bits, then set last 3 bits to padded_vals[...,...]
        return packed

    @staticmethod
    def unpack(packed, rows):
        def bin_to_dec(b3,b2,b1): return 4*b3 + 2*b2 + b1
        assert len(packed.shape)==2 and packed.dtype==int32, 'Pass a 2d tensor of int32s'
        n_packs, cols = packed.shape
        padded_vals = torch.zeros(n_packs*10, cols, dtype=int32)
        for k_up, k_down in zip(range(10), reversed(range(10))): # top-most 3bits vals (k_up=0) are most right-shifted (k_down=9)
            padded_vals[k_down*n_packs:(k_down+1)*n_packs,:] = ((packed >> (3*k_up)) & 0b111) # righ-shift 3*k_up times, so last 3 bits are those we want; then only select those via 0b111            
        return padded_vals[:rows,:]
    
    @staticmethod
    def _quant(data, group_size, bits=3, packed=True, optimize=True):
        assert data.numel()%group_size==0, f'group_size {group_size} can\'t evenly split the data (numel = {data.numel()})'
        data = data.float().reshape(group_size,-1)
        
        min_, max_ = data.min(axis=0, keepdim=True).values, data.max(axis=0, keepdim=True).values
    
        scale = (2**bits-1) / (max_-min_) # note: hqq clamp to 2e4 to avoid half-precision problems, let's ignore that for now
        zero = -min_ * scale
    
        if optimize: data, scale, zero = Quantizer.optimize_weights(data, scale, zero, min_max=[0, 2**bits-1])
        else: data = (data * scale + zero).round()

        if packed: data = QuantedDoraModule.pack(data)
        return data, zero, 1/scale # invert scale, so in dequanting we multiply instead of divide 

    @staticmethod
    def _dequant(data, zero, scale, shape, group_size, packed=True):
        if packed: data = QuantedDoraModule.unpack(data, rows=group_size)
        data = (data-zero)*scale
        return data.reshape(shape)

    def quant(self, data):
        qdata,  zero       , scale        = self._quant(data,  self.group_size,       self.bits, self.packed, self.optimized)
        qzero,  zeros_zero , zeros_scale  = self._quant(zero,  self.group_size_zero,  self.bits, self.packed, False)
        qscale, scales_zero, scales_scale = self._quant(scale, self.group_size_scale, self.bits, self.packed, False)
        store_attr('qdata, qzero, qscale, zeros_zero, zeros_scale, scales_zero, scales_scale', self)
        self.data_shape,self.zero_shape,self.scale_shape = data.shape, zero.shape, scale.shape

    def dequant(self):
        zero  = self._dequant(self.qzero,  self.zeros_zero,  self.zeros_scale,  self.zero_shape,  self.group_size_zero,  self.packed)
        scale = self._dequant(self.qscale, self.scales_zero, self.scales_scale, self.scale_shape, self.group_size_scale, self.packed)
        return  self._dequant(self.qdata,  zero,             scale,             self.data_shape,  self.group_size,       self.packed).to(self.compute_dtype)
    
    def forward(self, x):
        x = self.dequant()@x + self.b(self.a(x))
        col_norms =  (self.dequant() + self.b.weight @ self.a.weight).norm(p=2, dim=1).detach()
        x /= col_norms
        x *= self.m * self.alpha
        return x

In [25]:
base_linear = nn.Linear(4,5, bias=False, dtype=fp16) # ignore bias for now
base_linear.weight.data

tensor([[-0.03, -0.16, -0.25,  0.38],
        [ 0.13,  0.37, -0.12, -0.47],
        [-0.07, -0.41,  0.50,  0.15],
        [ 0.19, -0.39, -0.34, -0.15],
        [-0.49,  0.48,  0.10, -0.05]], dtype=torch.float16)

In [26]:
x_tst = torch.randn(4, dtype=fp16); x_tst

tensor([1.06, 0.13, 1.72, 1.35], dtype=torch.float16)

In [27]:
y_tst = base_linear(x_tst); y_tst

tensor([ 0.04, -0.65,  0.93, -0.64, -0.35], dtype=torch.float16, grad_fn=<SqueezeBackward4>)

In [28]:
qdora = QuantedDoraModule(base_linear, bits=3, group_size=5, group_size_zero=4, group_size_scale=4, rank=2, alpha=1, compute_dtype=fp16); qdora

QuantedDoraModule(
  (a): Linear(in_features=4, out_features=2, bias=False)
  (b): Linear(in_features=2, out_features=5, bias=False)
)

In [29]:
y_qdora = qdora(x_tst)

In [30]:
print(f'quanted result (with packing): {y_qdora}')
print(f'exact   result               : {y_tst}')
assert_somehow_close(y_qdora, y_tst)

quanted result (with packing): tensor([ 0.11, -0.70,  0.88, -0.57, -0.37], dtype=torch.float16, grad_fn=<MulBackward0>)
exact   result               : tensor([ 0.04, -0.65,  0.93, -0.64, -0.35], dtype=torch.float16, grad_fn=<SqueezeBackward4>)


Let's call backwards on the model

In [31]:
assert {n for n,p in qdora.named_parameters()} == {'m','a.weight','b.weight'} # assert only the dora part is trainable

In [32]:
loss = y_qdora.sum() # abitrary operation to make y_qdora a scalar
loss

tensor(-0.65, dtype=torch.float16, grad_fn=<SumBackward0>)

In [33]:
loss.backward()

In [34]:
print('Loss shapes:')
for n,p in qdora.named_parameters():
    print(f'Shape of grad of {n:<8} is {str(list(p.grad.shape)):<7}; shape of  {n:<8} is {list(p.shape)}')

Loss shapes:
Shape of grad of m        is [5]    ; shape of  m        is [5]
Shape of grad of a.weight is [2, 4] ; shape of  a.weight is [2, 4]
Shape of grad of b.weight is [5, 2] ; shape of  b.weight is [5, 2]


In [35]:
for n,p in qdora.named_parameters():
    print(f'--- {n}.grad:\n{p.grad}')

--- m.grad:
tensor([ 0.22, -1.12,  1.32, -1.00, -0.54], dtype=torch.float16)
--- a.weight.grad:
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]], dtype=torch.float16)
--- b.weight.grad:
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], dtype=torch.float16)
