In [1]:
import torch
from torch import nn
from lora_modules import LoRALinear

# Testing LoRALinear

In [2]:
x = torch.randn(size=(4, 16))

In [3]:
in_features = 16
out_features = 32
bias = True

In [4]:
linear = nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
print(linear)

num_trainable_parameters_in_linear = 0
num_non_trainable_parameters_in_linear = 0
for parameter in linear.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_linear += parameter.numel()
    else:
        num_non_trainable_parameters_in_linear += parameter.numel()
print(f'Num trainable/non-trainable parameters in Linear: {num_trainable_parameters_in_linear}/{num_non_trainable_parameters_in_linear}')

Linear(in_features=16, out_features=32, bias=True)
Num trainable/non-trainable parameters in Linear: 544/0


In [5]:
linear(x)

tensor([[-0.3338, -0.6755,  0.4683,  0.0024, -0.3157,  0.0406, -0.3260, -1.1837,
         -0.1478,  0.4814, -0.1412,  0.7134, -0.4066, -0.0354,  0.7776,  0.4889,
          0.6230,  0.5008, -0.2684,  0.3058, -0.3035,  0.7830, -1.0276,  0.6032,
         -1.3451, -0.8008, -0.9648,  0.6712,  1.1311,  0.1094,  1.2219, -0.7236],
        [-0.2401, -0.7181, -0.1038, -0.1265,  0.5834, -0.1813,  0.0111,  0.2977,
         -0.7698,  0.7042,  0.0683,  0.8807, -0.4751,  0.4882,  0.5226, -0.4313,
         -0.3034, -0.1315, -0.1363,  0.2115, -0.1644,  1.6095,  0.3726,  1.4429,
          0.2238, -0.6671,  0.4667,  0.3930, -0.6234,  0.2655,  0.4400, -0.6092],
        [-0.3184,  0.5939, -0.9461, -0.7104,  0.0713, -1.2302,  0.3890,  0.3616,
         -0.1091, -1.1380,  1.1048, -0.4713, -0.4286,  0.2103,  1.0851,  0.5802,
         -0.3804,  1.1119,  0.7919, -0.7343, -0.3367, -0.7302, -0.4052, -1.1207,
          0.3852, -1.1809,  0.2259, -0.4546,  0.0410,  0.2004, -0.1868, -0.5725],
        [ 0.3625, -0.6777

In [6]:
linear.state_dict()

OrderedDict([('weight',
              tensor([[ 4.6334e-02, -1.8275e-01,  7.5900e-02,  1.4937e-01, -1.9879e-02,
                        2.2637e-01, -2.0035e-01,  4.2305e-02,  7.6439e-02, -1.4376e-01,
                       -2.3370e-01, -1.3972e-01,  1.7278e-01,  1.3309e-01, -8.8608e-02,
                        6.1890e-02],
                      [ 5.0454e-02,  1.5182e-01,  1.2406e-01, -9.5388e-02,  1.2845e-01,
                       -7.6813e-02,  1.0319e-01,  1.8125e-01,  9.0450e-02, -2.2708e-01,
                       -2.4401e-01,  1.3567e-01,  9.9102e-02, -2.0264e-01, -1.2471e-01,
                        8.8342e-02],
                      [-5.0123e-02, -8.8491e-02, -2.4167e-02,  5.7067e-02,  1.4452e-02,
                       -1.7246e-01,  1.0535e-01,  3.1072e-02,  2.1555e-01,  2.4766e-01,
                        2.0399e-01, -2.0178e-01,  2.3337e-01,  1.1378e-01,  9.0520e-02,
                       -5.5212e-02],
                      [ 1.4570e-01,  2.3578e-01,  2.0913e-01, -2.4772e-01

In [7]:
lora_config = {
    'rank': 4,
    'alpha': 2,
    'delta_bias': False
}

In [8]:
lora_linear = LoRALinear()
lora_linear.add_base_module(base_module=linear)
lora_linear.build_new_adapter(lora_config=lora_config)
print(lora_linear)

num_trainable_parameters_in_lora_linear = 0
num_non_trainable_parameters_in_lora_linear = 0
for parameter in lora_linear.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_lora_linear += parameter.numel()
    else:
        num_non_trainable_parameters_in_lora_linear += parameter.numel()
print(f'Num trainable/non-trainable parameters in LoRA Linear: {num_trainable_parameters_in_lora_linear}/{num_non_trainable_parameters_in_lora_linear}')

LoRALinear(Linear(in_features=16, out_features=32, bias=True) + ((α=2/r=4) × Adapter(in_features=16, rank=4, out_features=32, delta_bias=False)))
Num trainable/non-trainable parameters in LoRA Linear: 192/546


In [9]:
lora_linear.disable_adapter()
lora_linear(x)

tensor([[-0.3338, -0.6755,  0.4683,  0.0024, -0.3157,  0.0406, -0.3260, -1.1837,
         -0.1478,  0.4814, -0.1412,  0.7134, -0.4066, -0.0354,  0.7776,  0.4889,
          0.6230,  0.5008, -0.2684,  0.3058, -0.3035,  0.7830, -1.0276,  0.6032,
         -1.3451, -0.8008, -0.9648,  0.6712,  1.1311,  0.1094,  1.2219, -0.7236],
        [-0.2401, -0.7181, -0.1038, -0.1265,  0.5834, -0.1813,  0.0111,  0.2977,
         -0.7698,  0.7042,  0.0683,  0.8807, -0.4751,  0.4882,  0.5226, -0.4313,
         -0.3034, -0.1315, -0.1363,  0.2115, -0.1644,  1.6095,  0.3726,  1.4429,
          0.2238, -0.6671,  0.4667,  0.3930, -0.6234,  0.2655,  0.4400, -0.6092],
        [-0.3184,  0.5939, -0.9461, -0.7104,  0.0713, -1.2302,  0.3890,  0.3616,
         -0.1091, -1.1380,  1.1048, -0.4713, -0.4286,  0.2103,  1.0851,  0.5802,
         -0.3804,  1.1119,  0.7919, -0.7343, -0.3367, -0.7302, -0.4052, -1.1207,
          0.3852, -1.1809,  0.2259, -0.4546,  0.0410,  0.2004, -0.1868, -0.5725],
        [ 0.3625, -0.6777

In [10]:
lora_linear.enable_adapter()
lora_linear(x)

tensor([[ 6.0684e+00,  4.5627e+00,  2.1684e+00,  5.4600e+00, -5.4508e-03,
          8.1957e+00,  2.5495e+00, -2.3648e+00, -4.2509e+00, -1.4196e+00,
          4.2199e+00,  8.1056e+00,  4.5921e+00,  3.4768e-01,  7.7350e+00,
          4.5583e+00,  2.7909e-01,  4.4991e+00,  1.1443e+00,  2.5971e+00,
          2.7029e+00, -1.7982e+00,  3.4135e+00,  4.7136e+00, -1.6290e+00,
         -4.8815e+00,  1.2306e+00, -6.5736e-01,  8.2388e+00, -3.1371e-01,
          8.3457e+00, -1.5162e+00],
        [ 2.8686e+00,  4.1355e+00,  2.3246e-01, -4.5670e-01, -6.4506e-01,
         -2.3667e+00, -2.4346e+00,  1.2379e+00, -2.3705e+00, -6.9970e-01,
          1.6560e+00,  8.9758e+00,  5.0263e+00,  2.2981e+00, -5.1084e-01,
          2.0733e+00,  1.2014e+00,  1.3779e+00, -1.3204e+00, -7.8606e-01,
          2.7798e+00,  3.5013e-02, -1.5381e+00,  3.2651e+00, -6.5935e+00,
         -1.0630e-01, -3.3617e-03, -7.7015e-01,  4.4765e+00, -1.1149e-01,
          3.8171e+00, -5.2694e+00],
        [ 1.8984e+00, -5.6806e+00,  5.27

In [11]:
merged_linear = lora_linear.get_merged_module()
merged_linear(x)

tensor([[ 6.0684e+00,  4.5627e+00,  2.1684e+00,  5.4600e+00, -5.4508e-03,
          8.1957e+00,  2.5495e+00, -2.3648e+00, -4.2509e+00, -1.4196e+00,
          4.2199e+00,  8.1056e+00,  4.5921e+00,  3.4768e-01,  7.7350e+00,
          4.5583e+00,  2.7909e-01,  4.4991e+00,  1.1443e+00,  2.5971e+00,
          2.7029e+00, -1.7982e+00,  3.4135e+00,  4.7136e+00, -1.6290e+00,
         -4.8815e+00,  1.2306e+00, -6.5736e-01,  8.2388e+00, -3.1371e-01,
          8.3457e+00, -1.5162e+00],
        [ 2.8686e+00,  4.1355e+00,  2.3246e-01, -4.5670e-01, -6.4506e-01,
         -2.3667e+00, -2.4346e+00,  1.2379e+00, -2.3705e+00, -6.9970e-01,
          1.6560e+00,  8.9758e+00,  5.0263e+00,  2.2981e+00, -5.1084e-01,
          2.0733e+00,  1.2014e+00,  1.3779e+00, -1.3204e+00, -7.8606e-01,
          2.7798e+00,  3.5013e-02, -1.5381e+00,  3.2651e+00, -6.5935e+00,
         -1.0630e-01, -3.3617e-03, -7.7015e-01,  4.4765e+00, -1.1149e-01,
          3.8171e+00, -5.2694e+00],
        [ 1.8984e+00, -5.6806e+00,  5.27

In [12]:
lora_linear.disable_adapter()
lora_linear(x)

tensor([[-0.3338, -0.6755,  0.4683,  0.0024, -0.3157,  0.0406, -0.3260, -1.1837,
         -0.1478,  0.4814, -0.1412,  0.7134, -0.4066, -0.0354,  0.7776,  0.4889,
          0.6230,  0.5008, -0.2684,  0.3058, -0.3035,  0.7830, -1.0276,  0.6032,
         -1.3451, -0.8008, -0.9648,  0.6712,  1.1311,  0.1094,  1.2219, -0.7236],
        [-0.2401, -0.7181, -0.1038, -0.1265,  0.5834, -0.1813,  0.0111,  0.2977,
         -0.7698,  0.7042,  0.0683,  0.8807, -0.4751,  0.4882,  0.5226, -0.4313,
         -0.3034, -0.1315, -0.1363,  0.2115, -0.1644,  1.6095,  0.3726,  1.4429,
          0.2238, -0.6671,  0.4667,  0.3930, -0.6234,  0.2655,  0.4400, -0.6092],
        [-0.3184,  0.5939, -0.9461, -0.7104,  0.0713, -1.2302,  0.3890,  0.3616,
         -0.1091, -1.1380,  1.1048, -0.4713, -0.4286,  0.2103,  1.0851,  0.5802,
         -0.3804,  1.1119,  0.7919, -0.7343, -0.3367, -0.7302, -0.4052, -1.1207,
          0.3852, -1.1809,  0.2259, -0.4546,  0.0410,  0.2004, -0.1868, -0.5725],
        [ 0.3625, -0.6777

In [13]:
lora_linear.state_dict()

OrderedDict([('base',
              OrderedDict([('weight',
                            tensor([[ 4.6334e-02, -1.8275e-01,  7.5900e-02,  1.4937e-01, -1.9879e-02,
                                      2.2637e-01, -2.0035e-01,  4.2305e-02,  7.6439e-02, -1.4376e-01,
                                     -2.3370e-01, -1.3972e-01,  1.7278e-01,  1.3309e-01, -8.8608e-02,
                                      6.1890e-02],
                                    [ 5.0454e-02,  1.5182e-01,  1.2406e-01, -9.5388e-02,  1.2845e-01,
                                     -7.6813e-02,  1.0319e-01,  1.8125e-01,  9.0450e-02, -2.2708e-01,
                                     -2.4401e-01,  1.3567e-01,  9.9102e-02, -2.0264e-01, -1.2471e-01,
                                      8.8342e-02],
                                    [-5.0123e-02, -8.8491e-02, -2.4167e-02,  5.7067e-02,  1.4452e-02,
                                     -1.7246e-01,  1.0535e-01,  3.1072e-02,  2.1555e-01,  2.4766e-01,
                      

In [14]:
torch.save(lora_linear.state_dict(), 'linear.pth')

In [15]:
# lora_linear = LoRALinear()

lora_linear.load_state_dict(torch.load('linear.pth', weights_only=True))