In [1]:
import torch
from torch import nn
from lora_modules import *

# Testing LoRALinear

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
x = torch.randn(size=(4, 16))
x = x.to(device)

In [4]:
in_features = 16
out_features = 32
bias = True

In [5]:
linear = nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
linear.to(device)
print(linear)

num_trainable_parameters_in_linear = 0
num_non_trainable_parameters_in_linear = 0
for parameter in linear.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_linear += parameter.numel()
    else:
        num_non_trainable_parameters_in_linear += parameter.numel()
print(f'Num trainable/non-trainable parameters in Linear: {num_trainable_parameters_in_linear}/{num_non_trainable_parameters_in_linear}')

Linear(in_features=16, out_features=32, bias=True)
Num trainable/non-trainable parameters in Linear: 544/0


In [6]:
linear(x)

tensor([[ 0.3687,  0.0468, -0.3190,  0.1016, -0.4830,  0.8460,  0.7054,  1.0520,
         -0.4528, -0.4620, -0.2957,  0.0573, -0.5145, -0.4268,  0.4509, -0.3976,
         -0.1905,  0.0845, -0.6936,  0.4584,  0.7937, -0.2834,  0.0443,  0.1912,
         -0.4124,  1.0536, -0.6583,  1.4021,  1.5621, -0.0406,  0.2510, -0.4904],
        [ 0.7146, -1.1288,  0.0562, -0.1871,  0.5941, -0.3572,  0.3720,  0.9657,
         -0.0784, -0.0709, -0.2372,  0.3658, -0.2015, -0.1831,  0.1439, -0.2060,
         -0.7583, -0.6196,  0.2744,  0.1992, -0.1116, -0.4199, -0.0618,  0.0983,
         -0.7672,  0.9925, -0.4078, -0.1319,  0.2812,  0.5623, -0.1153, -0.2775],
        [-0.5062, -0.3455,  0.3720, -0.5979,  0.1882, -0.7664,  0.5001,  0.6932,
          0.6571,  0.9728,  0.0153,  0.6498,  0.2768,  0.1113,  0.4866, -0.2014,
          0.0801, -0.7385,  0.4667,  0.2217,  0.2173,  0.1230, -0.2807,  0.7251,
         -0.1757,  0.9174, -0.4382,  0.2998,  0.1583,  0.0766,  0.2383,  0.0841],
        [ 0.9465,  0.1658

In [7]:
linear.state_dict()

OrderedDict([('weight',
              tensor([[-0.1646,  0.2493, -0.1254,  0.0281,  0.0482,  0.1962,  0.1618,  0.2280,
                       -0.0525,  0.1150,  0.2007, -0.0625,  0.1419, -0.0669,  0.0110, -0.0775],
                      [ 0.0223,  0.1155, -0.1460, -0.0696,  0.0191, -0.1655, -0.1386, -0.1770,
                       -0.1754,  0.0947, -0.1609,  0.2419, -0.0366,  0.2174,  0.1259,  0.0252],
                      [-0.0453, -0.1989, -0.0440, -0.0036, -0.0248,  0.1939,  0.1785,  0.0125,
                        0.1256, -0.0822, -0.0768, -0.0134, -0.0790,  0.2028,  0.1096,  0.0386],
                      [-0.0928,  0.0819, -0.0700, -0.2032,  0.0685,  0.0217, -0.0809, -0.0657,
                       -0.1895, -0.0141, -0.0028, -0.0163,  0.1747, -0.0553, -0.0435,  0.2114],
                      [ 0.0829,  0.1678, -0.1398,  0.0765, -0.0837, -0.1303,  0.1041,  0.2073,
                        0.2070, -0.2315,  0.1979,  0.0810,  0.0383, -0.1006,  0.2333,  0.0735],
                     

In [8]:
lora_config = {
    'rank': 4,
    'alpha': 2,
    'delta_bias': False
}

In [9]:
lora_linear = LoRALinear(linear, lora_config)
print(lora_linear)

num_trainable_parameters_in_lora_linear = 0
num_non_trainable_parameters_in_lora_linear = 0
for parameter in lora_linear.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_lora_linear += parameter.numel()
    else:
        num_non_trainable_parameters_in_lora_linear += parameter.numel()
print(f'Num trainable/non-trainable parameters in LoRA Linear: {num_trainable_parameters_in_lora_linear}/{num_non_trainable_parameters_in_lora_linear}')

LoRALinear(Linear(in_features=16, out_features=32, bias=True) + ((α=2.0/r=4) × Adapter(in_features=16, rank=4, out_features=32)))
Num trainable/non-trainable parameters in LoRA Linear: 192/544


In [10]:
lora_linear.disable_adapter()
lora_linear(x)

tensor([[ 0.3687,  0.0468, -0.3190,  0.1016, -0.4830,  0.8460,  0.7054,  1.0520,
         -0.4528, -0.4620, -0.2957,  0.0573, -0.5145, -0.4268,  0.4509, -0.3976,
         -0.1905,  0.0845, -0.6936,  0.4584,  0.7937, -0.2834,  0.0443,  0.1912,
         -0.4124,  1.0536, -0.6583,  1.4021,  1.5621, -0.0406,  0.2510, -0.4904],
        [ 0.7146, -1.1288,  0.0562, -0.1871,  0.5941, -0.3572,  0.3720,  0.9657,
         -0.0784, -0.0709, -0.2372,  0.3658, -0.2015, -0.1831,  0.1439, -0.2060,
         -0.7583, -0.6196,  0.2744,  0.1992, -0.1116, -0.4199, -0.0618,  0.0983,
         -0.7672,  0.9925, -0.4078, -0.1319,  0.2812,  0.5623, -0.1153, -0.2775],
        [-0.5062, -0.3455,  0.3720, -0.5979,  0.1882, -0.7664,  0.5001,  0.6932,
          0.6571,  0.9728,  0.0153,  0.6498,  0.2768,  0.1113,  0.4866, -0.2014,
          0.0801, -0.7385,  0.4667,  0.2217,  0.2173,  0.1230, -0.2807,  0.7251,
         -0.1757,  0.9174, -0.4382,  0.2998,  0.1583,  0.0766,  0.2383,  0.0841],
        [ 0.9465,  0.1658

In [11]:
lora_linear.enable_adapter()
lora_linear(x)

tensor([[ 5.1307e-01, -1.8497e-01, -1.3637e-01, -1.2063e-01, -6.2212e-01,
          8.3575e-01,  5.4737e-01,  1.1717e+00, -4.2806e-01, -6.4353e-01,
         -5.5168e-01, -3.9587e-02, -6.1672e-01, -3.7094e-01,  3.7082e-01,
         -5.1223e-01, -3.9828e-03, -3.6075e-02, -6.3787e-01,  5.6137e-01,
          6.4046e-01, -1.2540e-01,  3.1898e-01, -6.2173e-02, -5.8526e-01,
          1.3357e+00, -4.4179e-01,  1.4343e+00,  1.5591e+00,  8.6541e-02,
          1.8251e-01, -6.9897e-01],
        [ 8.6375e-01, -1.3514e+00,  9.2576e-02, -4.7857e-01,  3.0491e-01,
         -3.1552e-01,  3.6685e-01,  8.7124e-01, -1.3686e-02, -2.1035e-01,
         -4.7755e-01,  2.9794e-01, -3.5352e-01, -2.1723e-01,  2.4069e-01,
         -3.1802e-01, -6.9046e-01, -6.0295e-01,  3.9239e-01,  1.4604e-01,
         -6.8641e-02, -2.5012e-01,  1.3297e-01, -2.4460e-02, -6.3584e-01,
          9.7395e-01, -4.8589e-01,  2.2571e-02,  3.0717e-01,  6.3215e-01,
         -4.0209e-01, -4.3993e-01],
        [-4.4214e-01, -4.6871e-01,  4.92

In [12]:
merged_linear = lora_linear.get_merged_module()
merged_linear(x)

tensor([[ 5.1307e-01, -1.8497e-01, -1.3637e-01, -1.2063e-01, -6.2212e-01,
          8.3575e-01,  5.4737e-01,  1.1717e+00, -4.2806e-01, -6.4353e-01,
         -5.5168e-01, -3.9587e-02, -6.1672e-01, -3.7094e-01,  3.7082e-01,
         -5.1223e-01, -3.9828e-03, -3.6075e-02, -6.3787e-01,  5.6137e-01,
          6.4046e-01, -1.2540e-01,  3.1898e-01, -6.2173e-02, -5.8526e-01,
          1.3357e+00, -4.4179e-01,  1.4343e+00,  1.5591e+00,  8.6541e-02,
          1.8251e-01, -6.9897e-01],
        [ 8.6375e-01, -1.3514e+00,  9.2576e-02, -4.7857e-01,  3.0491e-01,
         -3.1552e-01,  3.6685e-01,  8.7124e-01, -1.3686e-02, -2.1035e-01,
         -4.7755e-01,  2.9794e-01, -3.5352e-01, -2.1723e-01,  2.4069e-01,
         -3.1802e-01, -6.9046e-01, -6.0295e-01,  3.9239e-01,  1.4604e-01,
         -6.8641e-02, -2.5012e-01,  1.3297e-01, -2.4460e-02, -6.3584e-01,
          9.7395e-01, -4.8589e-01,  2.2571e-02,  3.0717e-01,  6.3215e-01,
         -4.0209e-01, -4.3993e-01],
        [-4.4214e-01, -4.6871e-01,  4.92

In [13]:
lora_linear.disable_adapter()
lora_linear(x)

tensor([[ 0.3687,  0.0468, -0.3190,  0.1016, -0.4830,  0.8460,  0.7054,  1.0520,
         -0.4528, -0.4620, -0.2957,  0.0573, -0.5145, -0.4268,  0.4509, -0.3976,
         -0.1905,  0.0845, -0.6936,  0.4584,  0.7937, -0.2834,  0.0443,  0.1912,
         -0.4124,  1.0536, -0.6583,  1.4021,  1.5621, -0.0406,  0.2510, -0.4904],
        [ 0.7146, -1.1288,  0.0562, -0.1871,  0.5941, -0.3572,  0.3720,  0.9657,
         -0.0784, -0.0709, -0.2372,  0.3658, -0.2015, -0.1831,  0.1439, -0.2060,
         -0.7583, -0.6196,  0.2744,  0.1992, -0.1116, -0.4199, -0.0618,  0.0983,
         -0.7672,  0.9925, -0.4078, -0.1319,  0.2812,  0.5623, -0.1153, -0.2775],
        [-0.5062, -0.3455,  0.3720, -0.5979,  0.1882, -0.7664,  0.5001,  0.6932,
          0.6571,  0.9728,  0.0153,  0.6498,  0.2768,  0.1113,  0.4866, -0.2014,
          0.0801, -0.7385,  0.4667,  0.2217,  0.2173,  0.1230, -0.2807,  0.7251,
         -0.1757,  0.9174, -0.4382,  0.2998,  0.1583,  0.0766,  0.2383,  0.0841],
        [ 0.9465,  0.1658

In [14]:
lora_linear.state_dict()

OrderedDict([('delta_weight_A',
              tensor([[ 0.1729,  0.3651, -0.0797,  0.1531],
                      [ 0.4142, -0.1591,  0.4630, -0.4466],
                      [-0.4248,  0.1843, -0.0393,  0.3035],
                      [-0.1716, -0.2575,  0.3485, -0.3803],
                      [-0.2672, -0.4598,  0.4146, -0.1257],
                      [ 0.4289, -0.2600,  0.0945,  0.0851],
                      [ 0.2880, -0.3184, -0.0835, -0.1873],
                      [-0.3545,  0.0724,  0.2528,  0.1996],
                      [-0.2943,  0.4123, -0.2847, -0.1053],
                      [ 0.1714, -0.3297,  0.1813, -0.2507],
                      [-0.4595, -0.3970,  0.0005, -0.3688],
                      [-0.3821,  0.4191, -0.1602, -0.3718],
                      [-0.2761,  0.1598,  0.1103, -0.2908],
                      [ 0.3839,  0.4383,  0.3949, -0.0951],
                      [-0.1905,  0.2758, -0.4814, -0.2520],
                      [-0.4300,  0.0819, -0.1018, -0.2698],
        

# Testing LoRAConv1d

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [16]:
x = torch.randn(size=(2, 256))
x = x.to(device)

In [17]:
in_channels = 2
out_channels = 8
kernel_size = 5
stride = 1
padding = 2
dilation = 1
bias = True
padding_mode = 'zeros'

In [18]:
conv1d = nn.Conv1d(
    in_channels=in_channels,
    out_channels=out_channels,
    kernel_size=kernel_size,
    stride=stride,
    padding=padding,
    dilation=dilation,
    bias=bias,
    padding_mode=padding_mode
)
conv1d.to(device)
print(conv1d)

num_trainable_parameters_in_conv1d = 0
num_non_trainable_parameters_in_conv1d = 0
for parameter in conv1d.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_conv1d += parameter.numel()
    else:
        num_non_trainable_parameters_in_conv1d += parameter.numel()
print(f'Num trainable/non-trainable parameters in Conv1d: {num_trainable_parameters_in_conv1d}/{num_non_trainable_parameters_in_conv1d}')

Conv1d(2, 8, kernel_size=(5,), stride=(1,), padding=(2,))
Num trainable/non-trainable parameters in Conv1d: 88/0


In [19]:
conv1d(x)

tensor([[ 0.6512, -0.6503, -0.8993,  ...,  0.3132, -0.7204, -0.3859],
        [ 0.1260,  0.5410, -0.3867,  ...,  0.4675,  0.3979, -0.2770],
        [-0.5936,  0.6184, -0.1523,  ..., -1.4040,  0.1440, -0.1381],
        ...,
        [ 1.4024,  1.1076,  0.2687,  ..., -0.3982,  0.4824, -0.1141],
        [-1.0153, -0.6430,  0.5092,  ...,  0.6282, -0.4158,  0.5815],
        [-1.2079, -0.3756, -1.0491,  ..., -0.3632,  0.2821,  0.6019]],
       device='cuda:0', grad_fn=<SqueezeBackward1>)

In [20]:
lora_config = {
    'alpha': 2,
    'rank': 1,
    'delta_bias': True,
    'beta': 1.5,
}

In [21]:
lora_conv1d = LoRAConv1d(conv1d, lora_config)
print(lora_conv1d)

num_trainable_parameters_in_lora_conv1d = 0
num_non_trainable_parameters_in_lora_conv1d = 0
for parameter in lora_conv1d.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_lora_conv1d += parameter.numel()
    else:
        num_non_trainable_parameters_in_lora_conv1d += parameter.numel()
print(f'Num trainable/non-trainable parameters in LoRA Conv1d: {num_trainable_parameters_in_lora_conv1d}/{num_non_trainable_parameters_in_lora_conv1d}')

LoRAConv1d(Conv1d(2, 8, kernel_size=(5,), stride=(1,), padding=(2,)) + ((α=2.0/r=1) × Adapter(in_channels=2, rank=1, out_features=8))) + (1.5 × delta_bias)
Num trainable/non-trainable parameters in LoRA Conv1d: 58/88


In [22]:
lora_conv1d.disable_adapter()
lora_conv1d(x)

tensor([[ 0.6512, -0.6503, -0.8993,  ...,  0.3132, -0.7204, -0.3859],
        [ 0.1260,  0.5410, -0.3867,  ...,  0.4675,  0.3979, -0.2770],
        [-0.5936,  0.6184, -0.1523,  ..., -1.4040,  0.1440, -0.1381],
        ...,
        [ 1.4024,  1.1076,  0.2687,  ..., -0.3982,  0.4824, -0.1141],
        [-1.0153, -0.6430,  0.5092,  ...,  0.6282, -0.4158,  0.5815],
        [-1.2079, -0.3756, -1.0491,  ..., -0.3632,  0.2821,  0.6019]],
       device='cuda:0')

In [23]:
lora_conv1d.enable_adapter()
lora_conv1d(x)

tensor([[ 1.0845, -0.9089, -2.0287,  ...,  0.5806, -0.6369, -0.6526],
        [ 0.2594,  0.0869, -1.4050,  ...,  0.1325,  0.0195, -0.7689],
        [-0.6942,  0.2323, -1.0083,  ..., -1.7118, -0.1696, -0.0659],
        ...,
        [ 1.1910,  0.4724, -0.3036,  ..., -0.7932, -0.0737, -1.1429],
        [-1.3871, -0.9876, -0.3168,  ...,  0.4165, -0.5735,  0.7184],
        [-1.6186, -0.7638, -2.1851,  ..., -0.0695,  0.2602,  0.4420]],
       device='cuda:0', grad_fn=<SqueezeBackward1>)

In [24]:
merged_conv1d = lora_conv1d.get_merged_module()
merged_conv1d(x)

tensor([[ 1.0845, -0.9089, -2.0287,  ...,  0.5806, -0.6369, -0.6526],
        [ 0.2594,  0.0869, -1.4050,  ...,  0.1325,  0.0195, -0.7689],
        [-0.6942,  0.2323, -1.0083,  ..., -1.7118, -0.1696, -0.0659],
        ...,
        [ 1.1910,  0.4724, -0.3036,  ..., -0.7932, -0.0737, -1.1429],
        [-1.3871, -0.9876, -0.3168,  ...,  0.4165, -0.5735,  0.7184],
        [-1.6186, -0.7638, -2.1851,  ..., -0.0695,  0.2602,  0.4420]],
       device='cuda:0', grad_fn=<SqueezeBackward1>)

In [25]:
lora_conv1d.disable_adapter()
lora_conv1d(x)

tensor([[ 0.6512, -0.6503, -0.8993,  ...,  0.3132, -0.7204, -0.3859],
        [ 0.1260,  0.5410, -0.3867,  ...,  0.4675,  0.3979, -0.2770],
        [-0.5936,  0.6184, -0.1523,  ..., -1.4040,  0.1440, -0.1381],
        ...,
        [ 1.4024,  1.1076,  0.2687,  ..., -0.3982,  0.4824, -0.1141],
        [-1.0153, -0.6430,  0.5092,  ...,  0.6282, -0.4158,  0.5815],
        [-1.2079, -0.3756, -1.0491,  ..., -0.3632,  0.2821,  0.6019]],
       device='cuda:0')

# Testing LoRAConv2d

In [26]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [27]:
x = torch.randn(size=(3, 8, 8))
x = x.to(device)

In [28]:
in_channels = 3
out_channels = 2
kernel_size = 5
stride = 1
padding = 2
dilation = 1
bias = True
padding_mode = 'zeros'

In [29]:
conv2d = nn.Conv2d(
    in_channels=in_channels,
    out_channels=out_channels,
    kernel_size=kernel_size,
    stride=stride,
    padding=padding,
    dilation=dilation,
    bias=bias,
    padding_mode=padding_mode
)
conv2d.to(device)
print(conv2d)

num_trainable_parameters_in_conv2d = 0
num_non_trainable_parameters_in_conv2d = 0
for parameter in conv2d.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_conv2d += parameter.numel()
    else:
        num_non_trainable_parameters_in_conv2d += parameter.numel()
print(f'Num trainable/non-trainable parameters in Conv2d: {num_trainable_parameters_in_conv2d}/{num_non_trainable_parameters_in_conv2d}')

Conv2d(3, 2, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
Num trainable/non-trainable parameters in Conv2d: 152/0


In [30]:
conv2d(x)

tensor([[[ 1.9037e-01,  1.4457e-02, -4.9597e-01, -3.6650e-01,  1.5104e-01,
          -8.7106e-03,  2.3315e-01,  2.8363e-02],
         [ 3.8243e-02, -7.0938e-01, -1.9897e-01,  5.3477e-01,  5.7587e-02,
           3.0670e-01, -1.7395e-01, -1.8100e-01],
         [ 3.3196e-01, -4.1181e-01,  1.1878e-01,  3.8952e-01,  8.1037e-01,
          -6.3297e-01, -7.5928e-01,  3.3828e-01],
         [ 2.1258e-02, -6.1072e-01,  6.7721e-01,  5.3313e-01, -7.1713e-01,
           8.9155e-02,  4.7531e-01,  2.0686e-01],
         [-3.0157e-01, -2.4520e-01,  8.3224e-01, -9.2104e-01, -6.7133e-01,
          -4.5130e-01,  5.1400e-01, -5.6369e-01],
         [-1.9601e-01, -1.0025e+00,  2.2458e-01, -3.4213e-01,  1.9763e-01,
          -4.4656e-01,  4.4935e-01,  9.7178e-01],
         [-3.4913e-01, -1.2886e-01,  2.3596e-01,  1.4481e-01,  6.2148e-01,
           1.0468e-01, -5.2466e-01, -2.4996e-01],
         [-4.7445e-01,  4.3592e-01,  4.9381e-01, -1.5376e-01,  2.2004e-01,
           6.7969e-01,  3.2152e-01,  2.5814e-01]],

In [31]:
lora_config = {
    'alpha': 2,
    'rank': 1,
    'rank_for': 'kernel',
    'delta_bias': True,
    'beta': 2
}

In [32]:
lora_conv2d = LoRAConv2d(conv2d, lora_config)
print(lora_conv2d)

num_trainable_parameters_in_lora_conv2d = 0
num_non_trainable_parameters_in_lora_conv2d = 0
for parameter in lora_conv2d.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_lora_conv2d += parameter.numel()
    else:
        num_non_trainable_parameters_in_lora_conv2d += parameter.numel()
print(f'Num trainable/non-trainable parameters in LoRA Conv2d: {num_trainable_parameters_in_lora_conv2d}/{num_non_trainable_parameters_in_lora_conv2d}')

LoRAConv2d(Conv2d(3, 2, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) + ((α=2.0/r=1) × Adapter(kH=5, rank=1, kW=5))) + (2 × delta_bias)
Num trainable/non-trainable parameters in LoRA Conv2d: 62/152


In [33]:
lora_conv2d.disable_adapter()
lora_conv2d(x)

tensor([[[ 1.9037e-01,  1.4457e-02, -4.9597e-01, -3.6650e-01,  1.5104e-01,
          -8.7106e-03,  2.3315e-01,  2.8363e-02],
         [ 3.8243e-02, -7.0938e-01, -1.9897e-01,  5.3477e-01,  5.7587e-02,
           3.0670e-01, -1.7395e-01, -1.8100e-01],
         [ 3.3196e-01, -4.1181e-01,  1.1878e-01,  3.8952e-01,  8.1037e-01,
          -6.3297e-01, -7.5928e-01,  3.3828e-01],
         [ 2.1258e-02, -6.1072e-01,  6.7721e-01,  5.3313e-01, -7.1713e-01,
           8.9155e-02,  4.7531e-01,  2.0686e-01],
         [-3.0157e-01, -2.4520e-01,  8.3224e-01, -9.2104e-01, -6.7133e-01,
          -4.5130e-01,  5.1400e-01, -5.6369e-01],
         [-1.9601e-01, -1.0025e+00,  2.2458e-01, -3.4213e-01,  1.9763e-01,
          -4.4656e-01,  4.4935e-01,  9.7178e-01],
         [-3.4913e-01, -1.2886e-01,  2.3596e-01,  1.4481e-01,  6.2148e-01,
           1.0468e-01, -5.2466e-01, -2.4996e-01],
         [-4.7445e-01,  4.3592e-01,  4.9381e-01, -1.5376e-01,  2.2004e-01,
           6.7969e-01,  3.2152e-01,  2.5814e-01]],

In [34]:
lora_conv2d.enable_adapter()
lora_conv2d(x)

tensor([[[-2.2451e-01,  2.6450e-01, -3.1098e-01,  5.2292e-04, -2.0205e-01,
          -1.4503e-01,  4.6376e-01,  1.7992e-01],
         [ 2.3258e-01, -4.6014e-01, -7.5545e-02,  6.6362e-01, -2.3476e-01,
           6.1440e-01,  3.2513e-01, -2.9521e-01],
         [ 5.0405e-01, -1.6319e-01,  1.0090e-02,  4.7648e-01,  1.2719e+00,
          -2.4622e-01, -9.0399e-01,  5.2989e-01],
         [-5.9454e-02, -6.2730e-01,  4.9892e-01,  7.7627e-01, -4.5845e-01,
           5.2210e-01, -1.7092e-01,  7.4754e-01],
         [-5.9252e-01, -1.8964e-01,  9.5407e-01, -7.4176e-01, -1.0738e+00,
          -5.2297e-01,  8.4761e-01, -8.0150e-01],
         [-5.9744e-02, -7.8112e-01,  4.7483e-01, -5.1805e-01,  5.3829e-01,
          -8.9158e-01,  9.4799e-01,  9.2014e-01],
         [-1.3348e-01,  9.1001e-02,  1.6131e-01,  2.2557e-01,  9.7555e-01,
           4.7241e-01, -5.1427e-01,  3.3591e-02],
         [-3.5757e-01,  5.5723e-01,  4.5122e-01,  3.2969e-01,  9.0439e-02,
           9.1061e-01,  3.3774e-01,  3.7299e-01]],

In [35]:
merged_conv2d = lora_conv2d.get_merged_module()
merged_conv2d(x)

tensor([[[-2.2451e-01,  2.6450e-01, -3.1098e-01,  5.2292e-04, -2.0205e-01,
          -1.4503e-01,  4.6376e-01,  1.7992e-01],
         [ 2.3258e-01, -4.6014e-01, -7.5545e-02,  6.6362e-01, -2.3476e-01,
           6.1440e-01,  3.2513e-01, -2.9521e-01],
         [ 5.0405e-01, -1.6319e-01,  1.0090e-02,  4.7648e-01,  1.2719e+00,
          -2.4622e-01, -9.0399e-01,  5.2989e-01],
         [-5.9454e-02, -6.2730e-01,  4.9892e-01,  7.7627e-01, -4.5845e-01,
           5.2210e-01, -1.7092e-01,  7.4754e-01],
         [-5.9252e-01, -1.8964e-01,  9.5407e-01, -7.4176e-01, -1.0738e+00,
          -5.2297e-01,  8.4761e-01, -8.0150e-01],
         [-5.9744e-02, -7.8112e-01,  4.7483e-01, -5.1805e-01,  5.3829e-01,
          -8.9158e-01,  9.4799e-01,  9.2014e-01],
         [-1.3348e-01,  9.1001e-02,  1.6131e-01,  2.2557e-01,  9.7555e-01,
           4.7241e-01, -5.1427e-01,  3.3591e-02],
         [-3.5757e-01,  5.5723e-01,  4.5122e-01,  3.2969e-01,  9.0439e-02,
           9.1061e-01,  3.3774e-01,  3.7299e-01]],

In [36]:
lora_conv2d.disable_adapter()
lora_conv2d(x)

tensor([[[ 1.9037e-01,  1.4457e-02, -4.9597e-01, -3.6650e-01,  1.5104e-01,
          -8.7106e-03,  2.3315e-01,  2.8363e-02],
         [ 3.8243e-02, -7.0938e-01, -1.9897e-01,  5.3477e-01,  5.7587e-02,
           3.0670e-01, -1.7395e-01, -1.8100e-01],
         [ 3.3196e-01, -4.1181e-01,  1.1878e-01,  3.8952e-01,  8.1037e-01,
          -6.3297e-01, -7.5928e-01,  3.3828e-01],
         [ 2.1258e-02, -6.1072e-01,  6.7721e-01,  5.3313e-01, -7.1713e-01,
           8.9155e-02,  4.7531e-01,  2.0686e-01],
         [-3.0157e-01, -2.4520e-01,  8.3224e-01, -9.2104e-01, -6.7133e-01,
          -4.5130e-01,  5.1400e-01, -5.6369e-01],
         [-1.9601e-01, -1.0025e+00,  2.2458e-01, -3.4213e-01,  1.9763e-01,
          -4.4656e-01,  4.4935e-01,  9.7178e-01],
         [-3.4913e-01, -1.2886e-01,  2.3596e-01,  1.4481e-01,  6.2148e-01,
           1.0468e-01, -5.2466e-01, -2.4996e-01],
         [-4.7445e-01,  4.3592e-01,  4.9381e-01, -1.5376e-01,  2.2004e-01,
           6.7969e-01,  3.2152e-01,  2.5814e-01]],

# Testing LoRAConv3d

In [37]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [38]:
x = torch.randn(size=(3, 8, 256, 256))
x = x.to(device)

In [39]:
in_channels = 3
out_channels = 2
kernel_size = 5
stride = 1
padding = 2
dilation = 1
bias = True
padding_mode = 'zeros'

In [40]:
conv3d = nn.Conv3d(
    in_channels=in_channels,
    out_channels=out_channels,
    kernel_size=kernel_size,
    stride=stride,
    padding=padding,
    dilation=dilation,
    bias=bias,
    padding_mode=padding_mode
)
conv3d.to(device)
print(conv3d)

num_trainable_parameters_in_conv3d = 0
num_non_trainable_parameters_in_conv3d = 0
for parameter in conv3d.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_conv3d += parameter.numel()
    else:
        num_non_trainable_parameters_in_conv3d += parameter.numel()
print(f'Num trainable/non-trainable parameters in Conv3d: {num_trainable_parameters_in_conv3d}/{num_non_trainable_parameters_in_conv3d}')

Conv3d(3, 2, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
Num trainable/non-trainable parameters in Conv3d: 752/0


In [41]:
conv3d(x)

tensor([[[[ 1.1222e-01, -2.8749e-01,  2.0797e-01,  ..., -1.2659e-02,
            3.5934e-01,  1.1450e-01],
          [ 2.3992e-02, -6.3007e-02,  1.8843e-01,  ..., -2.3352e-01,
           -5.1714e-01, -3.1130e-02],
          [-5.8726e-01,  7.0430e-02,  1.0287e+00,  ...,  5.2699e-02,
            1.3431e-01, -4.8851e-02],
          ...,
          [-3.2887e-01,  4.5803e-01,  3.3186e-02,  ..., -7.1173e-01,
            3.4119e-01,  1.0070e-01],
          [ 3.8214e-02, -7.3057e-02,  5.6187e-01,  ..., -6.2443e-02,
            2.8987e-01,  1.7929e-01],
          [ 5.9664e-01,  5.8491e-01,  1.0336e-01,  ...,  5.6393e-01,
            5.4424e-01, -1.5466e-02]],

         [[-3.9222e-01,  2.4899e-01,  6.5139e-01,  ...,  5.4746e-02,
            2.4662e-01, -2.1444e-01],
          [ 7.8010e-01, -3.6187e-01,  2.1524e-01,  ..., -4.9827e-01,
            2.3559e-01,  2.7592e-01],
          [-6.3338e-01, -4.4967e-01, -4.1325e-01,  ...,  1.8419e-01,
            1.2774e-01, -1.6728e-01],
          ...,
     

In [42]:
lora_config = {
    'alpha': 2,
    'rank': 1,
    'delta_bias': True,
    'beta': 2
}

In [43]:
lora_conv3d = LoRAConv3d(conv3d, lora_config)
print(lora_conv3d)

num_trainable_parameters_in_lora_conv3d = 0
num_non_trainable_parameters_in_lora_conv3d = 0
for parameter in lora_conv3d.parameters():
    if parameter.requires_grad:
        num_trainable_parameters_in_lora_conv3d += parameter.numel()
    else:
        num_non_trainable_parameters_in_lora_conv3d += parameter.numel()
print(f'Num trainable/non-trainable parameters in LoRA Conv3d: {num_trainable_parameters_in_lora_conv3d}/{num_non_trainable_parameters_in_lora_conv3d}')

LoRAConv3d(Conv3d(3, 2, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2)) + ((α=2.0/r=1) × Adapter(in_channels=3, rank=1, out_features=2))) + (2 × delta_bias)
Num trainable/non-trainable parameters in LoRA Conv3d: 627/752


In [44]:
lora_conv3d.disable_adapter()
lora_conv3d(x)

tensor([[[[ 1.1222e-01, -2.8749e-01,  2.0797e-01,  ..., -1.2659e-02,
            3.5934e-01,  1.1450e-01],
          [ 2.3992e-02, -6.3007e-02,  1.8843e-01,  ..., -2.3352e-01,
           -5.1714e-01, -3.1130e-02],
          [-5.8726e-01,  7.0430e-02,  1.0287e+00,  ...,  5.2699e-02,
            1.3431e-01, -4.8851e-02],
          ...,
          [-3.2887e-01,  4.5803e-01,  3.3186e-02,  ..., -7.1173e-01,
            3.4119e-01,  1.0070e-01],
          [ 3.8214e-02, -7.3057e-02,  5.6187e-01,  ..., -6.2443e-02,
            2.8987e-01,  1.7929e-01],
          [ 5.9664e-01,  5.8491e-01,  1.0336e-01,  ...,  5.6393e-01,
            5.4424e-01, -1.5466e-02]],

         [[-3.9222e-01,  2.4899e-01,  6.5139e-01,  ...,  5.4746e-02,
            2.4662e-01, -2.1444e-01],
          [ 7.8010e-01, -3.6187e-01,  2.1524e-01,  ..., -4.9827e-01,
            2.3559e-01,  2.7592e-01],
          [-6.3338e-01, -4.4967e-01, -4.1325e-01,  ...,  1.8419e-01,
            1.2774e-01, -1.6728e-01],
          ...,
     

In [45]:
lora_conv3d.enable_adapter()
lora_conv3d(x)

tensor([[[[ 6.3268e-02, -1.6584e-01,  2.0463e-01,  ..., -1.1168e-01,
            3.5591e-01,  3.1528e-01],
          [ 2.2846e-01,  1.3258e-01,  9.8698e-02,  ..., -9.4560e-02,
           -5.7136e-01,  2.0923e-01],
          [-6.3094e-01,  4.7065e-02,  1.1532e+00,  ...,  2.8823e-01,
            3.8536e-01, -5.0469e-02],
          ...,
          [-2.2882e-01,  5.3669e-01, -8.5172e-03,  ..., -8.5821e-01,
            2.1856e-01,  1.5459e-01],
          [ 2.4417e-01,  7.2072e-02,  7.9518e-01,  ..., -3.9980e-02,
            2.0296e-01,  2.7116e-01],
          [ 7.5508e-01,  3.0815e-01,  1.3755e-01,  ...,  5.4818e-01,
            5.1906e-01,  1.6044e-01]],

         [[-4.1712e-01,  3.4350e-01,  7.2339e-01,  ...,  3.5224e-01,
            4.5797e-01,  5.3570e-03],
          [ 7.9017e-01, -3.8070e-01,  2.2116e-01,  ..., -2.3775e-01,
            2.1101e-01,  3.5523e-01],
          [-7.6579e-01, -2.8091e-01, -4.3878e-04,  ...,  4.5317e-01,
           -8.1324e-03, -1.9251e-01],
          ...,
     

In [46]:
merged_conv3d = lora_conv3d.get_merged_module()
merged_conv3d(x)

tensor([[[[ 6.3268e-02, -1.6584e-01,  2.0463e-01,  ..., -1.1168e-01,
            3.5591e-01,  3.1528e-01],
          [ 2.2846e-01,  1.3258e-01,  9.8698e-02,  ..., -9.4560e-02,
           -5.7136e-01,  2.0923e-01],
          [-6.3094e-01,  4.7065e-02,  1.1532e+00,  ...,  2.8823e-01,
            3.8536e-01, -5.0469e-02],
          ...,
          [-2.2882e-01,  5.3669e-01, -8.5172e-03,  ..., -8.5821e-01,
            2.1856e-01,  1.5459e-01],
          [ 2.4417e-01,  7.2072e-02,  7.9518e-01,  ..., -3.9980e-02,
            2.0296e-01,  2.7116e-01],
          [ 7.5508e-01,  3.0815e-01,  1.3755e-01,  ...,  5.4818e-01,
            5.1906e-01,  1.6044e-01]],

         [[-4.1712e-01,  3.4350e-01,  7.2339e-01,  ...,  3.5224e-01,
            4.5797e-01,  5.3570e-03],
          [ 7.9017e-01, -3.8070e-01,  2.2116e-01,  ..., -2.3775e-01,
            2.1101e-01,  3.5523e-01],
          [-7.6579e-01, -2.8091e-01, -4.3878e-04,  ...,  4.5317e-01,
           -8.1324e-03, -1.9251e-01],
          ...,
     

In [47]:
lora_conv3d.disable_adapter()
lora_conv3d(x)

tensor([[[[ 1.1222e-01, -2.8749e-01,  2.0797e-01,  ..., -1.2659e-02,
            3.5934e-01,  1.1450e-01],
          [ 2.3992e-02, -6.3007e-02,  1.8843e-01,  ..., -2.3352e-01,
           -5.1714e-01, -3.1130e-02],
          [-5.8726e-01,  7.0430e-02,  1.0287e+00,  ...,  5.2699e-02,
            1.3431e-01, -4.8851e-02],
          ...,
          [-3.2887e-01,  4.5803e-01,  3.3186e-02,  ..., -7.1173e-01,
            3.4119e-01,  1.0070e-01],
          [ 3.8214e-02, -7.3057e-02,  5.6187e-01,  ..., -6.2443e-02,
            2.8987e-01,  1.7929e-01],
          [ 5.9664e-01,  5.8491e-01,  1.0336e-01,  ...,  5.6393e-01,
            5.4424e-01, -1.5466e-02]],

         [[-3.9222e-01,  2.4899e-01,  6.5139e-01,  ...,  5.4746e-02,
            2.4662e-01, -2.1444e-01],
          [ 7.8010e-01, -3.6187e-01,  2.1524e-01,  ..., -4.9827e-01,
            2.3559e-01,  2.7592e-01],
          [-6.3338e-01, -4.4967e-01, -4.1325e-01,  ...,  1.8419e-01,
            1.2774e-01, -1.6728e-01],
          ...,
     