In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [12]:
class myModel(nn.Module):
    
    def __init__(self, input_dims, out_dims):
        super().__init__()
        
        self.layer1 = nn.Linear(input_dims, out_dims)
        
    def forward(self, x):
        
        x = self.layer1(x)
        
        return x
    
model = myModel(16, 128)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

data = torch.rand(16)

model(data)

tensor([-0.4178,  0.0527,  0.2860, -0.2160, -0.9628, -0.0234,  0.1412, -0.2157,
        -0.1364,  0.4179, -0.4966, -0.5282,  0.4678,  0.1319,  0.3803,  0.7978,
        -0.0436, -0.0223, -0.1968,  0.3369,  0.3349, -0.7447, -0.4242, -0.0657,
         0.0318, -0.5602, -0.0867,  0.0194, -0.4563,  0.0027, -0.1443,  0.1436,
         0.7101,  0.4038, -0.1858,  0.2627,  0.2492,  0.1637,  0.3377, -0.6230,
        -0.4291,  0.0948,  0.6636, -0.3977,  0.6548, -0.2269, -0.3907, -0.7490,
        -0.1500,  0.3006, -0.3635,  0.4097, -0.1951, -0.0927,  0.5026,  0.8984,
         0.6316,  0.0168, -0.8660,  0.2549,  0.6663,  0.5140, -0.0759, -0.3557,
        -0.5989, -0.4600, -0.3807,  0.4687,  0.4531,  0.4831, -0.8452,  0.4256,
        -0.4777,  0.4604, -0.0837,  0.3188, -0.5522, -0.7701,  0.2572, -0.1687,
         0.2905,  0.7840,  0.3558,  0.1854, -0.1331, -0.0621,  0.2936, -0.2118,
        -0.4452, -0.0603, -0.2236,  0.3912, -0.2629,  0.4622, -0.0030,  0.2306,
         0.1220, -1.0691,  0.1937, -0.33

In [10]:
class myModel(nn.Module):
    
    def __init__(self, input_dims, out_dims, rank):
        super().__init__()
        
        self.layer1 = nn.Linear(input_dims, out_dims)
        self.A = nn.Linear(input_dims, rank)
        self.B = nn.Linear(rank, out_dims)
        
    def forward(self, x):
        
        AcrossB = self.A @ self.B
        out = self.layer1 + AcrossB
        x = out(x)
        
        return x
    
model = myModel(16, 128, 4)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

data = torch.rand(16)

model(data)

TypeError: unsupported operand type(s) for @: 'Linear' and 'Linear'

# working lora

In [15]:
import torch
import torch.nn as nn

class myModel(nn.Module):
    
    def __init__(self, input_dims, out_dims, rank):
        super().__init__()
        
        self.layer1 = nn.Linear(input_dims, out_dims)
        self.rank = rank

        # Initialize LoRA matrices A and B
        self.A = nn.Parameter(torch.randn(out_dims, rank))
        self.B = nn.Parameter(torch.randn(rank, input_dims))
        
    def forward(self, x):
        # Modify layer1 weights with LoRA adaptation
        modified_W = self.layer1.weight + self.A @ self.B
        x = torch.nn.functional.linear(x, modified_W, self.layer1.bias)
        return x

# Model and optimizer
model = myModel(16, 128, rank=5)  # Assuming rank=5 for LoRA
# optimizer = torch.optim.AdamW([model.A, model.B], lr=1e-3)  # Optimize only A and B

params_to_update = [param for name, param in model.named_parameters() if "layer1" not in name]
optimizer = torch.optim.AdamW(params_to_update, lr=1e-3)

# Sample data
data = torch.rand(1, 16)  # Adjust the shape for batch processing

# Forward pass
output = model(data)

output

tensor([[ 2.0814, -3.3569,  0.2003,  2.6699,  3.0409, -2.9248, -0.8539, -1.1681,
          5.6208,  5.4735,  3.5509, -0.7432, -3.7090, -0.2704, -2.5739,  4.2626,
          1.8779, -3.5988,  1.5255, -0.9325,  0.6537, -0.7779,  2.6920, -6.6480,
         -1.7500,  4.5094, -4.5425,  4.1396, -1.9845, -0.1102, -3.1174,  3.3400,
          2.3653,  1.5188,  1.8876, -0.5312, -5.4253,  2.0907, -1.2952, -0.4010,
          1.8288,  1.3090, -0.3476, -2.7510, -3.4891,  5.9249, -5.0030, -3.8181,
         -1.8146, -1.9316,  2.5059,  0.5311, -9.0898,  2.4114, -0.5630,  0.9388,
         -4.0803, -3.2848,  2.8547,  3.5545, -2.4992,  2.6213,  2.4592, -3.3518,
         -0.6746,  0.3652, -2.0115, -1.1867,  1.3993,  0.1393,  4.0982,  2.1443,
         -0.0917, -2.9917,  1.3907, -3.2344,  0.1646, -2.5243,  2.3510, -1.5404,
         -1.5686, -1.1608,  3.4435,  4.6569,  3.2349,  0.5873, -5.2667,  2.7617,
         -0.1261,  0.0996,  1.7889, -0.4298,  1.3854, -5.4287, -0.0255,  1.9962,
         -7.7961,  1.8806, -