In [None]:
# Why LoRA  ?
# well we dont want to fine tune a whole model just for a simple specific task
# we want an adapter that attaches to the model to add specificity


# instead of backpropagating through the whole model , for each / some weight
#  matrices (D , K)  , we attach small two matrices (D , r) , (r , K)
# such that their multiplication gives rise to a matrix (D , K)
#  r << min(D,k)

# notice that sizeof(mat(1000,1000)) >> sizeof(mat(1000,10)) + sizeof(mat(10,1000))
# 1,000,000 >> 10,000 + 10,000
# by matmul these two matrices we get a matrix with the same size as the original
# but it is less rich in information (high in redundant information ==> lower entropy)

# we want those two small matrices to hold enough information about the specificity of
#  the task while utilizing the general information from the model weights
# we are adding `specificity` params

# thus we can use the same big / `general` weights for many adapters (different specific
# tasks)


In [None]:
# pretrained models usually have low intrinsic rank , meaning we
# can represent the same information in a smaller space
# for example using the singular value  decomposition
# we get that w = u @ s @ v_t ,
# let us define the two matrices :
# A = v , B = u @ s
# and we notice that the w tranformation on the vector x
# is identical to the (B @ A) transformation on the vector x


In [1]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader

# reproducable results (determinism)
torch.manual_seed(1)

<torch._C.Generator at 0x7f75117e19d0>

In [4]:


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307  ),(0.3081  ))
])

mnist_training = datasets.MNIST(root='./data',train=True,
                                download=True,transform=transform)
train_loader = DataLoader(mnist_training,batch_size = 8,shuffle=True)


mnist_testing = datasets.MNIST(root='./data',train=False,
                              download=True,transform=transform)
test_loader = DataLoader(mnist_testing,batch_size=8,shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')




class OverTheMoon(nn.Module):
    def __init__(self,hidden_size1=1000,hidden_size2=1000):
        super().__init__()

        self.linear1 = nn.Linear(28*28,hidden_size1)
        self.linear2 = nn.Linear(hidden_size1,hidden_size2)
        self.linear3 = nn.Linear(hidden_size2,10)

        self.relu = nn.ReLU()

    def forward(self,img):
        # we flatten the image
        x = img.view(-1,28*28)

        x = self.linear1(x)
        x = self.relu(x)

        x = self.linear2(x)
        x = self.relu(x)

        x = self.linear3(x)
        # x = self.relu(x)

        return x



model = OverTheMoon().to(device)





In [10]:
def train(model, dataloader,epochs=10,total_iterations_bound = None):

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

    total_iters = 0
    for epoch in range(epochs) :
        model.train()
        data_iterator = tqdm(dataloader,desc=f'Epoch : {epoch+1}/{epochs}')


        total_loss = 0
        num_iters = 0

        if total_iterations_bound is not None :
            data_iterator.total = total_iterations_bound

        for batch in data_iterator :
            num_iters += 1
            total_iters +=1
            x,y = batch

            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()

            output = model(x)
            loss = loss_fn(output,y)
            total_loss += loss.item()

            avg_loss = total_loss/num_iters

            data_iterator.set_postfix(loss=avg_loss)

            loss.backward()
            optimizer.step()


            if total_iterations_bound is not None and total_iters >= total_iterations_bound :
                return




train(model, train_loader ,epochs=3)



Epoch : 1/3: 100%|██████████| 7500/7500 [00:36<00:00, 205.79it/s, loss=0.0748]
Epoch : 2/3: 100%|██████████| 7500/7500 [00:36<00:00, 203.59it/s, loss=0.0613]
Epoch : 3/3: 100%|██████████| 7500/7500 [00:35<00:00, 210.32it/s, loss=0.0614]


In [11]:
# we stash the original pre-trained weights

original_weights = {}

for name , param in model.named_parameters():
    original_weights[name] = param.clone().detach()


print('Pre-Trained Weights Stashed')

Pre-Trained Weights Stashed


In [12]:
def test(model,dataloader):
    correct = 0
    total = 0
    miss_classifications = [ 0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(dataloader,desc='Test'):
            x,y = data
            x = x.to(device)
            y = y.to(device)

            output = model(x.view(-1,784))

            for idx , i in enumerate(output):
                # if the highest logit correspond to the correct label
                if torch.argmax(i)==y[idx]:
                    correct += 1
                else :
                    miss_classifications[y[idx]] += 1

                total += 1

    print(f'Accuracy : {round(correct/total,3)}')
    for i in range(10):
        print(f'Miss-classifications of the digit: {i} : {miss_classifications[i]}')







test(model,test_loader)


Test: 100%|██████████| 1250/1250 [00:02<00:00, 424.85it/s]

Accuracy : 0.974
Miss-classifications of the digit: 0 : 12
Miss-classifications of the digit: 1 : 12
Miss-classifications of the digit: 2 : 29
Miss-classifications of the digit: 3 : 22
Miss-classifications of the digit: 4 : 33
Miss-classifications of the digit: 5 : 35
Miss-classifications of the digit: 6 : 16
Miss-classifications of the digit: 7 : 45
Miss-classifications of the digit: 8 : 21
Miss-classifications of the digit: 9 : 40





In [17]:
print(model)
!pip install torchinfo
from torchinfo import summary

summary(model,input=(8,1,28,28))

OverTheMoon(
  (linear1): Linear(in_features=784, out_features=1000, bias=True)
  (linear2): Linear(in_features=1000, out_features=1000, bias=True)
  (linear3): Linear(in_features=1000, out_features=10, bias=True)
  (relu): ReLU()
)
Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


Layer (type:depth-idx)                   Param #
OverTheMoon                              --
├─Linear: 1-1                            785,000
├─Linear: 1-2                            1,001,000
├─Linear: 1-3                            10,010
├─ReLU: 1-4                              --
Total params: 1,796,010
Trainable params: 1,796,010
Non-trainable params: 0

In [23]:
class LoRAParametrization(nn.Module):
    def __init__(self,D,K,rank=1,alpha=1,device='cpu'):
        super().__init__()

        #  in the paper they init the lora_A weights with
        # random Gaussian
        self.lora_A = nn.Parameter(torch.zeros((rank,K)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((D,rank)).to(device))


        nn.init.normal_(self.lora_A,mean=0,std=1)

        # HOW DID U MISS THIS LMAO
        # simply adding the LoRA produced matrix to the initial matrix
        # is not a controlled strategy to determine how much
        # the LoRA contributes to the whole model ,
        # we first normalize the with respect to the rank
        #
        # Because 𝐴@𝐵 can produce large values (especially if rank is small),
        # we scale it down by rank to normalize its magnitude.
        # then we introduce the hyperparameter alpha to control
        # how much the LoRA contribution is amplified / down-weighted


        self.scale = alpha / rank
        self.enabled = True



    def forward(self,original_weights):
        # if enabled we apply on the weights only and not the biases
        if self.enabled :
            return original_weights  + torch.matmul(self.lora_B,self.lora_A).view(original_weights.shape)*self.scale

        else  :
            return original_weights






In [24]:
# NOW WE NEED TO EMBEDD THAT PARMETERIZATION  TO OUR MODEL
import torch.nn.utils.parametrize as parametrize


def linear_layer_parameterization(layer,device,rank=1,alpha=1):
    D,K = layer.weight.shape



    return LoRAParametrization(D,K,rank,alpha,device)




parametrize.register_parametrization(
    model.linear1 , "weight",linear_layer_parameterization(model.linear1,device)
)

parametrize.register_parametrization(
    model.linear2,"weight",linear_layer_parameterization(model.linear2,device)
)

parametrize.register_parametrization(
    model.linear3,"weight",linear_layer_parameterization(model.linear3,device)
)





def enable_disable_lora(enabled=True):
    for layer in [model.linear1,model.linear2,model.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled




In [27]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([model.linear1,model.linear2,model.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement()+ layer.parametrizations['weight'][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement()+layer.bias.nelement()

    print(f"""
        Layer {index+1} : W: {layer.weight.shape} + B : {layer.bias.shape}
                          lora_A :{layer.parametrizations['weight'][0].lora_A.shape} + lora_B : {layer.parametrizations['weight'][0].lora_B.shape}
    """)




print(f'Total Params with Lora : {total_parameters_lora}')

print(f'Total Params non Lora : {total_parameters_non_lora}')


Params_Ratio = ((total_parameters_lora + total_parameters_non_lora) /total_parameters_non_lora ) * 100

print(f'Parameters Ratio : {1/Params_Ratio}')



        Layer 1 : W: torch.Size([1000, 784]) + B : torch.Size([1000])
                          lora_A :torch.Size([1, 784]) + lora_B : torch.Size([1000, 1])
    

        Layer 2 : W: torch.Size([1000, 1000]) + B : torch.Size([1000])
                          lora_A :torch.Size([1, 1000]) + lora_B : torch.Size([1000, 1])
    

        Layer 3 : W: torch.Size([10, 1000]) + B : torch.Size([10])
                          lora_A :torch.Size([1, 1000]) + lora_B : torch.Size([10, 1])
    
Total Params with Lora : 4794
Total Params non Lora : 1796010
Parameters Ratio : 0.009973378557577614


In [29]:
# NOW WE MUST FREEZE THE PARAMS :
for name,param in model.named_parameters():
    if 'lora' not in name :
        print(f' Freezing NON-LoRA Param : {name}')
        param.requires_grad = False

 Freezing NON-LoRA Param : linear1.bias
 Freezing NON-LoRA Param : linear1.parametrizations.weight.original
 Freezing NON-LoRA Param : linear2.bias
 Freezing NON-LoRA Param : linear2.parametrizations.weight.original
 Freezing NON-LoRA Param : linear3.bias
 Freezing NON-LoRA Param : linear3.parametrizations.weight.original


In [31]:
# We notice that the model doesnt do well on numbers 7 and 9
# NOW WE FINETUNE
"""
Miss-classifications of the digit: 0 : 12
Miss-classifications of the digit: 1 : 12
Miss-classifications of the digit: 2 : 29
Miss-classifications of the digit: 3 : 22
Miss-classifications of the digit: 4 : 33
Miss-classifications of the digit: 5 : 35
Miss-classifications of the digit: 6 : 16
Miss-classifications of the digit: 7 : 45
Miss-classifications of the digit: 8 : 21
Miss-classifications of the digit: 9 : 40

"""

mnist_training = datasets.MNIST(root='./data',train=True,download=True,transform=transform)
exclude_indices = (mnist_training.targets == 7) | (mnist_training.targets == 9)
mnist_training.data = mnist_training.data[exclude_indices]
mnist_training.targets = mnist_training.targets[exclude_indices]




train_loader = DataLoader(mnist_training,batch_size=8,shuffle=True)


train(model,train_loader,epochs=10)


Epoch : 1/10: 100%|██████████| 1527/1527 [00:08<00:00, 180.10it/s, loss=0.0273]
Epoch : 2/10: 100%|██████████| 1527/1527 [00:08<00:00, 189.96it/s, loss=0.00774]
Epoch : 3/10: 100%|██████████| 1527/1527 [00:08<00:00, 186.46it/s, loss=0.00507]
Epoch : 4/10: 100%|██████████| 1527/1527 [00:08<00:00, 176.04it/s, loss=0.00286]
Epoch : 5/10: 100%|██████████| 1527/1527 [00:08<00:00, 188.85it/s, loss=0.00308]
Epoch : 6/10: 100%|██████████| 1527/1527 [00:08<00:00, 186.22it/s, loss=0.00255]
Epoch : 7/10: 100%|██████████| 1527/1527 [00:08<00:00, 182.19it/s, loss=0.00194]
Epoch : 8/10: 100%|██████████| 1527/1527 [00:08<00:00, 190.87it/s, loss=0.00166]
Epoch : 9/10: 100%|██████████| 1527/1527 [00:09<00:00, 158.65it/s, loss=0.0015]
Epoch : 10/10: 100%|██████████| 1527/1527 [00:10<00:00, 142.74it/s, loss=0.0015]


In [32]:
test(model,test_loader)

Test: 100%|██████████| 1250/1250 [00:03<00:00, 375.97it/s]

Accuracy : 0.739
Miss-classifications of the digit: 0 : 104
Miss-classifications of the digit: 1 : 123
Miss-classifications of the digit: 2 : 605
Miss-classifications of the digit: 3 : 306
Miss-classifications of the digit: 4 : 332
Miss-classifications of the digit: 5 : 163
Miss-classifications of the digit: 6 : 62
Miss-classifications of the digit: 7 : 8
Miss-classifications of the digit: 8 : 903
Miss-classifications of the digit: 9 : 6





In [33]:
# THAT LORA INTEGRATION ALSO ALTERS THE OUTPUT FOR OTHER CLASSES

enable_disable_lora(False)
test(model,test_loader)
# WE GOT THE SAME RESULTS BEFORE FINETUNNING THE LORA PARAMS


Test: 100%|██████████| 1250/1250 [00:03<00:00, 395.98it/s]

Accuracy : 0.974
Miss-classifications of the digit: 0 : 12
Miss-classifications of the digit: 1 : 12
Miss-classifications of the digit: 2 : 29
Miss-classifications of the digit: 3 : 22
Miss-classifications of the digit: 4 : 33
Miss-classifications of the digit: 5 : 35
Miss-classifications of the digit: 6 : 16
Miss-classifications of the digit: 7 : 45
Miss-classifications of the digit: 8 : 21
Miss-classifications of the digit: 9 : 40



