In [1]:
import torch
from torch import nn
from lora_modules import LoRALinear

# Testing LoRALinear

In [2]:
x = torch.randn(size=(4, 16))

In [3]:
in_features = 16
out_features = 32
bias = True

In [4]:
linear = nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
print(linear)

num_trainable_parameters_in_linear = 0
num_non_trainable_parameters_in_linear = 0
for parameter in linear.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_linear += parameter.numel()
    else:
        num_non_trainable_parameters_in_linear += parameter.numel()
print(f'Num trainable/non-trainable parameters in Linear: {num_trainable_parameters_in_linear}/{num_non_trainable_parameters_in_linear}')

Linear(in_features=16, out_features=32, bias=True)
Num trainable/non-trainable parameters in Linear: 544/0


In [5]:
linear(x)

tensor([[ 0.3112, -0.5434,  0.5286, -0.3006, -0.2723, -0.3482, -0.1714, -0.2795,
         -0.8456,  0.3614,  0.3751, -0.4710, -0.4060, -0.4528,  0.1198, -0.3354,
         -0.0631, -0.3033,  0.2364,  0.2626, -0.4088, -0.4054, -0.3306,  0.0158,
         -0.2222,  0.3794, -0.4798, -0.3367, -0.1746, -0.5843,  0.0875,  0.4903],
        [-0.2201, -0.4519,  0.5687,  0.7261, -1.4542, -0.3455, -0.5674, -0.5083,
          0.5267,  0.2598, -0.2588, -0.1397,  0.1017,  0.8528,  0.0521,  0.9468,
         -0.0486, -0.1825, -0.4143, -0.6748, -0.2408,  0.4069,  0.4898,  0.9341,
         -0.2967,  0.3735, -0.4919, -0.0398,  1.0945,  0.5666, -0.0376,  0.5202],
        [-0.5604,  0.6451,  0.1797, -0.6332,  0.6717,  0.2009,  1.1105, -1.0070,
          0.3122, -0.3347,  0.2356,  0.0648,  0.9382, -0.4550,  0.6496, -0.6390,
         -0.1045, -0.5934,  0.7739, -0.2603,  0.4633,  0.9035,  0.7268, -0.6977,
         -0.5872, -0.4794,  0.0224, -0.4526,  0.2953,  0.4776,  0.3914, -0.1528],
        [-0.3303, -0.8498

In [6]:
lora_config = {
    'rank': 4,
    'alpha': 2,
    'delta_bias': False
}

In [7]:
lora_linear = LoRALinear()
lora_linear.add_base_module(base_module=linear)
lora_linear.build_new_adapter(lora_config=lora_config)
print(lora_linear)

num_trainable_parameters_in_lora_linear = 0
num_non_trainable_parameters_in_lora_linear = 0
for parameter in lora_linear.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_lora_linear += parameter.numel()
    else:
        num_non_trainable_parameters_in_lora_linear += parameter.numel()
print(f'Num trainable/non-trainable parameters in LoRA Linear: {num_trainable_parameters_in_lora_linear}/{num_non_trainable_parameters_in_lora_linear}')

LoRALinear(Linear(in_features=16, out_features=32, bias=True) + ((α=2/r=4) × Adapter(in_features=16, rank=4, out_features=32, delta_bias=False)))
Num trainable/non-trainable parameters in LoRA Linear: 736/2


In [8]:
lora_linear.disable_adapter()
lora_linear(x)

tensor([[ 0.3112, -0.5434,  0.5286, -0.3006, -0.2723, -0.3482, -0.1714, -0.2795,
         -0.8456,  0.3614,  0.3751, -0.4710, -0.4060, -0.4528,  0.1198, -0.3354,
         -0.0631, -0.3033,  0.2364,  0.2626, -0.4088, -0.4054, -0.3306,  0.0158,
         -0.2222,  0.3794, -0.4798, -0.3367, -0.1746, -0.5843,  0.0875,  0.4903],
        [-0.2201, -0.4519,  0.5687,  0.7261, -1.4542, -0.3455, -0.5674, -0.5083,
          0.5267,  0.2598, -0.2588, -0.1397,  0.1017,  0.8528,  0.0521,  0.9468,
         -0.0486, -0.1825, -0.4143, -0.6748, -0.2408,  0.4069,  0.4898,  0.9341,
         -0.2967,  0.3735, -0.4919, -0.0398,  1.0945,  0.5666, -0.0376,  0.5202],
        [-0.5604,  0.6451,  0.1797, -0.6332,  0.6717,  0.2009,  1.1105, -1.0070,
          0.3122, -0.3347,  0.2356,  0.0648,  0.9382, -0.4550,  0.6496, -0.6390,
         -0.1045, -0.5934,  0.7739, -0.2603,  0.4633,  0.9035,  0.7268, -0.6977,
         -0.5872, -0.4794,  0.0224, -0.4526,  0.2953,  0.4776,  0.3914, -0.1528],
        [-0.3303, -0.8498

In [9]:
lora_linear.enable_adapter()
lora_linear(x)

tensor([[ -3.3009,  -2.4167,   3.1948,  -5.2583,  -0.3836,  -2.5413,  -1.5520,
           0.9593,   0.9490,   2.5349,   3.4287,  -2.5208,  -3.3501,  -1.4585,
          -6.6375,  -8.1450,  -2.5488,  -0.8129,   1.5925,  -1.0689,  -1.5069,
           3.3821,   3.6107,   4.5306,  -4.1186,   0.9992,  -1.4753,  -1.9164,
           1.9871,  -5.9008,   0.3173,  -0.1720],
        [ -2.4532,  -3.0258,   5.2253,  -2.8147,   0.4041,  -7.1005,  -9.2512,
          -4.0248,  -3.6919,  -0.1330,  -0.6812,  -0.5600,   0.2983,  -0.2980,
           2.1269,  -4.6184,   3.8216,   0.6124,   3.6335,   1.1796,  -1.2590,
         -11.9950, -14.7049,  -5.4569, -14.1237,   1.0820, -12.3658,  -2.9393,
           4.1095,  11.7725,  -7.9211,  -8.7318],
        [ -1.3734,   0.0388,   0.5073,  -1.2695,  -0.3597,  -1.4701,  -0.5408,
          -1.4293,  -0.4984,   0.7661,   1.6293,   1.5113,   0.8249,  -0.3904,
           2.2003,  -0.5381,   2.0508,  -2.3375,   1.1881,   1.4566,  -1.7635,
          -1.5520,  -0.4522,  -

In [10]:
merged_linear = lora_linear.get_merged_module()
merged_linear(x)

tensor([[ -3.3009,  -2.4167,   3.1948,  -5.2583,  -0.3836,  -2.5413,  -1.5520,
           0.9593,   0.9490,   2.5349,   3.4287,  -2.5208,  -3.3501,  -1.4585,
          -6.6375,  -8.1450,  -2.5488,  -0.8129,   1.5925,  -1.0689,  -1.5069,
           3.3821,   3.6107,   4.5306,  -4.1186,   0.9992,  -1.4753,  -1.9164,
           1.9871,  -5.9008,   0.3173,  -0.1720],
        [ -2.4532,  -3.0258,   5.2253,  -2.8147,   0.4041,  -7.1005,  -9.2512,
          -4.0248,  -3.6919,  -0.1330,  -0.6812,  -0.5600,   0.2983,  -0.2980,
           2.1269,  -4.6184,   3.8216,   0.6124,   3.6335,   1.1796,  -1.2590,
         -11.9950, -14.7049,  -5.4569, -14.1237,   1.0820, -12.3658,  -2.9393,
           4.1095,  11.7725,  -7.9211,  -8.7318],
        [ -1.3734,   0.0388,   0.5073,  -1.2695,  -0.3597,  -1.4701,  -0.5408,
          -1.4293,  -0.4984,   0.7661,   1.6293,   1.5113,   0.8249,  -0.3904,
           2.2003,  -0.5381,   2.0508,  -2.3375,   1.1881,   1.4566,  -1.7635,
          -1.5520,  -0.4522,  -

In [11]:
lora_linear.disable_adapter()
lora_linear(x)

tensor([[ 0.3112, -0.5434,  0.5286, -0.3006, -0.2723, -0.3482, -0.1714, -0.2795,
         -0.8456,  0.3614,  0.3751, -0.4710, -0.4060, -0.4528,  0.1198, -0.3354,
         -0.0631, -0.3033,  0.2364,  0.2626, -0.4088, -0.4054, -0.3306,  0.0158,
         -0.2222,  0.3794, -0.4798, -0.3367, -0.1746, -0.5843,  0.0875,  0.4903],
        [-0.2201, -0.4519,  0.5687,  0.7261, -1.4542, -0.3455, -0.5674, -0.5083,
          0.5267,  0.2598, -0.2588, -0.1397,  0.1017,  0.8528,  0.0521,  0.9468,
         -0.0486, -0.1825, -0.4143, -0.6748, -0.2408,  0.4069,  0.4898,  0.9341,
         -0.2967,  0.3735, -0.4919, -0.0398,  1.0945,  0.5666, -0.0376,  0.5202],
        [-0.5604,  0.6451,  0.1797, -0.6332,  0.6717,  0.2009,  1.1105, -1.0070,
          0.3122, -0.3347,  0.2356,  0.0648,  0.9382, -0.4550,  0.6496, -0.6390,
         -0.1045, -0.5934,  0.7739, -0.2603,  0.4633,  0.9035,  0.7268, -0.6977,
         -0.5872, -0.4794,  0.0224, -0.4526,  0.2953,  0.4776,  0.3914, -0.1528],
        [-0.3303, -0.8498