# https://pytorch.org/torchtune/stable/tutorials/lora_finetune.html
LoRA memory savings come primarily from gradient and optimizer states, so if your model’s peak memory comes in its forward() method, then LoRA may not reduce peak memory.

nn.Linear(in_dim,out_dim) layer could have rank as high as min(in_dim,out_dim)

**The main idea:** Instead of updating weights of a layer, freeze the layer, and add a new low-rank-optimization layer and fine-tune it. 
**But:** Low-rank approximation of a matrix is a optimization problem. 

What does it do? Create a branch-network with lower number of parameters and train it. Then sum up the old outputs with new outputs! Training improves, inference does not!

In [1]:
from torch import nn, Tensor

In [2]:
class LoRALinear(nn.Module):

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        rank: int,
        alpha: float,
        dropout: float
    ):
        # original pretrained layers
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.linear = nn.Linear(
            self.in_dim, 
            self.out_dim, 
            bias=False
        )
        
        # -----------------LoRA----------------
        # new hyper parameters
        self.rank = rank
        self.alpha = alpha

        # new Kabab-me-haddi
        self._addLoraLayers()

        # now freeze the original model params
        self._prepWeightsForFinetuning()

    def _addLoraLayers(self):
        self.lora_a = nn.Linear(
            self.in_dim, 
            self.rank, 
            bias=False
        )
        self.lora_b = nn.Linear(
            self.rank, 
            self.out_dim, 
            bias = False
        )
        # follow the convention
        self.lora_dropout = nn.Dropout(p=dropout)

    def _prepWeightsForFinetuning(self):
        self.linear.weight.requires_grad = False
        self.lora_a.weight.requires_grad = True
        self.lora_b.weight.requires_grad = True

    def forward(self, x: Tensor) -> Tensor:
        frozen_out = self.linear(x)

        lora_out = self.lora_b(
            self.lora_a(
                self.lora_dropout(x)
            )
        )

        return frozen_out + (self.alpha / self.rank) * lora_out
    
        

## Optimizing Llama2 Q,K,V projection layers
Self-attention in Llama2 has in_dim=out_dim=4096. So, each projection FFN has 4096x4096 = 16.7M parameters. With rank=8, we can reduce the number of trainable parameters of each projection to
4096x8 + 8x4096 = 65K params!

In [3]:
# Get the Llama2
from  torchtune.models.llama2 import llama2_7b, lora_llama2_7b

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
baseModel = llama2_7b()

In [6]:
# The default settings for lora_llama2_7b will match those for llama2_7b
# We just need to define which layers we want LoRA applied to.
# Within each self-attention, we can choose from ["q_proj", "k_proj", "v_proj", and "output_proj"].
# We can also set apply_lora_to_mlp=True or apply_lora_to_output=True to apply LoRA to other linear
# layers outside of the self-attention.
loraModel = lora_llama2_7b(lora_attn_modules=["q_proj", "v_proj"])

**Calling lora_llama_2_7b alone will not handle the definition of which parameters are trainable**

In [7]:
print(baseModel.layers[0].attn)

CausalSelfAttention(
  (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (output_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (pos_embeddings): RotaryPositionalEmbeddings()
)


In [8]:
print(loraModel.layers[0].attn)

CausalSelfAttention(
  (q_proj): LoRALinear(
    (dropout): Dropout(p=0.0, inplace=False)
    (lora_a): Linear(in_features=4096, out_features=8, bias=False)
    (lora_b): Linear(in_features=8, out_features=4096, bias=False)
  )
  (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (v_proj): LoRALinear(
    (dropout): Dropout(p=0.0, inplace=False)
    (lora_a): Linear(in_features=4096, out_features=8, bias=False)
    (lora_b): Linear(in_features=8, out_features=4096, bias=False)
  )
  (output_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (pos_embeddings): RotaryPositionalEmbeddings()
)


In [9]:
# Load pretrained weights from the baseModel to the Lora model
loraModel.load_state_dict(baseModel.state_dict(), strict=False)

_IncompatibleKeys(missing_keys=['layers.0.attn.q_proj.lora_a.weight', 'layers.0.attn.q_proj.lora_b.weight', 'layers.0.attn.v_proj.lora_a.weight', 'layers.0.attn.v_proj.lora_b.weight', 'layers.1.attn.q_proj.lora_a.weight', 'layers.1.attn.q_proj.lora_b.weight', 'layers.1.attn.v_proj.lora_a.weight', 'layers.1.attn.v_proj.lora_b.weight', 'layers.2.attn.q_proj.lora_a.weight', 'layers.2.attn.q_proj.lora_b.weight', 'layers.2.attn.v_proj.lora_a.weight', 'layers.2.attn.v_proj.lora_b.weight', 'layers.3.attn.q_proj.lora_a.weight', 'layers.3.attn.q_proj.lora_b.weight', 'layers.3.attn.v_proj.lora_a.weight', 'layers.3.attn.v_proj.lora_b.weight', 'layers.4.attn.q_proj.lora_a.weight', 'layers.4.attn.q_proj.lora_b.weight', 'layers.4.attn.v_proj.lora_a.weight', 'layers.4.attn.v_proj.lora_b.weight', 'layers.5.attn.q_proj.lora_a.weight', 'layers.5.attn.q_proj.lora_b.weight', 'layers.5.attn.v_proj.lora_a.weight', 'layers.5.attn.v_proj.lora_b.weight', 'layers.6.attn.q_proj.lora_a.weight', 'layers.6.attn.q_p

In [11]:
# Setting trainable parameters in the LoraModel
from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params

In [12]:
loraParams = get_adapter_params(loraModel)
set_trainable_params(loraModel, loraParams)

In [13]:
totalParams = sum([p.numel() for p in loraModel.parameters()])
trainableParams = sum([p.numel() for p in loraModel.parameters() if p.requires_grad])

In [14]:
print(
  f"""
  {totalParams} total params,
  {trainableParams}" trainable params,
  {(100.0 * trainableParams / totalParams):.2f}% of all params are trainable.
  """
)


  6742609920 total params,
  4194304" trainable params,
  0.06% of all params are trainable.
  
