# https://pytorch.org/torchtune/stable/tutorials/lora_finetune.html
LoRA memory savings come primarily from gradient and optimizer states, so if your model’s peak memory comes in its forward() method, then LoRA may not reduce peak memory.

nn.Linear(in_dim,out_dim) layer could have rank as high as min(in_dim,out_dim)

**The main idea:** Instead of updating weights of a layer, freeze the layer, and add a new low-rank-optimization layer and fine-tune it. 
**But:** Low-rank approximation of a matrix is a optimization problem. 

What does it do? Create a branch-network with lower number of parameters and train it. Then sum up the old outputs with new outputs! Training improves, inference does not!

In [1]:
from torch import nn, Tensor

In [2]:
class LoRALinear(nn.Module):

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        rank: int,
        alpha: float,
        dropout: float
    ):
        # original pretrained layers
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.linear = nn.Linear(
            self.in_dim, 
            self.out_dim, 
            bias=False
        )
        
        # -----------------LoRA----------------
        # new hyper parameters
        self.rank = rank
        self.alpha = alpha

        # new Kabab-me-haddi
        self._addLoraLayers()

        # now freeze the original model params
        self._prepWeightsForFinetuning()

    def _addLoraLayers(self):
        self.lora_a = nn.Linear(
            self.in_dim, 
            self.rank, 
            bias=False
        )
        self.lora_b = nn.Linear(
            self.rank, 
            self.out_dim, 
            bias = False
        )
        # follow the convention
        self.lora_dropout = nn.Dropout(p=dropout)

    def _prepWeightsForFinetuning(self):
        self.linear.weight.requires_grad = False
        self.lora_a.weight.requires_grad = True
        self.lora_b.weight.requires_grad = True

    def forward(self, x: Tensor) -> Tensor:
        frozen_out = self.linear(x)

        lora_out = self.lora_b(
            self.lora_a(
                self.lora_dropout(x)
            )
        )

        return frozen_out + (self.alpha / self.rank) * lora_out
    
        

In [3]:
# Get the Llama2
from  torchtune.models.llama2 import llama2_7b, lora_llama2_7b

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'torch._higher_order_ops'