In [2]:
import torch
import torch.nn as nn

In [93]:
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.first_layer = nn.Linear(10, 10)

    def forward(self, x):
        x = self.first_layer(x)
        return x

In [94]:
m = MyModel()

# Train the model:
optimizer = torch.optim.Adam(m.parameters(), lr=0.001)
loss_function = nn.MSELoss() # or any other appropriate loss function

In [106]:
# 1000 times
loss = 0
for i in range(1000):
  x = torch.rand(32, 10) # 32: batch size
  target = torch.full((32,10), 1.0)
  y = m(x)

  loss = loss_function(y, target)

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

print('loss : ', loss)

loss :  tensor(1.1260e-12, grad_fn=<MseLossBackward0>)


In [108]:
# save the weights:
torch.save(m.state_dict(), 'my_model_state_dict.pth')

In [118]:
class MyModelLora(MyModel):
    def __init__(self):
        super().__init__() # Call the base class's __init__ method
        # Additional initialization logic for MyModel2
        self.lora_A = nn.Linear(10, 1)
        self.lora_B = nn.Linear(1, 10)

    def forward(self, x):
        # Custom forward pass logic here
        x1 = super().forward(x)
        x2 = self.lora_A(x) # -> [10 * 1]
        x2 = self.lora_B(x2) # [10 * 10]

        return (x1 + x2)
  
    def freeze_parent_weights(self):
        for name, param in self.first_layer.named_parameters():
            print('freezing first_layer.', name)
            param.requires_grad = False

In [132]:
# load the model from weights:
m2 = MyModelLora()
m2.load_state_dict(torch.load('my_model_state_dict.pth'), strict=False)

x = torch.rand(1, 10)
y = m2(x)
print(y)

tensor([[0.1110, 0.0508, 1.2207, 1.2921, 1.4932, 1.7057, 0.2008, 0.4966, 1.5487,
         1.5557]], grad_fn=<AddBackward0>)


In [133]:
# freeze the weights:
m2.freeze_parent_weights()

# Add new weights:

optimizer = torch.optim.Adam(m2.parameters(), lr=0.001)
loss_function = nn.MSELoss() # or any other appropriate loss function

freezing first_layer. weight
freezing first_layer. bias


In [164]:
# try it out:
loss = 0
for i in range(1000):
  x = torch.rand(32, 10) # 32: batch size
  target = torch.full((32, 10), 0.8)
  y = m2(x)

  loss = loss_function(y, target)

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

print('loss : ', loss)

loss :  tensor(1.8874e-08, grad_fn=<MseLossBackward0>)


In [169]:
x = torch.rand(1, 10)
x1 = m2.lora_A(x) # -> [0.8]
print(x1)
x2 = m2.lora_B(x1) # -> [-0.2, -0.2] ...
print(x2)


tensor([[0.8003]], grad_fn=<AddmmBackward0>)
tensor([[-0.2000, -0.2001, -0.1999, -0.1999, -0.2000, -0.1999, -0.2000, -0.2000,
         -0.1999, -0.1999]], grad_fn=<AddmmBackward0>)


In [193]:
print(m2.lora_B.weight.numel())

10


In [195]:
# let's count the params:
# original model:
count_1 = sum(p.numel() for p in m.parameters() if p.requires_grad)
print(count_1)

# lora:
count_2 = sum(p.numel() for p in m2.parameters() if p.requires_grad)
print(count_2) # loraA: 10+1 elem loraB: 10+10 elems

110
31
