In [1]:
import torch
from torch import nn
from lora_modules import LoRALinear

# Testing LoRALinear

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
x = torch.randn(size=(4, 16))
x = x.to(device)

In [4]:
in_features = 16
out_features = 32
bias = True

In [5]:
linear = nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
linear.to(device)
print(linear)

num_trainable_parameters_in_linear = 0
num_non_trainable_parameters_in_linear = 0
for parameter in linear.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_linear += parameter.numel()
    else:
        num_non_trainable_parameters_in_linear += parameter.numel()
print(f'Num trainable/non-trainable parameters in Linear: {num_trainable_parameters_in_linear}/{num_non_trainable_parameters_in_linear}')

Linear(in_features=16, out_features=32, bias=True)
Num trainable/non-trainable parameters in Linear: 544/0


In [6]:
linear(x)

tensor([[-0.9974,  0.2971,  0.4827, -0.1939,  0.0840,  0.3795, -0.1560, -0.8703,
          0.2838,  0.2311, -0.0806, -0.3444, -0.1399, -0.7673, -0.1217, -0.4844,
         -0.0116, -1.2735,  0.1225,  0.6411,  0.2643,  0.5202,  0.3701,  0.0701,
         -0.3931,  0.2665,  0.8281,  0.3925, -0.9762,  0.0633,  0.5376, -0.3201],
        [ 0.9448, -1.0424, -0.7873,  0.5302,  0.0658,  0.6236, -0.1832,  1.1294,
          0.4910, -0.0195,  0.7048,  1.3998,  0.4556,  0.8646, -0.2026, -0.4290,
         -0.1479, -0.0901,  1.0054, -0.4510, -0.1345, -0.1178, -0.3862,  0.4515,
          1.1967,  0.5483, -0.3600,  0.5713,  0.3901, -0.4534,  0.4421, -0.2241],
        [ 0.3602,  0.3605,  1.6869,  0.0447,  0.0131,  0.4407, -0.8250, -1.0779,
         -0.4620, -0.6680,  0.3935, -0.0041, -0.0141, -0.6126, -0.7687, -0.1662,
          0.3416, -0.8928,  0.2881,  0.9859, -0.6207,  0.1412,  0.5736, -0.8701,
         -0.0902, -0.3355,  1.2091,  0.4819,  0.5685, -0.2906,  0.4154, -0.7510],
        [ 0.4440, -0.3239

In [7]:
lora_config = {
    'rank': 4,
    'alpha': 2,
    'delta_bias': False
}

In [8]:
lora_linear = LoRALinear(linear, lora_config)
print(lora_linear)

num_trainable_parameters_in_lora_linear = 0
num_non_trainable_parameters_in_lora_linear = 0
for parameter in lora_linear.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_lora_linear += parameter.numel()
    else:
        num_non_trainable_parameters_in_lora_linear += parameter.numel()
print(f'Num trainable/non-trainable parameters in LoRA Linear: {num_trainable_parameters_in_lora_linear}/{num_non_trainable_parameters_in_lora_linear}')

LoRALinear(Linear(in_features=16, out_features=32, bias=True) + ((α=2.0/r=4) × Adapter(in_features=16, rank=4, out_features=32, delta_bias=False)))
Num trainable/non-trainable parameters in LoRA Linear: 192/546


In [9]:
lora_linear.disable_adapter()
lora_linear(x)

tensor([[-0.9974,  0.2971,  0.4827, -0.1939,  0.0840,  0.3795, -0.1560, -0.8703,
          0.2838,  0.2311, -0.0806, -0.3444, -0.1399, -0.7673, -0.1217, -0.4844,
         -0.0116, -1.2735,  0.1225,  0.6411,  0.2643,  0.5202,  0.3701,  0.0701,
         -0.3931,  0.2665,  0.8281,  0.3925, -0.9762,  0.0633,  0.5376, -0.3201],
        [ 0.9448, -1.0424, -0.7873,  0.5302,  0.0658,  0.6236, -0.1832,  1.1294,
          0.4910, -0.0195,  0.7048,  1.3998,  0.4556,  0.8646, -0.2026, -0.4290,
         -0.1479, -0.0901,  1.0054, -0.4510, -0.1345, -0.1178, -0.3862,  0.4515,
          1.1967,  0.5483, -0.3600,  0.5713,  0.3901, -0.4534,  0.4421, -0.2241],
        [ 0.3602,  0.3605,  1.6869,  0.0447,  0.0131,  0.4407, -0.8250, -1.0779,
         -0.4620, -0.6680,  0.3935, -0.0041, -0.0141, -0.6126, -0.7687, -0.1662,
          0.3416, -0.8928,  0.2881,  0.9859, -0.6207,  0.1412,  0.5736, -0.8701,
         -0.0902, -0.3355,  1.2091,  0.4819,  0.5685, -0.2906,  0.4154, -0.7510],
        [ 0.4440, -0.3239

In [10]:
lora_linear.enable_adapter()
lora_linear(x)

tensor([[-1.1160e+00,  6.7169e-02,  1.7643e-01, -2.2289e-01,  2.8446e-02,
          1.0022e-01, -2.4928e-01, -9.5091e-01,  3.0792e-01, -1.9356e-02,
          1.0941e-01, -5.7883e-01,  3.6339e-01, -1.0004e+00, -2.0747e-01,
         -4.2455e-01, -1.3423e-03, -1.3222e+00, -5.3823e-02,  8.4764e-01,
          2.0883e-02,  6.4379e-01,  2.6028e-01, -2.8996e-01, -3.3305e-01,
          3.0832e-01,  6.7640e-01,  3.9435e-01, -5.0337e-01,  1.9751e-01,
          2.9437e-01, -4.5602e-01],
        [ 5.7233e-01, -6.6440e-01, -1.0173e+00,  1.6374e-01,  6.4171e-02,
          7.8138e-01, -2.7517e-01,  1.0964e+00,  6.9501e-01, -1.5336e-03,
          5.8168e-01,  1.2845e+00,  3.0624e-01,  5.9930e-01,  2.0957e-01,
         -7.8737e-01,  1.1396e-01, -5.4319e-02,  9.9113e-01, -6.8297e-01,
         -2.5250e-02,  5.3367e-01, -2.2419e-01,  2.6684e-01,  1.6029e+00,
          3.2535e-01, -3.4595e-01,  4.1655e-01,  5.6776e-01, -2.4828e-01,
          5.3327e-01,  8.7938e-02],
        [ 2.6089e-01,  1.0196e-01,  1.50

In [11]:
merged_linear = lora_linear.get_merged_module()
merged_linear(x)

tensor([[-1.1160e+00,  6.7169e-02,  1.7643e-01, -2.2289e-01,  2.8446e-02,
          1.0022e-01, -2.4928e-01, -9.5091e-01,  3.0792e-01, -1.9356e-02,
          1.0941e-01, -5.7883e-01,  3.6339e-01, -1.0004e+00, -2.0747e-01,
         -4.2455e-01, -1.3423e-03, -1.3222e+00, -5.3823e-02,  8.4764e-01,
          2.0883e-02,  6.4379e-01,  2.6028e-01, -2.8996e-01, -3.3305e-01,
          3.0832e-01,  6.7640e-01,  3.9435e-01, -5.0337e-01,  1.9751e-01,
          2.9437e-01, -4.5602e-01],
        [ 5.7233e-01, -6.6440e-01, -1.0173e+00,  1.6374e-01,  6.4171e-02,
          7.8138e-01, -2.7517e-01,  1.0964e+00,  6.9501e-01, -1.5336e-03,
          5.8168e-01,  1.2845e+00,  3.0624e-01,  5.9930e-01,  2.0957e-01,
         -7.8737e-01,  1.1396e-01, -5.4319e-02,  9.9113e-01, -6.8297e-01,
         -2.5250e-02,  5.3367e-01, -2.2419e-01,  2.6684e-01,  1.6029e+00,
          3.2535e-01, -3.4595e-01,  4.1655e-01,  5.6776e-01, -2.4828e-01,
          5.3327e-01,  8.7938e-02],
        [ 2.6089e-01,  1.0196e-01,  1.50

In [12]:
lora_linear.disable_adapter()
lora_linear(x)

tensor([[-0.9974,  0.2971,  0.4827, -0.1939,  0.0840,  0.3795, -0.1560, -0.8703,
          0.2838,  0.2311, -0.0806, -0.3444, -0.1399, -0.7673, -0.1217, -0.4844,
         -0.0116, -1.2735,  0.1225,  0.6411,  0.2643,  0.5202,  0.3701,  0.0701,
         -0.3931,  0.2665,  0.8281,  0.3925, -0.9762,  0.0633,  0.5376, -0.3201],
        [ 0.9448, -1.0424, -0.7873,  0.5302,  0.0658,  0.6236, -0.1832,  1.1294,
          0.4910, -0.0195,  0.7048,  1.3998,  0.4556,  0.8646, -0.2026, -0.4290,
         -0.1479, -0.0901,  1.0054, -0.4510, -0.1345, -0.1178, -0.3862,  0.4515,
          1.1967,  0.5483, -0.3600,  0.5713,  0.3901, -0.4534,  0.4421, -0.2241],
        [ 0.3602,  0.3605,  1.6869,  0.0447,  0.0131,  0.4407, -0.8250, -1.0779,
         -0.4620, -0.6680,  0.3935, -0.0041, -0.0141, -0.6126, -0.7687, -0.1662,
          0.3416, -0.8928,  0.2881,  0.9859, -0.6207,  0.1412,  0.5736, -0.8701,
         -0.0902, -0.3355,  1.2091,  0.4819,  0.5685, -0.2906,  0.4154, -0.7510],
        [ 0.4440, -0.3239