In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
from tqdm import  tqdm

_ = torch.manual_seed(0)

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])

train_ds = datasets.MNIST(root='./data',train=False,download=True,transform=transform)

test_ds = datasets.MNIST(root='./data',train=False,download=True,transform=transform)

train_dl =torch.utils.data.DataLoader(train_ds,batch_size=10,shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size=10,shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 6363927.66it/s] 


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 258659.54it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 2285834.02it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [25]:
class Net(nn.Module):
    def __init__(self,hidden_size_1 = 1000,hidden_size_2=2000):
        super().__init__()
        
        self.linear1 = nn.Linear(28*28,hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1,hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2,10)
        self.relu = nn.ReLU()
        
    def forward(self,x):
        x = x.view(-1,28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x
    
net = Net().to(device)

In [26]:
def train(train_loader,net,epochs=5,total_iteration_limit=None):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(),lr=0.01)
    
    total_iterations = 0
    
    for epoch in range(epochs):
        net.train()
        
        loss_sum = 0
        num_iterations = 0
        
        data_iterator = tqdm(train_loader,desc=f'Epoch {epoch+1}')
        if total_iteration_limit is not None:
            data_iterator.total = total_iteration_limit
        
        for data in data_iterator:
            num_iterations +=1
            total_iterations +=1
            
            x,y = data
            x = x.to(device)
            y = y.to(device)
            
            optimizer.zero_grad()
            
            preds = net(x.view(-1,28*28))
            loss = loss_fn(preds,y)
            loss_sum += loss.item()
            avg_loss = loss_sum/num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            
            loss.backward()
            optimizer.step()
            
            if total_iteration_limit is not None and total_iterations >= total_iteration_limit:
                return
            
train(train_dl,net,epochs=1)

Epoch 1: 100%|██████████| 1000/1000 [01:00<00:00, 16.51it/s, loss=1.02]


In [27]:
original_weights = {}
for name,param in net.named_parameters():
    original_weights[name] = param.clone().detach()

In [28]:
def test():
    correct = 0
    total = 0
    
    wrong_counts = [0 for i in range(10)]
    
    with torch.no_grad():
        for data in tqdm(test_dl,desc='Testing'):
            x ,y = data
            x = x.to(device)
            y = y.to(device)
            pred = net(x.view(-1,28*28))
            
            # (batch,predictions)
            for idx,i in enumerate(pred):
                if torch.argmax(i) == y[idx]:
                    correct+=1
                else:
                    wrong_counts[y[idx]] +=1
                total+=1
                
    print(f"Accuracy : {round(correct/total,3)}")
    for i in range(len(wrong_counts)):
        print(f"wrong counts for the digit {i} : {wrong_counts[i]}")
        
test()

Testing: 100%|██████████| 1000/1000 [00:05<00:00, 184.85it/s]

Accuracy : 0.851
wrong counts for the digit 0 : 214
wrong counts for the digit 1 : 22
wrong counts for the digit 2 : 160
wrong counts for the digit 3 : 151
wrong counts for the digit 4 : 213
wrong counts for the digit 5 : 153
wrong counts for the digit 6 : 151
wrong counts for the digit 7 : 76
wrong counts for the digit 8 : 93
wrong counts for the digit 9 : 256





In [30]:
total_parameters_original = 0

for index,layer in enumerate([net.linear1,net.linear2,net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f"Layer : {index+1}, W: {layer.weight.shape} + B : {layer.bias.shape}")
print(f"Total nr of params : {total_parameters_original:,}")

Layer : 1, W: torch.Size([1000, 784]) + B : torch.Size([1000])
Layer : 2, W: torch.Size([2000, 1000]) + B : torch.Size([2000])
Layer : 3, W: torch.Size([10, 2000]) + B : torch.Size([10])
Total nr of params : 2,807,010


In [38]:
class LoRAParametriaztion(nn.Module):
    def __init__(self,in_f,out_f,rank=1,alpha=1,device="cpu"):
        super().__init__()
        
        self.lora_A = nn.Parameter(torch.zeros((rank,out_f)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((in_f,rank)).to(device))
        nn.init.normal_(self.lora_A,mean=0,std=1)
        
        self.scale = alpha/rank
        self.enabled = True
        
    def forward(self,original_weights):
        if self.enabled:
            # X + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B,self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights


In [49]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer,device,rank=1,lora_alpha=1):
    
    features_in,features_out = layer.weight.shape
    
    return LoRAParametriaztion(
        features_in,features_out,rank=rank,alpha=lora_alpha,device=device
    )

parametrize.register_parametrization(
        net.linear1,"weight",linear_layer_parameterization(net.linear1,device)
)

parametrize.register_parametrization(
    net.linear2,"weight",linear_layer_parameterization(net.linear2,device)
)

parametrize.register_parametrization(
    net.linear3,"weight",linear_layer_parameterization(net.linear3,device)
)

def enable_disable_lore(enabled=True):
    for layer in [net.linear1,net.linear2,net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [50]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


In [51]:
for name,param in net.named_parameters():
    # print(name,param.shape)
    # print("\n")
    if 'lora' not in name:
        print(f"Freezing non-LoRA parameter {name}")
        param.requires_grad = False

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


In [52]:
mnist_ds = datasets.MNIST(root='./data',train=True,download=True,transform=transform)

exclude_indices = mnist_ds.targets == 9
mnist_ds.data = mnist_ds.data[exclude_indices]
mnist_ds.targets = mnist_ds.targets[exclude_indices]

train_dl = torch.utils.data.DataLoader(mnist_ds,batch_size=10,shuffle=True)

train(train_dl,net,epochs=1,total_iteration_limit=100)

Epoch 1:  99%|█████████▉| 99/100 [00:03<00:00, 28.07it/s, loss=2.38e-10]


In [56]:
assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lore(enabled=True)

assert torch.equal(net.linear1.weight,net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrization.weight[0].lora_A))

AttributeError: 'ParametrizedLinear' object has no attribute 'parametrization'

In [57]:
enable_disable_lore(True)
test()

Testing: 100%|██████████| 1000/1000 [00:12<00:00, 80.63it/s]

Accuracy : 0.227
wrong counts for the digit 0 : 694
wrong counts for the digit 1 : 1135
wrong counts for the digit 2 : 692
wrong counts for the digit 3 : 1010
wrong counts for the digit 4 : 982
wrong counts for the digit 5 : 889
wrong counts for the digit 6 : 346
wrong counts for the digit 7 : 1010
wrong counts for the digit 8 : 974
wrong counts for the digit 9 : 0





In [58]:
enable_disable_lore(False)
test()

Testing: 100%|██████████| 1000/1000 [00:12<00:00, 78.98it/s]

Accuracy : 0.854
wrong counts for the digit 0 : 208
wrong counts for the digit 1 : 27
wrong counts for the digit 2 : 150
wrong counts for the digit 3 : 135
wrong counts for the digit 4 : 223
wrong counts for the digit 5 : 167
wrong counts for the digit 6 : 139
wrong counts for the digit 7 : 73
wrong counts for the digit 8 : 132
wrong counts for the digit 9 : 204



