https://github.com/microsoft/LoRA/blob/main/loralib/ (ref)

In [4]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
from typing import Optional, List

In [70]:
# todo: implement custom dropout

class LoRALayer:
    def __init__(self, r, lora_alpha, lora_dropout, merge_weights):
        self.r = r
        self.lora_alpha = lora_alpha
        
        if lora_dropout > .0:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        
        self.merged = False
        self.merge_weights = merge_weights

In [109]:
class Embedding(nn.Embedding, LoRALayer):
    def __init__(self, num_embeddings, embedding_dim, r, lora_alpha, merge_weights, **kwargs):
        nn.Embedding.__init__(self, num_embeddings=num_embeddings, embedding_dim=embedding_dim, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0, merge_weights=merge_weights)
        
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
            self.scaling = self.lora_alpha/self.r
            
            self.weight.requires_grad = False
        
        self.reset_parameters()
    
    
    def reset_parameters(self):
        nn.Embedding.reset_parameters(self)
        
        if hasattr('lora_A'):
            # there is a mistake in original source code repo
            # raised a pull request to the original repo
            nn.init.normal_(self.lora_A)
            nn.init.zeros_(self.lora_B)
    
    def train(self, mode=True):
        nn.Embedding.train(self, mode)
        
        if mode:
            if self.merge_weights and self.merged:
                if self.r > 0:
                    self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                if self.r > 0:
                    self.weight.data += (self.lora_B @ slef.lora_A).transpose(0, 1) * self.scaling
                    self.merged = True
    
    def forward(self, x):
        if self.r > 0 and not self.merged:
            result = nn.Embedding.forward(self, x)
            
            # x: vocabulary index.
            # padding_idx: vocabulary index should be treated as padding,
            # shorted lengths can be padded to match common lenght.
            # max_norm: maximum allowed l2 norm of the embedding vector,
            # if norm > max_norm then vector = (vector/vector_norm) * max_norm.
            # norm_type: type of max_norm, l2 or inf norm.
            # scale_grad_by_freq: if True will scale down frequent word based on mini-batch graidents,
            # to avoid learning dominance by a few words.
            # sparse: if true then backward grad computation will be handled by torch.sparse module.
            # torch.sparse: ??
            
            after_A = F.embedding(x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm, \
                                 self.norm_type, self.scale_grad_by_freq, self.sparse)
            
            result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
            
            return result
            
        else:
            return nn.Embedding.forward(self, x)
     