**Simplified python reference implementation of hqq-qdora.**

(This nb only contains the final module, not the build up to it. For the full build up, see `python_hqq_qdora.ipynb`)

Simplifications:
- Replaced hqq quanting with simple group-wise quanting, while keeping the hqq dequanting. This is okay, because the fwd only uses the output of quanting, and hqq quanting and simple group-wise quanting return the same kinds of output (namely: quanted data, zeropoints, scales).
- Initializing lora_a with 0. Same reason as above.

**ToDo:** Compare against hqq implementation.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import tensor, cat, int32
from torch import float16 as fp16
from math import ceil

torch.set_printoptions(linewidth=200, precision=2, sci_mode=False)

In [None]:
def assert_close(a,b): assert torch.isclose(a,b,atol=1e-2).all()
def assert_somehow_close(a,b): assert torch.isclose(a,b,atol=0.1).all() # allow error of 0.1 due to quanting

In [None]:
class QuantedDoraModule(nn.Module):
    def __init__(self, linear, bits, group_size, rank, alpha):
        super().__init__()
        # for quanting
        assert base_linear.weight.numel() % group_size ==0, f'group_size {group_size} can\'t cleanly split weight of base layer ({base_linear.weight.numel()} items)'
        self.bits,self.group_size, = bits,group_size
        self.quant(linear)
        # for dora
        self.a = nn.Linear(linear.in_features, rank, bias=False, dtype=fp16)
        self.b = nn.Linear(rank, linear.out_features, bias=False, dtype=fp16)
        self.alpha = alpha
        self.m = nn.Parameter(linear.weight.norm(p=2, dim=1))
        # init a & b to 0 -- a should be inited differently, but for sake of simplicity, set it to 0 as well
        self.a.weight.data.zero_()
        self.b.weight.data.zero_()

    def quant(self, linear):
        data = linear.weight.data
        self.shape = data.shape

        # repeat last element, to have a multiple of group_size elements
        # note: element to pad with mustn't change any attribute that's use for quanting (eg min & max in a group)
        n_pad = data.numel()%self.group_size
        data = F.pad(data, (0,n_pad), 'constant', data.flatten()[-1])
        assert data.numel()%self.group_size==0

        data = data.reshape(-1,self.group_size)
        
        min_, max_ = data.min(axis=-1, keepdim=True).values, data.max(axis=-1, keepdim=True).values
        
        self.zero = min_
        self.scale = (max_-min_) / (2**self.bits-1) 
        
        # note: can't use shorthand ops like -= as they modify tensor in-place
        data = data - self.zero # start at 0
        data = data / self.scale # scale to [0, 2**bits-1]
        data = data.round().to(int)

        # packed quantized data
        self.pqdata = self.pack(data.flatten())

    # pack 10 3bit values into a 32bit val
    @staticmethod
    def pack(vals):
        for v in vals: assert 0<=v<=7 and v//1==v, f'Value {v} can\'t be represented by 3 bits or is not an integer'
        
        n_packs = ceil(len(vals)/10)
    
        # pad with 0, to have a multiple of pack_size elements
        n_pad = n_packs*10 - len(vals)
        vals = F.pad(vals, (0,n_pad), 'constant', 0)
        assert len(vals)==n_packs*10
    
        packed = torch.zeros(n_packs, dtype=int32)
        for i in range(n_packs):
            # pack the 10 vals from 10*i to 10*(i+1) into packed[i]
            for x in vals[10*i:10*(i+1)]: packed[i] = (packed[i] << 3) | x # shift right 3 bits, then set last 3 bits to x
        return packed

    def dequant(self):
        data = self.unpack(self.pqdata)[:self.shape.numel()] # unpack & remove padding that was added during packing
        data = data.reshape(-1,self.group_size)
        data = data*self.scale + self.zero
        return data.reshape(self.shape)
    
    # unpack a 32bit value into 10 3bit vals
    @staticmethod
    def unpack(packed):
        def bin_to_dec(b3,b2,b1): return 4*b3 + 2*b2 + b1
        for v in packed: isinstance(v, int), f'Value {v} is not an integer'
        unpacked = []
        for pack in packed:
            for i in reversed(range(10)):
                unpacked.append((pack >> (3*i)) & 0b111) # righ-shift 3*i times, so last 3 bits are those we want; then only select those via 0b111            
        return tensor(unpacked)
    
    def forward(self, x):
        x = self.dequant()@x + self.b(self.a(x))
        col_norms =  (self.dequant() + self.b.weight @ self.a.weight).norm(p=2, dim=1).detach()
        x /= col_norms
        x *= self.m * self.alpha
        return x

In [None]:
base_linear = nn.Linear(4,5, bias=False, dtype=fp16) # ignore bias for now
base_linear.weight.data

tensor([[-0.44,  0.25,  0.37, -0.49],
        [ 0.27, -0.07, -0.40,  0.14],
        [ 0.10, -0.43, -0.15, -0.02],
        [ 0.03,  0.16,  0.09, -0.13],
        [ 0.02,  0.31, -0.17, -0.24]], dtype=torch.float16)

In [None]:
x_tst = torch.randn(4, dtype=fp16); x_tst

tensor([ 0.75, -0.51,  1.20,  0.54], dtype=torch.float16)

In [None]:
y_tst = base_linear(x_tst); y_tst

tensor([-0.28, -0.16,  0.09, -0.02, -0.48], dtype=torch.float16, grad_fn=<SqueezeBackward4>)

In [None]:
qdora = QuantedDoraModule(base_linear, bits=3, group_size=5, rank=2, alpha=1); qdora

QuantedDoraModule(
  (a): Linear(in_features=4, out_features=2, bias=False)
  (b): Linear(in_features=2, out_features=5, bias=False)
)

In [None]:
y_qdora = qdora(x_tst)

In [None]:
print(f'quanted result (with packing): {y_qdora}')
print(f'exact   result               : {y_tst}')
assert_somehow_close(y_qdora, y_tst)

quanted result (with packing): tensor([-0.30, -0.18,  0.07, -0.06, -0.49], dtype=torch.float16, grad_fn=<MulBackward0>)
exact   result               : tensor([-0.28, -0.16,  0.09, -0.02, -0.48], dtype=torch.float16, grad_fn=<SqueezeBackward4>)


Let's call backwards on the model

In [None]:
# assert only the dora part is trainable
assert {n for n,p in qdora.named_parameters()} == {'m','a.weight','b.weight'}

In [None]:
loss = y_qdora.sum() # abitrary operation to make y_qdora a scalar
loss

tensor(-0.97, dtype=torch.float16, grad_fn=<SumBackward0>)

In [None]:
loss.backward()

In [None]:
print('Loss shapes:')
for n,p in qdora.named_parameters():
    print(f'Shape of grad of {n:<8} is {str(list(p.grad.shape)):<7}; shape of  {n:<8} is {list(p.shape)}')

Loss shapes:
Shape of grad of m        is [5]    ; shape of  m        is [5]
Shape of grad of a.weight is [2, 4] ; shape of  a.weight is [2, 4]
Shape of grad of b.weight is [5, 2] ; shape of  b.weight is [5, 2]


In [None]:
for n,p in qdora.named_parameters():
    print(f'--- {n}.grad:\n{p.grad}')

--- m.grad:
tensor([-0.38, -0.37,  0.15, -0.27, -1.15], dtype=torch.float16)
--- a.weight.grad:
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]], dtype=torch.float16)
--- b.weight.grad:
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], dtype=torch.float16)
