In [21]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm


In [22]:
## setting random seed
_=torch.manual_seed(0)

### Loading the MNIST dataset --> Applying transforms --> shifting to cuda/mps

In [23]:

transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])

## loading the dataset
mnist_dataset=datasets.MNIST(root='/Users/praroopchanda/Desktop/Models_Coding_Practice/LoRA',train=True,download=True,transform=transform)
## creating a dataloader for training
train_loader=DataLoader(mnist_dataset,batch_size=10,shuffle=True)

#MNIST Test set
mnist_testset=datasets.MNIST(root="/Users/praroopchanda/Desktop/Models_Coding_Practice/LoRA",train=False,download=True,transform=transform)
test_loader=DataLoader(mnist_testset,batch_size=10,shuffle=True)

# device
device=torch.device("mps" if torch.backends.mps.is_available() else "cpu")

## Defining the Network

In [24]:
## creating overly expensive neural network
## in this passing image as by flattening it

class ExpensiveNet(nn.Module):
    def __init__(self,hidden_size_1=1000,hidden_size_2=2000):
        super().__init__()
        self.linear_1=nn.Linear(28*28,hidden_size_1)
        self.linear_2=nn.Linear(hidden_size_1,hidden_size_2)
        self.linear_3=nn.Linear(hidden_size_2,10)
        self.relu=nn.ReLU()

    def forward(self,img):
        x=img.view(-1,28*28) ## flattening the image tensor
        x=self.relu(self.linear_1(x))
        x=self.relu(self.linear_2(x))
        x=self.linear_3(x)
        return x


## moving to device --> mps in our case
net=ExpensiveNet().to(device)

##  Training just for one Epoch -->one pass of the data 

In [25]:
def train(train_loader:DataLoader,net,epochs=5,total_iteration_limit=None):
    cross_el=nn.CrossEntropyLoss() ## defining the loss
    optimizer=torch.optim.Adam(net.parameters(),lr=0.001) ## defining the optimizer

    total_iterations=0

    for epoch in range(epochs):
        net.train()

        loss_sum=0
        num_iterations=0

        data_iterator=tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iteration_limit is not None:
            data_iterator.total=total_iteration_limit

        for data in data_iterator:
            num_iterations+=1
            total_iterations+=1

            x,y=data
            x=x.to(device)
            y=y.to(device)
            optimizer.zero_grad()

            output=net(x.view(-1,28*28))  ## i think we should be able to do it directly as well --> as it is changing the dimensions in the forward method
            loss=cross_el(output,y)

            loss_sum+=loss.item()
            avg_loss=loss_sum/num_iterations

            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iteration_limit is not None and total_iterations>=total_iteration_limit:
                return


train(train_loader,net,epochs=1)            



Epoch 1: 100%|██████████| 6000/6000 [00:38<00:00, 156.75it/s, loss=0.236]


## keeping a copy of original weights(cloning them) to check later for modifications

In [26]:
original_weights={}
for name,param in net.named_parameters():
    original_weights[name]=param.clone().detach()

## Testing the performance

In [28]:
def test():
    correct=0
    total=0

    wrong_counts=[0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader,desc="Testing"):
            x,y=data
            x=x.to(device)
            y=y.to(device)
            output=net(x.view(-1,28*28))
            
            for idx, i in enumerate(output):
                if torch.argmax(i)==y[idx]: ## matching the indices
                    correct+=1
                else:
                    wrong_counts[y[idx]]+=1
                total+=1

    print(f"Accuracy: {round(correct/total,3)}")
    for i in range(len(wrong_counts)):
        print(f"wrong count for {i} is {wrong_counts[i]}")                

test()



Testing: 100%|██████████| 1000/1000 [00:03<00:00, 253.73it/s]

Accuracy: 0.954
wrong count for 0 is 18
wrong count for 1 is 30
wrong count for 2 is 52
wrong count for 3 is 119
wrong count for 4 is 25
wrong count for 5 is 11
wrong count for 6 is 57
wrong count for 7 is 56
wrong count for 8 is 18
wrong count for 9 is 73





### We observe that accuracy of digit 3 is not very good


### checking total number of parameters in the original layers first

In [64]:
### Total parameter count including all layers
total_parameters_original=0
for index, layer in enumerate([net.linear_1,net.linear_2,net.linear_3]):
    total_parameters_original+=layer.weight.nelement() + layer.bias.nelement()
    print(f"Layer: {index+1}, W: {layer.weight.nelement()} + B: {layer.bias.nelement()}")
print(f"Total number of parameters are:",total_parameters_original)    

## Can also directly do by parameter count directly from the model
params_count=0

for params in net.parameters(): 
    params_count+=params.numel()

print(params_count)    

Layer: 1, W: 784000 + B: 1000
Layer: 2, W: 2000000 + B: 2000
Layer: 3, W: 20000 + B: 10
Total number of parameters are: 2807010
2813804


### Defining the LORA Parametrization from the paper

![alt text](image.png)

In [37]:
class LoRAParametrization(nn.Module):
    def __init__(self,features_in: int, features_out:int, rank:int=1,alpha:int =1,device="cpu"):
        super().__init__()
        ## section 4.1 of the paper
        ## random gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning
        self.lora_A=nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B=nn.Parameter(torch.zeros((features_in,rank)).to(device))
        nn.init.normal_(self.lora_A,mean=0,std=1) ## random gaussian initialization

        ## Section 4.1 of the paper:
        # Scaling ∆W by α/r , where α is a constant --> this will basically help preventing instability and large updates in training
        # when optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately
        # As a result, we just simply set α to first r and do not tune it
        # This scaling will also help reduce the effect when we increase r as its (α/r) , so reduce the need to retune hyperparameters as r is varied.

        self.scale=alpha/rank

        self.enabled=True

    def forward(self,original_weights):
        if self.enabled:
            ## return X+(B@A)*scale
            return original_weights+(self.lora_B @ self.lora_A).view(original_weights.shape)*self.scale
            ## can also do without view as it would not give error ---> return original_weights+torch.matmul(self.lora_B,self.lora_A).view(originak_weights.shape)*self.scale
        else:
            return original_weights    



### Using Pytorch Parametrization to inject/replace original weights with LoRA

In [38]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parametrization(layer, device,rank=1,lora_alpha=1):
    '''
    ** Only adding paramterization to the weights and leave the bias
    ** The study limits to only adapting weights to the downstream tasks, and freeze the MLP modules
    '''

    features_in, features_out=layer.weight.shape

    return LoRAParametrization(features_in, features_out,rank,lora_alpha,device)

'''
Registering LoRA now
'''

parametrize.register_parametrization(
    net.linear_1,"weight",linear_layer_parametrization(net.linear_1,device) ## this basically replaces net.linear1 with now LoRAParametrization() --> Tensor which has the same output
)

parametrize.register_parametrization(
    net.linear_2,"weight",linear_layer_parametrization(net.linear_2,device)
)

parametrize.register_parametrization(
    net.linear_3,"weight",linear_layer_parametrization(net.linear_3,device)
)


def enable_disable_lora(enabled=True):
    for layer in [net.linear_1,net.linear_2,net.linear_3]:
        layer.parametrizations["weight"][0].enabled=enabled



### Displaying total number of parameters added by LoRA

In [49]:
print(net.linear_1.weight)
print(net.linear_1.parametrizations["weight"][0].lora_A.shape)

tensor([[ 0.0262,  0.0457, -0.0029,  ...,  0.0485,  0.0302,  0.0286],
        [ 0.0026,  0.0074,  0.0120,  ...,  0.0021,  0.0164, -0.0075],
        [ 0.0110,  0.0461, -0.0021,  ...,  0.0109,  0.0324,  0.0392],
        ...,
        [ 0.0099,  0.0735,  0.0718,  ...,  0.0407,  0.0669,  0.0143],
        [ 0.0531,  0.0164,  0.0022,  ...,  0.0365,  0.0300,  0.0308],
        [ 0.0151, -0.0086,  0.0426,  ...,  0.0519,  0.0254,  0.0358]],
       device='mps:0', grad_fn=<AddBackward0>)
torch.Size([1, 784])


In [68]:
total_parameters_lora=0
total_parameters_non_lora=0
for index,layer in enumerate([net.linear_1,net.linear_2,net.linear_3]):
    total_parameters_lora+=layer.parametrizations['weight'][0].lora_A.nelement() + layer.parametrizations['weight'][0].lora_B.nelement()
    total_parameters_non_lora+=layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer:{index+1}: W: {layer.weight.shape}+ B:{layer.bias.shape}+ lora_A:{layer.parametrizations['weight'][0].lora_A.shape} + lora_B: {layer.parametrizations['weight'][0].lora_B.shape} '
    )     


## Non Lora parameters should be equal to original network
assert total_parameters_non_lora== total_parameters_original

print(f"Total number of parameters (original):",total_parameters_non_lora)
print(f"Total number of parameters (original+LoRA):",total_parameters_lora+total_parameters_non_lora)
print(f" number of parameters induced by LoRA):",total_parameters_lora)
parameters_increment=(total_parameters_lora/total_parameters_non_lora)*100

print(f"Parameters increment:{parameters_increment:.3f}%")



Layer:1: W: torch.Size([1000, 784])+ B:torch.Size([1000])+ lora_A:torch.Size([1, 784]) + lora_B: torch.Size([1000, 1]) 
Layer:2: W: torch.Size([2000, 1000])+ B:torch.Size([2000])+ lora_A:torch.Size([1, 1000]) + lora_B: torch.Size([2000, 1]) 
Layer:3: W: torch.Size([10, 2000])+ B:torch.Size([10])+ lora_A:torch.Size([1, 2000]) + lora_B: torch.Size([10, 1]) 
Total number of parameters (original): 2807010
Total number of parameters (original+LoRA): 2813804
 number of parameters induced by LoRA): 6794
Parameters increment:0.242%


In [69]:
# params_count=0
# for params in net.parameters(): 
#     params_count+=params.numel()

# print(params_count)    

### Now fine tuning only LORA weights for digit 3

In [73]:
 ## freezing non-lora parameters
for name, param in net.named_parameters():
    if 'lora' not in name:
        param.requires_grad=False
        print(f"freezing non LoRA paramater:{name}")

## Loading MNIST dataset again, by keeping only the digit 9
mnist_trainset=datasets.MNIST(root="/Users/praroopchanda/Desktop/Models_Coding_Practice/LoRA",train=True,transform=transform,download=True)
exlude_indices=mnist_trainset.targets==3
mnist_trainset.data=mnist_trainset.data[exlude_indices]
mnist_trainset.targets=mnist_trainset.targets[exlude_indices]

## create a dataloader for training
train_finetune_loader=torch.utils.data.DataLoader(mnist_trainset,batch_size=10, shuffle=True)

## Train/Fine tune the network now using LoRA and doing it for just 100 batches assumption being it would improve performance
train(train_finetune_loader,net,epochs=1,total_iteration_limit=100)

freezing non LoRA paramater:linear_1.bias
freezing non LoRA paramater:linear_1.parametrizations.weight.original
freezing non LoRA paramater:linear_2.bias
freezing non LoRA paramater:linear_2.parametrizations.weight.original
freezing non LoRA paramater:linear_3.bias
freezing non LoRA paramater:linear_3.parametrizations.weight.original


Epoch 1:  99%|█████████▉| 99/100 [00:02<00:00, 46.30it/s, loss=0.164]


### Verifying that we only changed the LoRA parameters and not the original ones

In [94]:
### comparing the frozen paramters to orignal weights 
assert torch.all(net.linear_1.parametrizations.weight.original==original_weights['linear_1.weight'])
assert torch.all(net.linear_2.parametrizations.weight.original==original_weights['linear_2.weight'])
assert torch.all(net.linear_3.parametrizations.weight.original==original_weights['linear_3.weight'])

## now comparing the paramters of full lora and doing matrix multiplication by ourselves
## should be same
enable_disable_lora(enabled=True)

assert torch.equal(net.linear_1.weight,net.linear_1.parametrizations.weight.original+(net.linear_1.parametrizations.weight[0].lora_B @ net.linear_1.parametrizations.weight[0].lora_A)** net.linear_1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
assert torch.equal(net.linear_1.weight,original_weights['linear_1.weight'])



### Testing the network now (should be a better classifier on digit 3)

In [95]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:06<00:00, 161.49it/s]

Accuracy: 0.926
wrong count for 0 is 29
wrong count for 1 is 38
wrong count for 2 is 135
wrong count for 3 is 22
wrong count for 4 is 27
wrong count for 5 is 47
wrong count for 6 is 64
wrong count for 7 is 141
wrong count for 8 is 143
wrong count for 9 is 89





In [96]:
# Test with LoRA disabled (Same accuracy as before)
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:04<00:00, 233.80it/s]

Accuracy: 0.954
wrong count for 0 is 18
wrong count for 1 is 30
wrong count for 2 is 52
wrong count for 3 is 119
wrong count for 4 is 25
wrong count for 5 is 11
wrong count for 6 is 57
wrong count for 7 is 56
wrong count for 8 is 18
wrong count for 9 is 73



