In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

Make sure the model is deterministic 

In [2]:
_ = torch.manual_seed(0)

In [3]:
transform = transforms.Compose([transforms.ToTensor() ,

transforms.Normalize((0.1307,),(0.3081,))

])

# mnist_trainset = datasets.MNIST(root='./',train=True , download=True , transform= transform)

# train_loader = torch.utils.data.DataLoader(mnist_trainset , batch_size = 10 , shuffle = True)

# mnist_testset = datasets.MNIST(root='../data' , train = False , download = True , transform = transform)

# test_loader = torch.utils.data.DataLoader(mnist_testset,batch_size = 10 , shuffle = True)

# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


mnist_trainset = datasets.MNIST(root='./data', train=True, download=False, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=False, transform=transform)

train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=False)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
class bigboy(nn.Module):

    def __init__(self , hidden_size_1 = 1000 , hidden_size_2 =2000):

        super(bigboy,self).__init__()
        self.linear1 = nn.Linear(28*28 , hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1 , hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2,10)
        self.relu = nn.ReLU()

    def forward(self , img):

        x = img.view(-1 , 28*28) # what does this do??

        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)

        return x

net = bigboy().to(device)

In [5]:
def train(train_loader , net , epochs=5 , total_iterations_limit = None):

    criteria = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(),lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        
        net.train()

        loss_sum = 0
        num_iterations = 0
        data_iterator = tqdm(train_loader,desc=f'Epoch {epoch+1}')

        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:

            num_iterations+=1
            total_iterations+=1
            x,y = data
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            output = net(x.view(-1,28*28))

            loss = criteria(output,y)
            loss_sum += loss.item()

            avg_loss = loss_sum / num_iterations

            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:

                return

train(train_loader,net,epochs=1)



Epoch 1: 100%|██████████| 6000/6000 [00:12<00:00, 493.20it/s, loss=0.24] 


In [6]:
original_weights = {}

for name , param in net.named_parameters():

    original_weights[name] = param.clone().detach()

print(original_weights)

{'linear1.weight': tensor([[ 0.0262,  0.0456, -0.0029,  ...,  0.0484,  0.0302,  0.0285],
        [ 0.0123,  0.0171,  0.0216,  ...,  0.0118,  0.0261,  0.0021],
        [ 0.0080,  0.0430, -0.0052,  ...,  0.0078,  0.0293,  0.0362],
        ...,
        [-0.0107,  0.0528,  0.0511,  ...,  0.0201,  0.0462, -0.0064],
        [ 0.0733,  0.0366,  0.0223,  ...,  0.0566,  0.0502,  0.0510],
        [ 0.0133, -0.0104,  0.0408,  ...,  0.0502,  0.0236,  0.0340]],
       device='cuda:0'), 'linear1.bias': tensor([-5.1073e-02, -1.0745e-02, -3.5749e-02, -3.4667e-02, -5.9566e-04,
        -4.2161e-02, -4.6397e-02, -3.9079e-02, -5.1368e-02, -7.2194e-03,
        -2.4266e-02, -3.8512e-02, -2.7861e-02, -2.7927e-02,  1.2791e-02,
        -2.0260e-02, -5.2526e-02, -2.6016e-02, -2.8698e-02, -5.0895e-02,
        -3.2971e-02, -7.8956e-02,  2.7680e-02, -4.4517e-02, -2.7908e-02,
        -4.5049e-02, -7.0258e-04, -2.2814e-02, -2.3985e-02, -1.6930e-02,
         2.4631e-03, -2.3111e-02, -2.1943e-02, -6.7954e-03, -2.5725e

In [8]:
def test():

    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():

        for data in tqdm(test_loader , desc='Testing'):

                x,y = data

                x = x.to(device)
                y = y.to(device)
                output = net(x.view(-1,784))

                for idx , i in enumerate(output):

                    if torch.argmax(i) == y[idx]:
                        correct += 1
                    else:
                        wrong_counts[y[idx]] +=1
                    total+=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

    

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 685.48it/s]

Accuracy: 0.956
wrong counts for the digit 0: 16
wrong counts for the digit 1: 19
wrong counts for the digit 2: 55
wrong counts for the digit 3: 112
wrong counts for the digit 4: 20
wrong counts for the digit 5: 23
wrong counts for the digit 6: 36
wrong counts for the digit 7: 47
wrong counts for the digit 8: 18
wrong counts for the digit 9: 94





# Now lets calculate the total number of parameters

In [9]:
total_parameters_original = 0

for index , layer in enumerate([net.linear1 , net.linear2 , net.linear3]):

    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')

print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


lets define LoRA parameterization : https://pytorch.org/tutorials/intermediate/parametrizations.html

In [10]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        # Section 4.1 of the paper: 
        #   We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        
        # Section 4.1 of the paper: 
        #   We then scale ∆Wx by α/r , where α is a constant in r. 
        #   When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately. 
        #   As a result, we simply set α to the first r we try and do not tune it. 
        #   This scaling helps to reduce the need to retune hyperparameters when we vary r.
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

In [11]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias

    # From section 4.2 of the paper:
    #   We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency.
    #   [...]
    #   We leave the empirical investigation of [...], and biases to a future work.
    
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)


def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled