https://arxiv.org/abs/2106.09685 (paper) <br/>
https://github.com/microsoft/LoRA/blob/main/loralib/ (ref code) <br/>

In [4]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
from typing import Optional, List

In [115]:
# lora layer: basic idea is to keep the back propagted weight changes "dw" during fine-tuning as seperate 
# from main pre-trained weights. 
# further decompose "dw" into two smaller matrices with rank "r", this is the primary assumption of lora. 
# this formulation allows nice decomposition of fine-tuning weights changes as seperate from main 
# pre-trained weights, and if required can be combined at the inference time.
# these primarary operation can be captured in lora layer, and specific layers can be inhertied
# with modification to accomadate this change.

class LoRALayer:
    def __init__(self, r, lora_alpha, lora_dropout, merge_weights):
        self.r = r
        self.lora_alpha = lora_alpha
        
        if lora_dropout > .0:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        
        self.merged = False
        self.merge_weights = merge_weights

In [119]:
# embedding layer: this is normal "nn.Embedding" layer which takes vocab index and return corrosponding 
# embedding.
# this is same nn.Embedding layer with custom changes to accommodate lora layer.


class Embedding(nn.Embedding, LoRALayer):
    def __init__(self, num_embeddings, embedding_dim, r, lora_alpha, merge_weights, **kwargs):
        nn.Embedding.__init__(self, num_embeddings=num_embeddings, embedding_dim=embedding_dim, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0, merge_weights=merge_weights)
        
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
            self.scaling = self.lora_alpha/self.r
            
            self.weight.requires_grad = False
        
        self.reset_parameters()
    
    
    def reset_parameters(self):
        nn.Embedding.reset_parameters(self)
        
        if hasattr('lora_A'):
            # there is a mistake in original source code repo
            # raised a pull request to the original repo
            nn.init.normal_(self.lora_A)
            nn.init.zeros_(self.lora_B)
    
    def train(self, mode=True):
        nn.Embedding.train(self, mode)
        
        if mode:
            if self.merge_weights and self.merged:
                if self.r > 0:
                    self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                if self.r > 0:
                    self.weight.data += (self.lora_B @ slef.lora_A).transpose(0, 1) * self.scaling
                    self.merged = True
    
    def forward(self, x):
        if self.r > 0 and not self.merged:
            result = nn.Embedding.forward(self, x)
            
            # x: vocabulary index.
            # padding_idx: vocabulary index should be treated as padding,
            # shorted lengths can be padded to match common lenght.
            # max_norm: maximum allowed l2 norm of the embedding vector,
            # if norm > max_norm then vector = (vector/vector_norm) * max_norm.
            # norm_type: type of max_norm, l2 or inf norm.
            # scale_grad_by_freq: if True will scale down frequent word based on mini-batch graidents,
            # to avoid learning dominance by a few words.
            # sparse: if true then backward grad computation will be handled by torch.sparse module.
            # torch.sparse: storage mechanism to store sparse tensors efficiently.
            
            after_A = F.embedding(x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm, \
                                 self.norm_type, self.scale_grad_by_freq, self.sparse)
            
            result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
            
            return result
            
        else:
            return nn.Embedding.forward(self, x)
     

In [149]:
# lora layer modification on "nn.Linear" layer

class Linear(nn.Linear, LoRALayer):
    def __init__(self, in_features, out_features, r, lora_alpha, lora_dropout, \
                 fan_in_fan_out, merge_weights, **kwargs):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(self, r, lora_alpha=lora_alpha, \
                           lora_dropout=lora_dropout, merge_weights=merge_weights)
        
        self.fan_in_fan_out = fan_in_fan_out
        
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
            
            self.scaling = self.lora_alpha / self.r
            
            # freeze the primary weights
            self.weight.requires_grad = False
        
        
        self.reset_parameters()
        
        # for fan_in_fan_out = True. expectation here is input x = (num_in_features, num_of_item_in_batch)
        if fan_in_fan_out:
            self.weight.data = self.weight.data.transpose(0, 1)
            
    def reset_parameters(self):
        nn.Linear.reset_parameters(self)
        
        if hasattr(self, 'lora_A'):
            
            # common normalization technique keep the varaince under controll.
            # a: depeding the on type non-linear activation this needs to be controlled.
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            
            nn.init.zeros_(self.lora_B)
    
    def train(self, mode=True):
        nn.Linear.train(self, mode)
        
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        
        if mode:
            
            # before training if already merged, de-merge the lora weights
            if self.merge_weights and self.merged:
                if self.r > 0:
                    self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
                self.merged = False
        
        else:
            
            # if not training, merge the weights for inference ready.
            if self.merge_weights and not self.merged:
                if self.r > 0:
                    self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
                self.merged = True
    
    def forward(self, x):
    
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        
        result = F.linear(x, T(self.weight), bias=self.bias)
        
        if self.r > 0 and not self.merged:
            result += (self.lora_droput(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1))\
            * self.scaling
            return result
        else:
            return result