In [1]:
%load_ext autoreload
%autoreload 2
%cd ..

/Users/akkirr/Desktop/IT/annotated-diffusion


In [2]:
import pytest
import torch
from torch import nn
from torch.optim import Adam
from copy import deepcopy

from lora import *
from lora.lora import LORA_MODULES
from lora.utils import isinstance_by_class, unfreeze_module

class Attention(nn.Module):
    def __init__(self):
        super().__init__()
        self.QKV = nn.Linear(1, 1)
        self.C = nn.Linear(1, 1)
        self.lrelu = nn.LeakyReLU()

    def forward(self, x):
        return self.C(self.lrelu(self.QKV(x)))


class TimeEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.time_proj = nn.Linear(1, 1)
        self.lrelu = nn.LeakyReLU()

    def forward(self, x):
        return self.lrelu(self.time_proj(x))

class A(nn.Module):
    def __init__(self):
        super().__init__()
        self.just_linear = nn.Linear(1, 1)
        self.attn = Attention()
        self.time_embedder = TimeEmbedding()

    def forward(self, x):
        return self.attn(self.just_linear(x) + self.time_embedder(x))

def models_eq(sd1, sd2, keys):
    return all(torch.allclose(sd1[key], sd2[key]) for key in keys)

def models_neq(sd1, sd2, keys):
    return all(not torch.allclose(sd1[key], sd2[key]) for key in keys)

def find_lora_modules(model: nn.Module, lora_modules=LORA_MODULES):
    for name, module in model.named_modules():
        if isinstance_by_class(module, lora_modules):
            yield name

In [24]:
torch.manual_seed(0)

model = A()

inject_lora(
    model,
    2, 0,
    ["Attention"],
    [nn.Linear],
    [LoraInjectedLinear],
    verbose=True
)
optim = Adam(model.parameters())
freeze_lora(model)
sd1 = deepcopy(model.state_dict())

x = torch.tensor([[1]], dtype=torch.float32)

optim.zero_grad()
loss1 = model(x).mean()
loss1.backward()
optim.step()

optim.zero_grad()
loss2 = model(x).mean()
loss2.backward()
optim.step()

optim.zero_grad()
loss2 = model(x).mean()
loss2.backward()
optim.step()

loss3 = model(x).mean()
assert loss3 != loss1

sd2 = deepcopy(model.state_dict())

all_keys = set(sd2.keys())
lora_keys = {k for k in all_keys if 'lora_up' in k or 'lora_down' in k}
print()
print(all_keys)
print()
print(lora_keys)
print()

for a, b in sd1.items():
    print(a)
    print(b)
print()
for a, b in sd2.items():
    print(a)
    print(b)

assert models_eq(sd1, sd2, all_keys - lora_keys)
assert models_neq(sd1, sd2, lora_keys)


Injected lora (1x2x1) in attn.QKV
Injected lora (1x2x1) in attn.C
A False False
Linear True False
Attention False False
LoraInjectedLinear False True
*
LoraInjectedLinear False True
*
LeakyReLU True False
TimeEmbedding False False
Linear True False
LeakyReLU True False

{'attn.C.src_linear.bias', 'attn.QKV.src_linear.weight', 'just_linear.bias', 'attn.C.src_linear.weight', 'attn.C.lora_up.weight', 'time_embedder.time_proj.weight', 'attn.QKV.lora_up.weight', 'attn.QKV.lora_down.weight', 'just_linear.weight', 'attn.C.lora_down.weight', 'attn.QKV.src_linear.bias', 'time_embedder.time_proj.bias'}

{'attn.QKV.lora_up.weight', 'attn.QKV.lora_down.weight', 'attn.C.lora_down.weight', 'attn.C.lora_up.weight'}

just_linear.weight
tensor([[-0.0075]])
just_linear.bias
tensor([0.5364])
attn.QKV.src_linear.weight
tensor([[-0.8230]])
attn.QKV.src_linear.bias
tensor([-0.7359])
attn.QKV.lora_down.weight
tensor([[0.2017],
        [0.4190]])
attn.QKV.lora_up.weight
tensor([[0., 0.]])
attn.C.src_linear.we

AssertionError: 

In [4]:
for name, p in model.named_parameters():
    if p.requires_grad:
        print(name)

attn.QKV.lora_down.weight
attn.QKV.lora_up.weight
attn.C.lora_down.weight
attn.C.lora_up.weight


In [22]:
loss = model(x).mean()
loss.backward()
optim.step()
optim.zero_grad()
print(loss)

tensor(0.2751, grad_fn=<MeanBackward0>)


In [23]:
model.attn.QKV.lora_up.weight.grad

tensor([[-0.0111, -0.0231]])