I've analytically derived the gradient of dora wrt `b` (lora up),`a` (lora down) and `m` (magnitude).

In this nb, I verify my analytical derivative against the pytorch autograd derivative. **Edit:** They match!

Reminder: The dora forward is `y = (w + b@a)@x * mag * alpha / col_norms`, where `@` is matmul, `*`, `/` and `+` are pointwise, and the shapes are:
- `w ~ (m,n)`
- `b ~ (m,r)`
- `a ~ (r,n)`
- `mag ~ (m)`
- `col_norms ~ (m)`
- `alpha ~ ()` (ie a scalar)
- `x ~ (bs, n)` (bs = batch size)
- `y ~ (bs, m)`

Also, as in the dora paper, `col_norms` is treated as constant in the backward (even though it depends on `a` and `b`), which greatly simplifies the backward.

Note: PyTorch autograd only yields the **total derivative** of the **final loss**, **evaluated at a point**, eg for `b`:
$$\frac{\partial \text{loss}}{\partial b} \Big|_{x} = \frac{\partial \text{loss}}{\partial y} \Big|_{x} \cdot \frac{\partial y}{\partial b} \Big|_{x}$$
What we are interested in, however, is the function $x \mapsto \frac{\partial y}{\partial b}\Big|_{x}$ (and same for $\frac{\partial y}{\partial a}$ and $\frac{\partial y}{\partial m}$).

In [None]:
import torch
import torch.nn as nn
from torch import tensor, diag, isclose

torch.set_printoptions(precision=2, sci_mode=False, linewidth=200)

from einops import einsum

In [None]:
do_assertion = True

In [None]:
class MyModel(nn.Module):
    def __init__(self, m,r,n, alpha=10, rand_weights=True):
        super().__init__()
        if rand_weights:
            # for robust testing
            self.a = nn.Parameter(torch.randn(r, n))
            self.b = nn.Parameter(torch.randn(m, r))
            self.w = torch.randn(m, n)
            self.mag = nn.Parameter(torch.randn(m))
            self.col_norms = torch.randn(m)
        else:
            # for debugging
            self.a = nn.Parameter(torch.ones(r, n) * 0.5)
            self.b = nn.Parameter(torch.ones(m, r) * 0.3)
            self.w = torch.ones(m, n) * 0.1    
            self.mag = nn.Parameter(torch.ones(m) * 1.1)
            self.col_norms = torch.ones(m) * 0.2
        self.alpha = alpha

    def forward(self, x):
        x = x@self.w.t() + x@self.a.t()@self.b.t() # x(w+ab).t instead of (w+ab)x so it works with batched x
        x /= self.col_norms
        x *= self.mag * self.alpha
        return x

In [None]:
bs,m,r,n = 2,5,3,4 # batch size, out, lora_rank, in

In [None]:
x = torch.ones(bs,n)
x

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.]])

In [None]:
model = MyModel(m,r,n, rand_weights=True)
y = model(x)
print(y.shape)
y

torch.Size([2, 5])


tensor([[ 15.84,  -5.73,  16.66, -13.06,  24.94],
        [ 15.84,  -5.73,  16.66, -13.06,  24.94]], grad_fn=<MulBackward0>)

In [None]:
loss = y.mean()
loss

tensor(7.73, grad_fn=<MeanBackward0>)

In [None]:
y.retain_grad() # we need dloss_dy as input for our manual gradient calculation below
loss.backward()
print(f'grad of y (shape {list(y.grad.shape)}):\n{y.grad}')

grad of y (shape [2, 5]):
tensor([[0.10, 0.10, 0.10, 0.10, 0.10],
        [0.10, 0.10, 0.10, 0.10, 0.10]])


---

In [None]:
# for notational convenience
w,b,a = model.w.data,model.b.data,model.a.data
mag = model.mag.data
beta = 1 / model.col_norms.data
alpha = model.alpha

In [None]:
def compare(should,is_):
    print(f'Should:\n{should}')
    print(f'Is    :\n{is_}')
    if do_assertion: assert isclose(should, is_).all()

**Derivate wrt mag**

In [None]:
def dy_dmag():
    z = x@w.t() + x@a.t()@b.t()
    return alpha * z[:,:,None]*diag(beta) 
assert dy_dmag().shape == (bs,m,m)

def dloss_dmag(dloss_dy):
    return einsum(dloss_dy, dy_dmag(), 'bs m, bs m m2 -> m2')

In [None]:
compare(model.mag.grad, dloss_dmag(y.grad))

Should:
tensor([11.41,  0.76, -2.34,  1.86, -6.53])
Is    :
tensor([11.41,  0.76, -2.34,  1.86, -6.53])


Yes!

**Derivate wrt a**

In [None]:
def dy_da():
    return alpha * (mag * beta)[None,:,None,None] * x[:,None,None,:] * b[None,:,:,None]
assert dy_da().shape == (bs,m,r,n)

def dloss_da(dloss_dy):
    return einsum(dloss_dy, dy_da(), 'bs m, bs m r n -> r n')

In [None]:
compare(model.a.grad, dloss_da(y.grad))

Should:
tensor([[7.93, 7.93, 7.93, 7.93],
        [1.63, 1.63, 1.63, 1.63],
        [1.89, 1.89, 1.89, 1.89]])
Is    :
tensor([[7.93, 7.93, 7.93, 7.93],
        [1.63, 1.63, 1.63, 1.63],
        [1.89, 1.89, 1.89, 1.89]])


Yes!

**Derivate wrt b**

In [None]:
def dy_db():
    return alpha * diag(mag*beta)[None,:,:,None] * (x@a.t())[:,None,None,:]
assert dy_db().shape == (bs,m,m,r)

def dloss_db(dloss_dy):
    return einsum(dloss_dy, dy_db(), 'bs m, bs m m2 r -> m2 r')

In [None]:
compare(model.b.grad, dloss_db(y.grad))

Should:
tensor([[-0.01, -0.04,  0.35],
        [-0.06, -0.18,  1.63],
        [ 0.04,  0.12, -1.08],
        [ 0.06,  0.18, -1.64],
        [ 0.06,  0.20, -1.76]])
Is    :
tensor([[-0.01, -0.04,  0.35],
        [-0.06, -0.18,  1.63],
        [ 0.04,  0.12, -1.08],
        [ 0.06,  0.18, -1.64],
        [ 0.06,  0.20, -1.76]])


Yes!