In [2]:
from LoRA import Lora, LoraLinear
import torch
import torch.nn as nn
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
from tqdm import tqdm

torch.manual_seed(0);


In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
class Net(nn.Module):
    def __init__(self, hidden_size_1=1000):
        super(Net,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_1) 
        self.linear3 = nn.Linear(hidden_size_1, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [5]:
net = Net().to(device)

In [6]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

train(train_loader, net, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [01:39<00:00, 60.53it/s, loss=0.228]


In [7]:
def replace_linear_with_lora(model, rank, dropout_rate=0.0):
    for name,module in model.named_children():
        if isinstance(module,nn.Linear):
            out_f,in_f = module.weight.shape
            bias = module.bias is not None
            lora_layer = LoraLinear(in_f, out_f, rank, lora_dropout=dropout_rate)
            lora_layer.weight.data = module.weight.data.clone()
            if bias:
                lora_layer.bias.data = module.bias.data.clone()
            setattr(model, name, lora_layer)
        else:
            # Recursively apply this function to child modules
            replace_linear_with_lora(module, rank, dropout_rate=dropout_rate)

In [8]:
rank ,alpha = 1,1
weight_decay = 0.0
replace_linear_with_lora(net, rank=rank)

In [9]:
net

Net(
  (linear1): LoraLinear(in_features=784, out_features=1000, bias=True)
  (linear2): LoraLinear(in_features=1000, out_features=1000, bias=True)
  (linear3): LoraLinear(in_features=1000, out_features=10, bias=True)
  (relu): ReLU()
)

In [14]:
optimizer_v1 = torch.optim.Adam;optimizer_v1

torch.optim.adam.Adam

In [36]:
params_names = [param[0] for param in (iter(net.named_parameters()))]
params_names

['linear1.weight',
 'linear1.bias',
 'linear1.A',
 'linear1.B',
 'linear2.weight',
 'linear2.bias',
 'linear2.A',
 'linear2.B',
 'linear3.weight',
 'linear3.bias',
 'linear3.A',
 'linear3.B']

In [46]:
param_group ={
    'A': {},
    'B': {}
}

In [47]:
for name, param in net.named_parameters():
        if not param.requires_grad:
            continue
        
        if 'A' in name:
            param_group['A'][name] = param
            print(name)
        elif 'B' in name:
            param_group['B'][name] = param
            print(name)
param_group
    

linear1.A
linear1.B
linear2.A
linear2.B
linear3.A
linear3.B


{'A': {'linear1.A': Parameter containing:
  tensor([[ 4.3848e-02,  4.4304e-02,  5.3137e-02, -3.3112e-02,  4.7397e-02,
            4.1626e-02, -8.4931e-03,  6.6490e-02, -8.3801e-03,  4.9271e-02,
            8.3646e-02, -9.9587e-03,  1.3388e-02, -3.3355e-02, -7.8271e-02,
           -4.4980e-02, -1.8543e-02,  5.1049e-03,  7.0606e-03,  2.6766e-02,
            2.3162e-02,  4.4538e-02, -5.8175e-02,  1.7762e-02,  3.5382e-02,
            2.6536e-02, -7.1259e-02, -1.2687e-02,  7.4822e-02,  4.0373e-02,
           -3.1082e-02, -8.3951e-02,  5.7237e-02,  4.9943e-02, -3.0137e-02,
           -5.8583e-02,  4.0435e-02, -4.2852e-02, -5.5431e-03, -6.6480e-02,
           -3.1952e-02,  4.5162e-03, -3.7911e-02,  7.5508e-02, -5.9996e-03,
           -5.5640e-02, -5.3007e-02, -8.1382e-02, -6.1219e-03,  6.1317e-03,
           -5.3232e-02,  7.2264e-02, -6.3679e-02, -8.0188e-02, -1.8057e-02,
           -7.2452e-02,  2.9887e-02, -6.8376e-02,  4.2223e-02,  2.7215e-02,
           -6.1261e-02, -4.3755e-02, -5.4216e-

In [51]:
list(iter(net.parameters()))

[Parameter containing:
 tensor([[ 0.0238,  0.0432, -0.0053,  ...,  0.0460,  0.0278,  0.0261],
         [ 0.0154,  0.0202,  0.0247,  ...,  0.0149,  0.0292,  0.0052],
         [ 0.0042,  0.0392, -0.0089,  ...,  0.0040,  0.0255,  0.0324],
         ...,
         [-0.0031,  0.0605,  0.0588,  ...,  0.0277,  0.0539,  0.0013],
         [ 0.0763,  0.0396,  0.0253,  ...,  0.0596,  0.0532,  0.0540],
         [ 0.0225, -0.0012,  0.0500,  ...,  0.0594,  0.0329,  0.0432]]),
 Parameter containing:
 tensor([-4.8653e-02, -1.3826e-02, -3.1972e-02,  7.1351e-03, -1.5521e-03,
         -4.5384e-02, -3.7178e-02, -2.8117e-02, -6.2857e-02, -2.7319e-02,
         -9.8505e-03, -1.7442e-02, -3.7520e-02, -1.8868e-02, -1.9660e-02,
         -2.2316e-02, -5.5815e-02, -4.7388e-02, -4.2925e-02, -3.9011e-02,
         -5.8413e-02, -3.8276e-02,  9.9136e-03, -6.3864e-02, -2.9130e-02,
         -4.0044e-02, -2.9538e-02, -1.0945e-03, -1.5453e-02, -2.5208e-02,
         -3.5245e-03, -3.9785e-02, -3.5010e-02, -9.1595e-03, -3.8553

In [49]:
lr = 1e-4
loraplus_lr_ratio = 2**4 # value from paper
optimizer_grouped_parameters = [
    {
        "params": list(param_group["A"].values()),
        "weight_decay": weight_decay,
        "lr": lr,
    },
    {
        "params": list(param_group["B"].values()),
        "weight_decay": weight_decay,
        "lr": lr * loraplus_lr_ratio,
    },
]
optimizer_grouped_parameters

[{'params': [Parameter containing:
   tensor([[ 4.3848e-02,  4.4304e-02,  5.3137e-02, -3.3112e-02,  4.7397e-02,
             4.1626e-02, -8.4931e-03,  6.6490e-02, -8.3801e-03,  4.9271e-02,
             8.3646e-02, -9.9587e-03,  1.3388e-02, -3.3355e-02, -7.8271e-02,
            -4.4980e-02, -1.8543e-02,  5.1049e-03,  7.0606e-03,  2.6766e-02,
             2.3162e-02,  4.4538e-02, -5.8175e-02,  1.7762e-02,  3.5382e-02,
             2.6536e-02, -7.1259e-02, -1.2687e-02,  7.4822e-02,  4.0373e-02,
            -3.1082e-02, -8.3951e-02,  5.7237e-02,  4.9943e-02, -3.0137e-02,
            -5.8583e-02,  4.0435e-02, -4.2852e-02, -5.5431e-03, -6.6480e-02,
            -3.1952e-02,  4.5162e-03, -3.7911e-02,  7.5508e-02, -5.9996e-03,
            -5.5640e-02, -5.3007e-02, -8.1382e-02, -6.1219e-03,  6.1317e-03,
            -5.3232e-02,  7.2264e-02, -6.3679e-02, -8.0188e-02, -1.8057e-02,
            -7.2452e-02,  2.9887e-02, -6.8376e-02,  4.2223e-02,  2.7215e-02,
            -6.1261e-02, -4.3755e-02, -5.

In [None]:
optimizer_v2 = optimizer_v1(optimizer_grouped_parameters, lr=lr)