In [86]:
from LoRA import Lora, LoraLinear
import torch
import torch.nn as nn
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
from tqdm import tqdm


In [87]:
torch.manual_seed(0);

## Setup

In [88]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#### Create an overly expensive neural network to classify MNIST digits

In [89]:
class Net(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(Net,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [90]:
net = Net().to(device)

In [91]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
train(train_loader, net, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:34<00:00, 171.83it/s, loss=0.237]


### Evaluate the model

In [92]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 294.59it/s]

Accuracy: 0.957
wrong counts for the digit 0: 33
wrong counts for the digit 1: 34
wrong counts for the digit 2: 42
wrong counts for the digit 3: 69
wrong counts for the digit 4: 19
wrong counts for the digit 5: 23
wrong counts for the digit 6: 45
wrong counts for the digit 7: 49
wrong counts for the digit 8: 17
wrong counts for the digit 9: 97





In [93]:
net.linear1,net.linear1.weight.shape

(Linear(in_features=784, out_features=1000, bias=True),
 torch.Size([1000, 784]))

In [94]:
out_f,in_f = net.linear1.weight.shape
rank = 2
alpha = 1

In [95]:
in_f

784

In [96]:
lora_linear = LoraLinear(in_f=in_f, out_f=out_f, r=rank, lora_alpha=alpha)

In [97]:
lora_linear.weight.data = net.linear1.weight.data.clone()
lora_linear.bias.data = net.linear1.bias.data.clone()

In [98]:
lora_linear.weight.data == net.linear1.weight.data

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]], device='cuda:0')

In [99]:
def replace_linear_with_lora(model, rank, dropout_rate=0.0):
    for name,module in model.named_children():
        if isinstance(module,nn.Linear):
            out_f,in_f = module.weight.shape
            bias = module.bias is not None
            lora_layer = LoraLinear(in_f, out_f, rank, lora_dropout=dropout_rate)
            lora_layer.weight.data = module.weight.data.clone()
            if bias:
                lora_layer.bias.data = module.bias.data.clone()
            setattr(model, name, lora_layer)
        else:
            # Recursively apply this function to child modules
            replace_linear_with_lora(module, rank, dropout_rate=dropout_rate)

In [100]:
net.linear1.weight.data

tensor([[ 0.0228,  0.0422, -0.0064,  ...,  0.0450,  0.0268,  0.0251],
        [ 0.0013,  0.0061,  0.0106,  ...,  0.0008,  0.0151, -0.0089],
        [ 0.0145,  0.0496,  0.0014,  ...,  0.0143,  0.0359,  0.0427],
        ...,
        [ 0.0082,  0.0718,  0.0701,  ...,  0.0390,  0.0652,  0.0126],
        [ 0.0803,  0.0436,  0.0293,  ...,  0.0636,  0.0572,  0.0580],
        [ 0.0246,  0.0009,  0.0521,  ...,  0.0615,  0.0350,  0.0453]],
       device='cuda:0')

In [101]:
replace_linear_with_lora(net, rank=rank)

In [102]:
net.to(device)

Net(
  (linear1): LoraLinear(in_features=784, out_features=1000, bias=True)
  (linear2): LoraLinear(in_features=1000, out_features=2000, bias=True)
  (linear3): LoraLinear(in_features=2000, out_features=10, bias=True)
  (relu): ReLU()
)

Copied weights are the same

In [103]:
net.linear1.weight.data == lora_linear.weight.data

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]], device='cuda:0')

In [104]:
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 268.93it/s]

Accuracy: 0.957
wrong counts for the digit 0: 33
wrong counts for the digit 1: 34
wrong counts for the digit 2: 42
wrong counts for the digit 3: 69
wrong counts for the digit 4: 19
wrong counts for the digit 5: 23
wrong counts for the digit 6: 45
wrong counts for the digit 7: 49
wrong counts for the digit 8: 17
wrong counts for the digit 9: 97





### Fine-tune the lora layers

In [105]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

train(train_loader, net, epochs=1, total_iterations_limit=50)

Epoch 1:  98%|█████████▊| 49/50 [00:00<00:00, 134.69it/s, loss=0.204]


In [107]:
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 264.16it/s]

Accuracy: 0.944
wrong counts for the digit 0: 29
wrong counts for the digit 1: 54
wrong counts for the digit 2: 56
wrong counts for the digit 3: 116
wrong counts for the digit 4: 70
wrong counts for the digit 5: 24
wrong counts for the digit 6: 61
wrong counts for the digit 7: 98
wrong counts for the digit 8: 42
wrong counts for the digit 9: 13



