In [36]:
import torch
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torch.nn as nn
from torch.nn import functional as Fn
from tqdm import tqdm


In [37]:
seed = torch.manual_seed(0)


In [38]:
import torch.utils
import torch.utils.data


transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081))])

mnistTrain = dataset.MNIST(root="/Users/ishananand/Desktop/Deep-Learning-Techniques/dataset", train=True, transform=transform, download=True)
mnistTest = dataset.MNIST(root="/Users/ishananand/Desktop/Deep-Learning-Techniques/dataset", train=False, transform=transform, download=True)

trainLoader = torch.utils.data.DataLoader(mnistTrain, batch_size=32, shuffle=True)
testLoader = torch.utils.data.DataLoader(mnistTest, batch_size=32, shuffle=True)

device = torch.device("cpu")

In [39]:
class StrongCustomModel(nn.Module):

    def __init__(self, ):
        super(StrongCustomModel, self).__init__()
        self.linear1 = nn.Linear(in_features=28*28, out_features=264)
       
        self.linear2 = nn.Linear(in_features=264, out_features =512)
        self.linear3 = nn.Linear(in_features=512, out_features =10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.linear1(x)
        x = Fn.relu(x)
        x = self.linear2(x)
        x = Fn.relu(x)
        x = self.linear3(x)
        return x


customModel = StrongCustomModel().to(device)

In [40]:
def trainModel(trainLoader, model, epochs = 5):
    lossFn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(lr =  0.0001, params=model.parameters())

    for each_spoch in range(epochs):

        model.train()

        data_iterator = tqdm(trainLoader, desc=f'Epoch {each_spoch + 1}')
        num_iteration = 0
        lossSum = 0
        for data in data_iterator:
            num_iteration += 1

            x, y = data
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            
            y_pred = model(x.view(-1, 28 * 28))

            loss = lossFn(y_pred, y)

            lossSum += loss.item()

            avg_loss = lossSum/num_iteration

            data_iterator.set_postfix(loss = avg_loss)
            loss.backward()
            optimizer.step()


trainModel(trainLoader, customModel)

Epoch 1: 100%|██████████| 1875/1875 [00:07<00:00, 260.41it/s, loss=0.364]
Epoch 2: 100%|██████████| 1875/1875 [00:07<00:00, 248.19it/s, loss=0.156]
Epoch 3: 100%|██████████| 1875/1875 [00:06<00:00, 282.47it/s, loss=0.108]
Epoch 4: 100%|██████████| 1875/1875 [00:06<00:00, 287.91it/s, loss=0.0809]
Epoch 5: 100%|██████████| 1875/1875 [00:06<00:00, 292.66it/s, loss=0.0634]


In [41]:
originalModelWeight = {}
for name, parameter in customModel.named_parameters():
    originalModelWeight[name] = parameter

def count_parameters(model):
    total_params = 0
    for param in model.parameters():
        if param.requires_grad:
            total_params += param.numel()
    return total_params

count_parameters(customModel)

348050

In [42]:
class LoRA(nn.Module):
    def __init__(self, input_feature, out_feature, rank =1, alpha = 1, device = device):
        super(LoRA, self).__init__()
        # original size of the model is D*K then the size of A and b is D * r and r * K
        self.Avector = nn.Parameter(torch.zeros(rank, out_feature)).to(device)
        self.Bvector = nn.Parameter(torch.zeros(input_feature, rank)).to(device)
        nn.init.normal(self.Avector, mean = 0, std = 1)

        # in the actual Paper it is given that AVecotor is n=initialized with random Gaussian and B with 0's

        self.scale = alpha/rank 
        # The Scale Parameter helps to reduce hyperparameters when we Vary r
        self.enable = True

    def forward(self, originalModelWeight):
        if(self.enable):
            return originalModelWeight + torch.matmul(self.Bvector, self.Avector).view(originalModelWeight.shape) * self.scale
        else:
            return originalModelWeight

    



In [43]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias

    # From section 4.2 of the paper:
    #   We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency.
    #   [...]
    #   We leave the empirical investigation of [...], and biases to a future work.
    
    features_in, features_out = layer.weight.shape
    return LoRA(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

parametrize.register_parametrization(
    customModel.linear1, "weight", linear_layer_parameterization(customModel.linear1, device)
)
parametrize.register_parametrization(
    customModel.linear2, "weight", linear_layer_parameterization(customModel.linear2, device)
)
parametrize.register_parametrization(
    customModel.linear3, "weight", linear_layer_parameterization(customModel.linear3, device)
)


def enable_disable_lora(enabled=True):
    for layer in [customModel.linear1, customModel.linear2, customModel.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

  nn.init.normal(self.Avector, mean = 0, std = 1)


In [45]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([customModel.linear1, customModel.linear2, customModel.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].Avector.nelement() + layer.parametrizations["weight"][0].Bvector.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].Avector.shape} + Lora_B: {layer.parametrizations["weight"][0].Bvector.shape}'
    )
# The non-LoRA parameters count must match the original network
# assert total_parameters_non_lora == originalModelWeight
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([264, 784]) + B: torch.Size([264]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([264, 1])
Layer 2: W: torch.Size([512, 264]) + B: torch.Size([512]) + Lora_A: torch.Size([1, 264]) + Lora_B: torch.Size([512, 1])
Layer 3: W: torch.Size([10, 512]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 512]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 348,050
Total number of parameters (original + LoRA): 350,396
Parameters introduced by LoRA: 2,346
Parameters incremment: 0.674%
