In [1]:
import torch
from minlora import add_lora, apply_to_lora, disable_lora, enable_lora, get_lora_params, merge_lora, name_is_lora, remove_lora, load_multiple_lora, select_lora, get_lora_state_dict
_ = torch.set_grad_enabled(False)


In [2]:
# a simple model
model = torch.nn.Sequential(
    torch.nn.Linear(in_features=5, out_features=7),
    torch.nn.Linear(in_features=7, out_features=3),
)

x = torch.randn(1, 5)
y = model(x)
print(y)
Y0 = y


tensor([[ 0.2740,  0.1809, -0.1765]])


In [3]:
# # print(model.named_modules())
# from labml.logger import inspect 
# # inspect(model.named_modules())
# # model.named_modules()
# for name, module in model.named_modules():
#     # print(name) # 0 1
#     # print("---"*20)
#     # print(name,module) 
#     """Sequential(
#   (0): Linear(in_features=5, out_features=7, bias=True)
#   (1): Linear(in_features=7, out_features=3, bias=True)
# )
# Linear(in_features=5, out_features=7, bias=True)
# Linear(in_features=7, out_features=3, bias=True)"""
#     # inspect(module)
#     if hasattr(module, "parametrizations"):
#       print("yes")


In [4]:
# add lora to the model
# becase B is initialized to 0, the output is the same as before
import pysnooper
with pysnooper.snoop():
    add_lora(model)
  

y = model(x)

print(model)
assert torch.allclose(y, Y0)


Sequential(
  (0): ParametrizedLinear(
    in_features=5, out_features=7, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
  (1): ParametrizedLinear(
    in_features=7, out_features=3, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
)


[33m[2mSource path:... [22m/tmp/ipykernel_1761884/1850508365.py[0m
[32m[2mNew var:....... [22m__name__ = '__main__'[0m
[32m[2mNew var:....... [22m__doc__ = 'Automatically created module for IPython interactive environment'[0m
[32m[2mNew var:....... [22m__package__ = None[0m
[32m[2mNew var:....... [22m__loader__ = None[0m
[32m[2mNew var:....... [22m__spec__ = None[0m
[32m[2mNew var:....... [22m__builtin__ = <module 'builtins' (built-in)>[0m
[32m[2mNew var:....... [22m__builtins__ = <module 'builtins' (built-in)>[0m
[32m[2mNew var:....... [22m_ih = ['', 'import torch\nfrom minlora import add_lora...)\n\nprint(model)\nassert torch.allclose(y, Y0)'][0m
[32m[2mNew var:....... [22m_oh = {}[0m
[32m[2mNew var:....... [22m_dh = [PosixPath('/home/yimingshi/shiym_proj/Sarautils/minLoRA'), PosixPath('/home/yimingshi/shiym_proj/Sarautils/minLoRA')][0m
[32m[2mNew var:....... [22mIn = ['', 'import torch\nfrom minlora import add_lora...)\n\nprint(model)\n

In [5]:
# to make the output different, we need to initialize B to something non-zero
model.apply(apply_to_lora(lambda x: torch.nn.init.ones_(x.lora_B)))
y = model(x)
print(y)
assert not torch.allclose(y, Y0)
Y1 = y


tensor([[ 0.3777,  0.2858, -0.1262]])


In [6]:
# now let's try to disable lora, the output is the same as before lora is added
disable_lora(model)
y = model(x)
assert torch.allclose(y, Y0)


In [7]:
# enable lora again
enable_lora(model)
y = model(x)
assert torch.allclose(y, Y1)


In [8]:
# let's save the state dict for later use
state_dict_to_save = get_lora_state_dict(model)
state_dict_to_save.keys()


dict_keys(['0.parametrizations.weight.0.lora_A', '0.parametrizations.weight.0.lora_B', '1.parametrizations.weight.0.lora_A', '1.parametrizations.weight.0.lora_B'])

In [9]:
# you can remove lora from the model
remove_lora(model)


In [10]:
# lets try to load the lora back
# first we need to add lora to the model
add_lora(model)
# then we can load the lora parameters
# strict=False is needed because we are loading a subset of the parameters
_ = model.load_state_dict(state_dict_to_save, strict=False) 
y = model(x)
assert torch.allclose(y, Y1)


In [11]:
# we can merge it to make it a normal linear layer, so there is no overhead for inference
merge_lora(model)
y = model(x)
assert torch.allclose(y, Y1)


In [12]:
# model now has no lora parameters
model


Sequential(
  (0): Linear(in_features=5, out_features=7, bias=True)
  (1): Linear(in_features=7, out_features=3, bias=True)
)

## Training a model

In [13]:
model = torch.nn.Linear(in_features=5, out_features=3)
# Step 1: Add LoRA to the model
add_lora(model)

# Step 2: Collect the parameters, pass them to the optimizer

parameters = [
    {"params": list(get_lora_params(model))},
]
optimizer = torch.optim.AdamW(parameters, lr=1e-3)

# Step 3: Train the model
# ...
# simulate training, update the LoRA parameters
model.apply(apply_to_lora(lambda x: torch.nn.init.normal_(x.lora_A)))
model.apply(apply_to_lora(lambda x: torch.nn.init.normal_(x.lora_B)))

# Step 4: export the LoRA parameters
state_dict = model.state_dict()
lora_state_dict = {k: v for k, v in state_dict.items() if name_is_lora(k)}


  from .autonotebook import tqdm as notebook_tqdm


## Loading and Inferencing with LoRA

In [14]:
# Step 1: Add LoRA to your model
add_lora(model)

# Step 2: Load the LoRA parameters
_ = model.load_state_dict(lora_state_dict, strict=False)

# Step 3: Merge the LoRA parameters into the model
merge_lora(model)


## Inferencing with multiple LoRA models

In [15]:
# to avoid re-adding lora to the model when rerun the cell, remove lora first 
remove_lora(model)
# Step 1: Add LoRA to your model
add_lora(model)

# Step 2: Load the LoRA parameters

# fake 3 sets of LoRA parameters
lora_state_dict_0 = lora_state_dict
lora_state_dict_1 = {k: torch.ones_like(v) for k, v in lora_state_dict.items()}
lora_state_dict_2 = {k: torch.zeros_like(v) for k, v in lora_state_dict.items()}
lora_state_dicts = [lora_state_dict_0, lora_state_dict_1, lora_state_dict_2]

load_multiple_lora(model, lora_state_dicts)

# Step 3: Select which LoRA to use at inference time
Y0 = select_lora(model, 0)(x)
Y1 = select_lora(model, 1)(x)
Y2 = select_lora(model, 2)(x)


In [16]:
Y0, Y1, Y2


(tensor([[ 1.1677,  0.1617, -0.8168]]),
 tensor([[1.2984, 0.5447, 0.2660]]),
 tensor([[ 0.6715, -0.0822, -0.3609]]))

In [17]:
remove_lora(model)
init_state_dict = model.state_dict()
# verify that it's the same as if we load the lora parameters one by one
for state_dict in lora_state_dicts:
    remove_lora(model)
    _ = model.load_state_dict(init_state_dict, strict=False)
    add_lora(model)
    _ = model.load_state_dict(state_dict, strict=False)
    merge_lora(model)
    y = model(x)
    print(y)


tensor([[ 1.1677,  0.1617, -0.8168]])
tensor([[1.2984, 0.5447, 0.2660]])
tensor([[ 0.6715, -0.0822, -0.3609]])


In [18]:
# class Test():
#     def __init__(self,num=1, layer=None):
#         # self.layer = layer
#         # for arg in args:
#             # print(arg)
#         self.layer = layer
#         self.num = num
# layer = torch.nn.Sequential(
#             torch.nn.Linear(5, 3),
#             torch.nn.ReLU(),
#             torch.nn.Linear(3, 3),
#             torch.nn.ReLU()
# )
# test = Test(1,layer=layer)
# # print(test)        
# # print(test.num) # 1
# # print(test.layer) # Linear(in_features=5, out_features=3, bias=True)
# # inspect(test.layer)
# print(test.layer[0])
# print(test.layer[0].weight)
# # print(test.layer[0]) 
# # print(test.layer[0].weight) 
# """Parameter containing:
# tensor([[-0.1163,  0.1544,  0.0566, -0.2275,  0.4066],
#         [-0.0287, -0.3928,  0.2575, -0.1188, -0.0773],
#         [-0.0870, -0.2780,  0.2427,  0.0463, -0.0287]], requires_grad=True)"""
        
# # print(test.layer.weight.shape) # torch.Size([3, 5])

# # print(test.layer.weight.dtype) # torch.float32

