In [1]:
import torch
import torch.nn as nn
from functools import partial

from minsara import SaRAParametrization,add_sara, apply_to_sara, disable_sara, enable_sara, get_sara_params, merge_sara, name_is_sara, remove_sara,get_sara_state_dict
_ = torch.set_grad_enabled(False)

import sys
sys.setrecursionlimit(150000)  # 举例增加到1500，根据实际需要调整


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# a simple model
model = torch.nn.Sequential(
    torch.nn.Linear(in_features=5, out_features=7),
    torch.nn.Linear(in_features=7, out_features=3),
)

x = torch.randn(1, 5)
y = model(x)
print(y)
Y0 = y


tensor([[-0.2117, -0.2127, -0.4248]])


In [3]:
sara_config = {
    nn.Linear: {
        "weight": partial(SaRAParametrization.from_linear, rank=2),
    },
}


In [4]:
# add sara to the model
# becase B is initialized to 0, the output is the same as before
import pysnooper
with pysnooper.snoop():
    add_sara(model, sara_config=sara_config)
y = model(x)

# print(model)
from labml.logger import inspect
inspect(model)
# from torchkeras import summary
# summary(model, input_shape=(5,))
# assert torch.allclose(y, Y0)


[33m[2mSource path:... [22m/tmp/ipykernel_1762759/2490519804.py[0m
[32m[2mNew var:....... [22m__name__ = '__main__'[0m
[32m[2mNew var:....... [22m__doc__ = 'Automatically created module for IPython interactive environment'[0m
[32m[2mNew var:....... [22m__package__ = None[0m
[32m[2mNew var:....... [22m__loader__ = None[0m
[32m[2mNew var:....... [22m__spec__ = None[0m
[32m[2mNew var:....... [22m__builtin__ = <module 'builtins' (built-in)>[0m
[32m[2mNew var:....... [22m__builtins__ = <module 'builtins' (built-in)>[0m
[32m[2mNew var:....... [22m_ih = ['', 'import torch\nimport torch.nn as nn\nfrom ...put_shape=(5,))\n# assert torch.allclose(y, Y0)'][0m
[32m[2mNew var:....... [22m_oh = {}[0m
[32m[2mNew var:....... [22m_dh = [PosixPath('/root/shiym_proj/Sara/utils/SaRA'), PosixPath('/root/shiym_proj/Sara/utils/SaRA')][0m
[32m[2mNew var:....... [22mIn = ['', 'import torch\nimport torch.nn as nn\nfrom ...put_shape=(5,))\n# assert torch.allclose(y,

: 

In [None]:
# to make the output different, we need to initialize B to something non-zero
# model.apply(apply_to_sara(lambda x: torch.nn.init.ones_(x.lora_B)))
y = model(x)
print(y)
assert not torch.allclose(y, Y0)
Y1 = y
# print(model)


tensor([[ 0.5632, -0.2915, -0.0617]])


In [None]:
# now let's try to disable sara, the output is the same as before sara is added
disable_sara(model)
y = model(x)
assert torch.allclose(y, Y0)


RecursionError: maximum recursion depth exceeded while calling a Python object

In [None]:
# enable sara again
enable_sara(model)
y = model(x)
assert torch.allclose(y, Y1)


In [None]:
# let's save the state dict for later use
state_dict_to_save = get_sara_state_dict(model)
state_dict_to_save.keys()


In [None]:
# you can remove sara from the model
remove_sara(model)


In [None]:
# lets try to load the sara back
# first we need to add sara to the model
add_sara(model)
# then we can load the sara parameters
# strict=False is needed because we are loading a subset of the parameters
_ = model.load_state_dict(state_dict_to_save, strict=False) 
y = model(x)
assert torch.allclose(y, Y1)


In [None]:
# we can merge it to make it a normal linear layer, so there is no overhead for inference
merge_sara(model)
y = model(x)
assert torch.allclose(y, Y1)


In [None]:
# model now has no sara parameters
model


## Training a model

In [None]:
model = torch.nn.Linear(in_features=5, out_features=3)
# Step 1: Add sara to the model
add_sara(model)

# Step 2: Collect the parameters, pass them to the optimizer

parameters = [
    {"params": list(get_sara_params(model))},
]
optimizer = torch.optim.AdamW(parameters, lr=1e-3)

# Step 3: Train the model
# ...
# simulate training, update the sara parameters
model.apply(apply_to_sara(lambda x: torch.nn.init.normal_(x.lora_A)))
model.apply(apply_to_sara(lambda x: torch.nn.init.normal_(x.lora_B)))

# Step 4: export the sara parameters
state_dict = model.state_dict()
sara_state_dict = {k: v for k, v in state_dict.items() if name_is_sara(k)}


## Loading and Inferencing with sara

In [None]:
# Step 1: Add sara to your model
add_sara(model)

# Step 2: Load the sara parameters
_ = model.load_state_dict(sara_state_dict, strict=False)

# Step 3: Merge the sara parameters into the model
merge_sara(model)


## Inferencing with multiple sara models

In [None]:
# # to avoid re-adding sara to the model when rerun the cell, remove sara first 
# remove_sara(model)
# # Step 1: Add sara to your model
# add_sara(model)

# # Step 2: Load the sara parameters

# # fake 3 sets of sara parameters
# sara_state_dict_0 = sara_state_dict
# sara_state_dict_1 = {k: torch.ones_like(v) for k, v in sara_state_dict.items()}
# sara_state_dict_2 = {k: torch.zeros_like(v) for k, v in sara_state_dict.items()}
# sara_state_dicts = [sara_state_dict_0, sara_state_dict_1, sara_state_dict_2]

# load_multiple_sara(model, sara_state_dicts)

# # Step 3: Select which sara to use at inference time
# Y0 = select_sara(model, 0)(x)
# Y1 = select_sara(model, 1)(x)
# Y2 = select_sara(model, 2)(x)


In [None]:
# Y0, Y1, Y2


In [None]:
remove_sara(model)
init_state_dict = model.state_dict()
# verify that it's the same as if we load the sara parameters one by one
for state_dict in sara_state_dicts:
    remove_sara(model)
    _ = model.load_state_dict(init_state_dict, strict=False)
    add_sara(model)
    _ = model.load_state_dict(state_dict, strict=False)
    merge_sara(model)
    y = model(x)
    print(y)


In [None]:
# class Test():
#     def __init__(self,num=1, layer=None):
#         # self.layer = layer
#         # for arg in args:
#             # print(arg)
#         self.layer = layer
#         self.num = num
# layer = torch.nn.Sequential(
#             torch.nn.Linear(5, 3),
#             torch.nn.ReLU(),
#             torch.nn.Linear(3, 3),
#             torch.nn.ReLU()
# )
# test = Test(1,layer=layer)
# # print(test)        
# # print(test.num) # 1
# # print(test.layer) # Linear(in_features=5, out_features=3, bias=True)
# # inspect(test.layer)
# print(test.layer[0])
# print(test.layer[0].weight)
# # print(test.layer[0]) 
# # print(test.layer[0].weight) 
# """Parameter containing:
# tensor([[-0.1163,  0.1544,  0.0566, -0.2275,  0.4066],
#         [-0.0287, -0.3928,  0.2575, -0.1188, -0.0773],
#         [-0.0870, -0.2780,  0.2427,  0.0463, -0.0287]], requires_grad=True)"""
        
# # print(test.layer.weight.shape) # torch.Size([3, 5])

# # print(test.layer.weight.dtype) # torch.float32

