In [1]:
import torch
from torch import nn
from lora_modules import LoRALinear, LoRAConv2d

# Testing LoRALinear

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
x = torch.randn(size=(4, 16))
x = x.to(device)

In [4]:
in_features = 16
out_features = 32
bias = True

In [5]:
linear = nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
linear.to(device)
print(linear)

num_trainable_parameters_in_linear = 0
num_non_trainable_parameters_in_linear = 0
for parameter in linear.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_linear += parameter.numel()
    else:
        num_non_trainable_parameters_in_linear += parameter.numel()
print(f'Num trainable/non-trainable parameters in Linear: {num_trainable_parameters_in_linear}/{num_non_trainable_parameters_in_linear}')

Linear(in_features=16, out_features=32, bias=True)
Num trainable/non-trainable parameters in Linear: 544/0


In [6]:
linear(x)

tensor([[ 3.0121e-01,  1.7673e-01, -1.4974e-01, -4.8116e-01,  5.6039e-02,
         -5.9517e-01, -4.3947e-02, -2.5015e-01, -2.2185e-01,  7.1066e-01,
          3.5674e-01, -5.0222e-03, -4.8637e-01,  2.0448e-01, -1.6874e+00,
          1.3945e-01, -4.9349e-01,  1.1055e-01,  1.0014e+00, -2.3313e-01,
          2.9794e-01, -2.1086e-01, -1.1446e+00,  2.8450e-01,  8.1779e-01,
         -7.5265e-01, -1.7032e-01,  3.9302e-01, -8.7387e-01, -1.0879e-01,
         -1.5706e-01,  4.0187e-01],
        [ 1.0546e+00, -2.9261e-01,  3.6825e-01, -8.5120e-02, -1.1122e-01,
         -3.4779e-01,  7.6557e-01, -1.3964e-01, -1.5457e-01, -7.9979e-02,
          4.0814e-01,  3.0458e-01,  9.1026e-02, -2.6182e-01, -1.8413e-01,
         -4.0557e-01, -3.1348e-01,  4.3335e-01,  9.2016e-01,  5.8965e-02,
          2.0841e-01,  1.1663e-01, -1.2426e+00, -9.0326e-01, -4.1864e-01,
         -2.9908e-01,  2.4509e-01,  7.3977e-01, -2.0913e-01, -2.4118e-01,
         -2.9867e-02,  1.0786e-01],
        [ 4.2192e-01,  8.3268e-01, -2.59

In [7]:
linear.state_dict()

OrderedDict([('weight',
              tensor([[ 6.8529e-02,  2.4812e-01, -1.3410e-01,  5.9949e-02, -2.1223e-01,
                        9.5045e-02,  2.1307e-01,  1.9693e-01,  1.1533e-01,  3.1733e-02,
                       -8.0075e-02,  1.7078e-01, -1.0011e-01, -7.0609e-02,  3.1599e-04,
                       -8.0654e-02],
                      [ 2.3126e-01,  6.4427e-02,  1.0232e-01, -1.1834e-01, -9.6964e-02,
                        2.3165e-01,  6.9982e-02, -2.3467e-01,  1.1491e-01,  2.1404e-01,
                        2.2293e-01, -2.0243e-02, -8.7904e-02,  1.5950e-01,  1.2627e-01,
                        1.4240e-02],
                      [ 2.2045e-02, -8.6731e-02,  1.0289e-01,  1.2742e-01,  2.1366e-01,
                       -2.1565e-02,  3.8092e-02,  6.1714e-02, -1.5625e-01, -1.1964e-01,
                        2.0900e-01,  2.3788e-01,  7.8003e-02, -2.0716e-01,  2.2925e-01,
                       -2.4099e-01],
                      [-1.1507e-01,  1.6252e-01, -4.4262e-02, -8.3881e-02

In [8]:
lora_config = {
    'rank': 4,
    'alpha': 2,
    'delta_bias': False
}

In [9]:
lora_linear = LoRALinear(linear, lora_config)
print(lora_linear)

num_trainable_parameters_in_lora_linear = 0
num_non_trainable_parameters_in_lora_linear = 0
for parameter in lora_linear.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_lora_linear += parameter.numel()
    else:
        num_non_trainable_parameters_in_lora_linear += parameter.numel()
print(f'Num trainable/non-trainable parameters in LoRA Linear: {num_trainable_parameters_in_lora_linear}/{num_non_trainable_parameters_in_lora_linear}')

LoRALinear(Linear(in_features=16, out_features=32, bias=True) + ((α=2.0/r=4) × Adapter(in_features=16, rank=4, out_features=32)))
Num trainable/non-trainable parameters in LoRA Linear: 192/544


In [10]:
lora_linear.disable_adapter()
lora_linear(x)

tensor([[ 3.0121e-01,  1.7673e-01, -1.4974e-01, -4.8116e-01,  5.6039e-02,
         -5.9517e-01, -4.3947e-02, -2.5015e-01, -2.2185e-01,  7.1066e-01,
          3.5674e-01, -5.0222e-03, -4.8637e-01,  2.0448e-01, -1.6874e+00,
          1.3945e-01, -4.9349e-01,  1.1055e-01,  1.0014e+00, -2.3313e-01,
          2.9794e-01, -2.1086e-01, -1.1446e+00,  2.8450e-01,  8.1779e-01,
         -7.5265e-01, -1.7032e-01,  3.9302e-01, -8.7387e-01, -1.0879e-01,
         -1.5706e-01,  4.0187e-01],
        [ 1.0546e+00, -2.9261e-01,  3.6825e-01, -8.5120e-02, -1.1122e-01,
         -3.4779e-01,  7.6557e-01, -1.3964e-01, -1.5457e-01, -7.9979e-02,
          4.0814e-01,  3.0458e-01,  9.1026e-02, -2.6182e-01, -1.8413e-01,
         -4.0557e-01, -3.1348e-01,  4.3335e-01,  9.2016e-01,  5.8965e-02,
          2.0841e-01,  1.1663e-01, -1.2426e+00, -9.0326e-01, -4.1864e-01,
         -2.9908e-01,  2.4509e-01,  7.3977e-01, -2.0913e-01, -2.4118e-01,
         -2.9867e-02,  1.0786e-01],
        [ 4.2192e-01,  8.3268e-01, -2.59

In [11]:
lora_linear.enable_adapter()
lora_linear(x)

tensor([[ 3.8348e-01,  2.7875e-01, -3.1137e-01, -3.4245e-01, -6.4704e-02,
         -4.4210e-01, -1.4597e-01, -3.5983e-01, -2.6085e-01,  6.9237e-01,
          4.5160e-01, -1.6079e-01, -2.3865e-01,  9.3919e-02, -1.5934e+00,
          1.7386e-01, -6.5451e-01,  1.7411e-04,  7.3359e-01, -2.5610e-01,
          4.1543e-01, -3.0035e-01, -1.2209e+00,  1.0434e-01,  5.2494e-01,
         -4.9435e-01, -1.8645e-01,  4.1482e-01, -8.9612e-01,  5.1659e-03,
         -1.0912e-01,  3.3511e-01],
        [ 9.3211e-01, -2.2974e-01,  2.5063e-01,  5.9250e-02, -1.5117e-01,
         -4.0332e-01,  9.8315e-01, -1.0679e-01, -2.9766e-02, -1.9202e-01,
          3.3213e-01,  4.7514e-01,  7.5761e-02, -9.3565e-02,  2.8139e-02,
         -4.3185e-01, -5.9644e-01,  6.0336e-01,  9.2772e-01,  2.8500e-01,
          3.1167e-01, -7.2709e-02, -1.2417e+00, -8.9830e-01, -6.2213e-01,
         -4.0055e-01,  3.7789e-02,  1.0335e+00, -3.7279e-01, -4.6125e-01,
         -1.1021e-01,  2.2263e-01],
        [ 3.6590e-01,  7.8862e-01, -1.39

In [12]:
merged_linear = lora_linear.get_merged_module()
merged_linear(x)

tensor([[ 3.8348e-01,  2.7875e-01, -3.1137e-01, -3.4245e-01, -6.4704e-02,
         -4.4210e-01, -1.4597e-01, -3.5983e-01, -2.6085e-01,  6.9237e-01,
          4.5160e-01, -1.6079e-01, -2.3865e-01,  9.3919e-02, -1.5934e+00,
          1.7386e-01, -6.5451e-01,  1.7411e-04,  7.3359e-01, -2.5610e-01,
          4.1543e-01, -3.0035e-01, -1.2209e+00,  1.0434e-01,  5.2494e-01,
         -4.9435e-01, -1.8645e-01,  4.1482e-01, -8.9612e-01,  5.1659e-03,
         -1.0912e-01,  3.3511e-01],
        [ 9.3211e-01, -2.2974e-01,  2.5063e-01,  5.9250e-02, -1.5117e-01,
         -4.0332e-01,  9.8315e-01, -1.0679e-01, -2.9766e-02, -1.9202e-01,
          3.3213e-01,  4.7514e-01,  7.5761e-02, -9.3565e-02,  2.8139e-02,
         -4.3185e-01, -5.9644e-01,  6.0336e-01,  9.2772e-01,  2.8500e-01,
          3.1167e-01, -7.2709e-02, -1.2417e+00, -8.9830e-01, -6.2213e-01,
         -4.0055e-01,  3.7789e-02,  1.0335e+00, -3.7279e-01, -4.6125e-01,
         -1.1021e-01,  2.2263e-01],
        [ 3.6590e-01,  7.8862e-01, -1.39

In [13]:
lora_linear.disable_adapter()
lora_linear(x)

tensor([[ 3.0121e-01,  1.7673e-01, -1.4974e-01, -4.8116e-01,  5.6039e-02,
         -5.9517e-01, -4.3947e-02, -2.5015e-01, -2.2185e-01,  7.1066e-01,
          3.5674e-01, -5.0222e-03, -4.8637e-01,  2.0448e-01, -1.6874e+00,
          1.3945e-01, -4.9349e-01,  1.1055e-01,  1.0014e+00, -2.3313e-01,
          2.9794e-01, -2.1086e-01, -1.1446e+00,  2.8450e-01,  8.1779e-01,
         -7.5265e-01, -1.7032e-01,  3.9302e-01, -8.7387e-01, -1.0879e-01,
         -1.5706e-01,  4.0187e-01],
        [ 1.0546e+00, -2.9261e-01,  3.6825e-01, -8.5120e-02, -1.1122e-01,
         -3.4779e-01,  7.6557e-01, -1.3964e-01, -1.5457e-01, -7.9979e-02,
          4.0814e-01,  3.0458e-01,  9.1026e-02, -2.6182e-01, -1.8413e-01,
         -4.0557e-01, -3.1348e-01,  4.3335e-01,  9.2016e-01,  5.8965e-02,
          2.0841e-01,  1.1663e-01, -1.2426e+00, -9.0326e-01, -4.1864e-01,
         -2.9908e-01,  2.4509e-01,  7.3977e-01, -2.0913e-01, -2.4118e-01,
         -2.9867e-02,  1.0786e-01],
        [ 4.2192e-01,  8.3268e-01, -2.59

In [14]:
lora_linear.state_dict()

OrderedDict([('delta_weight_A',
              tensor([[-0.1459, -0.2784,  0.0151,  0.2318],
                      [-0.0873, -0.0340,  0.2679,  0.4783],
                      [ 0.3030, -0.1053, -0.0024, -0.1398],
                      [-0.3718,  0.2352, -0.2165, -0.2865],
                      [ 0.2626,  0.1223, -0.4108, -0.2716],
                      [-0.1616, -0.0715, -0.4687,  0.2970],
                      [ 0.3988,  0.4013,  0.1683,  0.2680],
                      [ 0.4024,  0.0231,  0.3496,  0.4640],
                      [ 0.2453,  0.2736, -0.1085,  0.1957],
                      [ 0.0635, -0.2199,  0.1029,  0.1791],
                      [-0.4053, -0.2016,  0.0435, -0.3164],
                      [ 0.4602,  0.4023, -0.0032,  0.0460],
                      [-0.4855, -0.1831, -0.0180,  0.2976],
                      [ 0.1561,  0.4176, -0.1439, -0.4723],
                      [-0.1228,  0.1841,  0.4708,  0.3475],
                      [-0.1662, -0.0434, -0.0778, -0.2095],
        

# Testing LoRAConv2d

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [16]:
x = torch.randn(size=(3, 8, 8))
x = x.to(device)

In [17]:
in_channels = 3
out_channels = 2
kernel_size = 5
stride = 1
padding = 2
dilation = 1
bias = True
padding_mode = 'zeros'

In [18]:
conv2d = nn.Conv2d(
    in_channels=in_channels,
    out_channels=out_channels,
    kernel_size=kernel_size,
    stride=stride,
    padding=padding,
    dilation=dilation,
    bias=bias,
    padding_mode=padding_mode
)
conv2d.to(device)
print(conv2d)

num_trainable_parameters_in_conv2d = 0
num_non_trainable_parameters_in_conv2d = 0
for parameter in conv2d.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_conv2d += parameter.numel()
    else:
        num_non_trainable_parameters_in_conv2d += parameter.numel()
print(f'Num trainable/non-trainable parameters in Conv2d: {num_trainable_parameters_in_conv2d}/{num_non_trainable_parameters_in_conv2d}')

Conv2d(3, 2, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
Num trainable/non-trainable parameters in Conv2d: 152/0


In [19]:
conv2d(x)

tensor([[[ 0.3699, -0.1732,  0.5671, -0.3061, -0.9183, -0.1548,  0.0539,
          -0.1086],
         [ 0.1930,  0.0248,  0.0059, -0.2635,  0.3348, -0.5729, -0.2542,
          -0.0828],
         [ 0.6446, -0.2782, -0.6072, -0.7126, -0.2808,  0.2979, -0.2630,
          -0.0157],
         [-0.6193,  0.7035, -0.5116,  0.2693, -0.0293,  0.2111, -0.8410,
           0.0143],
         [-0.6793, -1.2151,  0.9551, -0.0214, -0.5077, -0.7375,  0.2328,
          -0.4570],
         [ 0.5124, -0.0542, -0.8369,  0.4932, -0.6503, -0.2252,  0.3872,
          -0.7001],
         [-0.0476,  0.0763,  0.2960,  0.4775, -0.2568, -0.4678,  0.6129,
          -0.0215],
         [ 0.1026,  0.6305,  0.1867,  0.0358,  0.7258, -0.6614, -0.4136,
          -0.0712]],

        [[-0.2300, -0.0720, -0.5610,  0.0670, -0.7666,  0.4532, -0.3946,
          -0.0343],
         [ 0.3339, -0.0835, -0.1329,  0.0395,  0.9597, -0.4631,  0.3297,
          -0.3359],
         [-0.4838,  0.4900, -0.3297,  0.7092, -0.4341,  0.7464, -0.0

In [27]:
lora_config = {
    'alpha': 2,
    'rank': 1,
    'rank_for': 'kernel',
    'delta_bias': True,
    'beta': 2
}

In [28]:
lora_conv2d = LoRAConv2d(conv2d, lora_config)
print(lora_conv2d)

num_trainable_parameters_in_lora_conv2d = 0
num_non_trainable_parameters_in_lora_conv2d = 0
for parameter in lora_conv2d.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_lora_conv2d += parameter.numel()
    else:
        num_non_trainable_parameters_in_lora_conv2d += parameter.numel()
print(f'Num trainable/non-trainable parameters in LoRA Conv2d: {num_trainable_parameters_in_lora_conv2d}/{num_non_trainable_parameters_in_lora_conv2d}')

LoRAConv2d(Conv2d(3, 2, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) + ((α=2.0/r=1) × Adapter(kernel_height=5, rank=1, kernel_width=5))) + (2 × delta_bias)
Num trainable/non-trainable parameters in LoRA Conv2d: 62/152


In [29]:
lora_conv2d.disable_adapter()
lora_conv2d(x)

tensor([[[ 0.3699, -0.1732,  0.5671, -0.3061, -0.9183, -0.1548,  0.0539,
          -0.1086],
         [ 0.1930,  0.0248,  0.0059, -0.2635,  0.3348, -0.5729, -0.2542,
          -0.0828],
         [ 0.6446, -0.2782, -0.6072, -0.7126, -0.2808,  0.2979, -0.2630,
          -0.0157],
         [-0.6193,  0.7035, -0.5116,  0.2693, -0.0293,  0.2111, -0.8410,
           0.0143],
         [-0.6793, -1.2151,  0.9551, -0.0214, -0.5077, -0.7375,  0.2328,
          -0.4570],
         [ 0.5124, -0.0542, -0.8369,  0.4932, -0.6503, -0.2252,  0.3872,
          -0.7001],
         [-0.0476,  0.0763,  0.2960,  0.4775, -0.2568, -0.4678,  0.6129,
          -0.0215],
         [ 0.1026,  0.6305,  0.1867,  0.0358,  0.7258, -0.6614, -0.4136,
          -0.0712]],

        [[-0.2300, -0.0720, -0.5610,  0.0670, -0.7666,  0.4532, -0.3946,
          -0.0343],
         [ 0.3339, -0.0835, -0.1329,  0.0395,  0.9597, -0.4631,  0.3297,
          -0.3359],
         [-0.4838,  0.4900, -0.3297,  0.7092, -0.4341,  0.7464, -0.0

In [30]:
lora_conv2d.enable_adapter()
lora_conv2d(x)

tensor([[[ 0.4596, -0.4069,  0.5613, -0.3495, -0.7958,  0.0759,  0.2835,
          -0.1472],
         [ 0.5437, -0.0712,  0.1785, -0.5508,  0.2297, -0.6922, -0.3510,
          -0.0061],
         [ 0.7197, -0.0709, -0.5254, -1.2622,  0.9298,  0.8230, -0.2308,
           0.5121],
         [-0.2002,  0.8302,  0.0578,  0.6138, -0.1779,  0.8635, -0.9217,
           0.0604],
         [-0.3770, -0.9489,  0.6231,  0.1498, -1.1296, -0.4435,  0.3173,
          -0.3373],
         [ 0.5158,  0.3938, -0.6899,  1.0107, -0.6330,  0.0359,  0.7790,
          -1.3702],
         [-0.0343,  0.3595,  0.3138,  0.6837, -0.0986, -0.7158,  0.3846,
          -0.3951],
         [ 0.0964,  0.7790, -0.3216, -0.2878,  0.8724, -0.3148, -0.1302,
           0.2691]],

        [[ 0.0110, -0.6267, -0.3012, -0.1983, -1.6582,  0.4601, -0.3075,
          -0.1498],
         [ 0.0916, -0.1605, -0.3315, -0.1157,  1.4639, -1.0006, -0.1399,
          -0.3246],
         [-0.9052,  0.5669, -0.9034,  0.4905, -1.0355,  0.2986,  0.0

In [31]:
merged_conv2d = lora_conv2d.get_merged_module()
merged_conv2d(x)

tensor([[[ 0.4596, -0.4069,  0.5613, -0.3495, -0.7958,  0.0759,  0.2835,
          -0.1472],
         [ 0.5437, -0.0712,  0.1785, -0.5508,  0.2297, -0.6922, -0.3510,
          -0.0061],
         [ 0.7197, -0.0709, -0.5254, -1.2622,  0.9298,  0.8230, -0.2308,
           0.5121],
         [-0.2002,  0.8302,  0.0578,  0.6138, -0.1779,  0.8635, -0.9217,
           0.0604],
         [-0.3770, -0.9489,  0.6231,  0.1498, -1.1296, -0.4435,  0.3173,
          -0.3373],
         [ 0.5158,  0.3938, -0.6899,  1.0107, -0.6330,  0.0359,  0.7790,
          -1.3702],
         [-0.0343,  0.3595,  0.3138,  0.6837, -0.0986, -0.7158,  0.3846,
          -0.3951],
         [ 0.0964,  0.7790, -0.3216, -0.2878,  0.8724, -0.3148, -0.1302,
           0.2691]],

        [[ 0.0110, -0.6267, -0.3012, -0.1983, -1.6582,  0.4601, -0.3075,
          -0.1498],
         [ 0.0916, -0.1605, -0.3315, -0.1157,  1.4639, -1.0006, -0.1399,
          -0.3246],
         [-0.9052,  0.5669, -0.9034,  0.4905, -1.0355,  0.2986,  0.0

In [32]:
lora_conv2d.disable_adapter()
lora_conv2d(x)

tensor([[[ 0.3699, -0.1732,  0.5671, -0.3061, -0.9183, -0.1548,  0.0539,
          -0.1086],
         [ 0.1930,  0.0248,  0.0059, -0.2635,  0.3348, -0.5729, -0.2542,
          -0.0828],
         [ 0.6446, -0.2782, -0.6072, -0.7126, -0.2808,  0.2979, -0.2630,
          -0.0157],
         [-0.6193,  0.7035, -0.5116,  0.2693, -0.0293,  0.2111, -0.8410,
           0.0143],
         [-0.6793, -1.2151,  0.9551, -0.0214, -0.5077, -0.7375,  0.2328,
          -0.4570],
         [ 0.5124, -0.0542, -0.8369,  0.4932, -0.6503, -0.2252,  0.3872,
          -0.7001],
         [-0.0476,  0.0763,  0.2960,  0.4775, -0.2568, -0.4678,  0.6129,
          -0.0215],
         [ 0.1026,  0.6305,  0.1867,  0.0358,  0.7258, -0.6614, -0.4136,
          -0.0712]],

        [[-0.2300, -0.0720, -0.5610,  0.0670, -0.7666,  0.4532, -0.3946,
          -0.0343],
         [ 0.3339, -0.0835, -0.1329,  0.0395,  0.9597, -0.4631,  0.3297,
          -0.3359],
         [-0.4838,  0.4900, -0.3297,  0.7092, -0.4341,  0.7464, -0.0