# LoRA +

In [1]:
from LoRA import Lora, LoraLinear
import torch
import torch.nn as nn
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
from tqdm import tqdm

torch.manual_seed(0);


# Setup

In [2]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Create over expensive model

In [3]:
class Net(nn.Module):
    def __init__(self, hidden_size_1=1000):
        super(Net,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_1*2) 
        self.linear3 = nn.Linear(hidden_size_1*2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [4]:
net = Net().to(device)

### Train the model

In [5]:
def train(train_loader, net, epochs=5, total_iterations_limit=None, optim = None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = optim or torch.optim.Adam(net.parameters(), lr=0.001)
    print(optimizer)
    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

train(train_loader, net, epochs=1)

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)


Epoch 1: 100%|██████████| 6000/6000 [00:34<00:00, 172.20it/s, loss=0.237]


In [6]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 292.11it/s]

Accuracy: 0.957
wrong counts for the digit 0: 33
wrong counts for the digit 1: 34
wrong counts for the digit 2: 42
wrong counts for the digit 3: 69
wrong counts for the digit 4: 19
wrong counts for the digit 5: 23
wrong counts for the digit 6: 45
wrong counts for the digit 7: 49
wrong counts for the digit 8: 17
wrong counts for the digit 9: 97





In [7]:
def replace_linear_with_lora(model, rank, dropout_rate=0.0):
    for name,module in model.named_children():
        if isinstance(module,nn.Linear):
            out_f,in_f = module.weight.shape
            bias = module.bias is not None
            lora_layer = LoraLinear(in_f, out_f, rank, lora_dropout=dropout_rate)
            lora_layer.weight.data = module.weight.data.clone()
            if bias:
                lora_layer.bias.data = module.bias.data.clone()
            setattr(model, name, lora_layer)
        else:
            # Recursively apply this function to child modules
            replace_linear_with_lora(module, rank, dropout_rate=dropout_rate)

In [8]:
rank ,alpha = 1,1
weight_decay = 0.0
replace_linear_with_lora(net, rank=rank)

In [9]:
net.to(device)

Net(
  (linear1): LoraLinear(in_features=784, out_features=1000, bias=True)
  (linear2): LoraLinear(in_features=1000, out_features=2000, bias=True)
  (linear3): LoraLinear(in_features=2000, out_features=10, bias=True)
  (relu): ReLU()
)

In [10]:
params_names = [param[0] for param in (iter(net.named_parameters()))]
params_names

['linear1.weight',
 'linear1.bias',
 'linear1.A',
 'linear1.B',
 'linear2.weight',
 'linear2.bias',
 'linear2.A',
 'linear2.B',
 'linear3.weight',
 'linear3.bias',
 'linear3.A',
 'linear3.B']

In [11]:
param_group ={
    'A': {},
    'B': {}
}

In [12]:
for name, param in net.named_parameters():
        if not param.requires_grad:
            continue
        
        if 'A' in name:
            param_group['A'][name] = param
            print(name)
        elif 'B' in name:
            param_group['B'][name] = param
            print(name)
param_group
    

linear1.A
linear1.B
linear2.A
linear2.B
linear3.A
linear3.B


{'A': {'linear1.A': Parameter containing:
  tensor([[ 3.9330e-02,  3.9752e-02, -8.0098e-02, -2.6622e-02, -5.8098e-02,
           -3.0868e-02,  7.7938e-02,  8.5724e-02, -3.1529e-02, -3.4676e-02,
           -1.9467e-02,  5.9036e-02,  8.0732e-02, -2.1240e-03, -1.3514e-02,
            6.6837e-02, -7.0199e-02,  1.5412e-02, -3.5974e-03,  5.7088e-02,
           -8.6246e-02, -8.2492e-02,  2.2794e-02,  7.5971e-02,  2.9438e-02,
           -7.1050e-02, -8.0348e-02, -5.1602e-03,  2.2886e-02, -4.2793e-02,
            4.1867e-02, -6.3142e-02, -1.6274e-03, -7.5760e-02,  5.0614e-02,
           -3.5353e-03, -4.5013e-03,  5.0154e-02, -4.6476e-02, -3.5197e-02,
           -3.7123e-02,  3.7014e-02, -2.7894e-02, -8.4516e-02,  1.3658e-02,
           -1.5273e-02,  4.6396e-02,  1.7682e-02,  4.8418e-02,  5.1854e-02,
           -8.2378e-02, -7.5649e-02,  2.5404e-03, -3.6122e-02, -8.4040e-02,
            3.4419e-02,  3.5138e-02,  8.6566e-02, -1.2819e-02,  5.1159e-02,
           -7.4816e-02, -3.8456e-02,  3.5854e-

In [13]:
lr = 0.001
loraplus_lr_ratio = 2**4 # value from paper
optimizer_grouped_parameters = [
    {
        "params": list(param_group["A"].values()),
        "weight_decay": weight_decay,
        "lr": lr,
    },
    {
        "params": list(param_group["B"].values()),
        "weight_decay": weight_decay,
        "lr": lr * loraplus_lr_ratio,
    },
]
optimizer_grouped_parameters

[{'params': [Parameter containing:
   tensor([[ 3.9330e-02,  3.9752e-02, -8.0098e-02, -2.6622e-02, -5.8098e-02,
            -3.0868e-02,  7.7938e-02,  8.5724e-02, -3.1529e-02, -3.4676e-02,
            -1.9467e-02,  5.9036e-02,  8.0732e-02, -2.1240e-03, -1.3514e-02,
             6.6837e-02, -7.0199e-02,  1.5412e-02, -3.5974e-03,  5.7088e-02,
            -8.6246e-02, -8.2492e-02,  2.2794e-02,  7.5971e-02,  2.9438e-02,
            -7.1050e-02, -8.0348e-02, -5.1602e-03,  2.2886e-02, -4.2793e-02,
             4.1867e-02, -6.3142e-02, -1.6274e-03, -7.5760e-02,  5.0614e-02,
            -3.5353e-03, -4.5013e-03,  5.0154e-02, -4.6476e-02, -3.5197e-02,
            -3.7123e-02,  3.7014e-02, -2.7894e-02, -8.4516e-02,  1.3658e-02,
            -1.5273e-02,  4.6396e-02,  1.7682e-02,  4.8418e-02,  5.1854e-02,
            -8.2378e-02, -7.5649e-02,  2.5404e-03, -3.6122e-02, -8.4040e-02,
             3.4419e-02,  3.5138e-02,  8.6566e-02, -1.2819e-02,  5.1159e-02,
            -7.4816e-02, -3.8456e-02,  3.

In [14]:
optimizer_v2 = torch.optim.Adam(optimizer_grouped_parameters, lr=lr)
optimizer_v2

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0.0

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.016
    maximize: False
    weight_decay: 0.0
)

In [15]:
from LoRA import create_loraplus_optim

optimizer_v2 = create_loraplus_optim(net, torch.optim.Adam,{'lr': 0.001, 'weight_decay': 0.},16)
optimizer_v2

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0.0

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.016
    maximize: False
    weight_decay: 0.0
)

In [16]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

train(train_loader, net, epochs=1, total_iterations_limit=25, optim=optimizer_v2)

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0.0

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.016
    maximize: False
    weight_decay: 0.0
)


Epoch 1:  96%|█████████▌| 24/25 [00:00<00:00, 119.10it/s, loss=0.0926]


In [17]:
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 278.05it/s]

Accuracy: 0.936
wrong counts for the digit 0: 37
wrong counts for the digit 1: 49
wrong counts for the digit 2: 66
wrong counts for the digit 3: 124
wrong counts for the digit 4: 84
wrong counts for the digit 5: 51
wrong counts for the digit 6: 69
wrong counts for the digit 7: 125
wrong counts for the digit 8: 23
wrong counts for the digit 9: 12





We reduced the number of iterations by `2` and even got betteer result of digit 9.