In [None]:
pip install torchtune

In [None]:
!tune

In [None]:
from torch import nn,Tensor

class LORALinear(nn.Module):
  def __init__(
      self,
      in_dim: int,
      out_dim: int,
      rank:int,
      alpha: float,
      dropout: float,
  ):
  #weights from original pretrained model
    self.linear = nn.Linear(in_dim,out_dim,bias=False)

    #new Lora parameters. in general rank<<in_dim, out_dim
    self.lora_a = nn.Linear(in_dim,rank,bias=False)
    self.lora_b = nn.Linear(rank,out_dim,bias=False)

    #rank and alpha are commonly tuned hyperparameters
    self.rank= rank
    self.alpha = alpha

    #dropout
    self.dropout = nn.Dropout(p=dropout)

    #original parameters are frozen and only lora parameters are trained
    self.linear.weight.requires_grad = False
    self.lora_a.weight.requires_grad = True
    self.lora_b.weight.requires_grad = True

  def forward(self, x:Tensor) ->Tensor:
    #output of the original model
    frozen_out = self.linear(x)

    #lora_a projects inputs down to the much smaller self.rank
    #then lora_b projects back to the output dimension

    lora_out = self.lora_b(self.lora_a(self.dropout(x)))

    #finally scale by alpha parameter (normalized by rank)
    # and add to the original model's output

    return frozen_out + (self.alpha/self.rank)*lora_out




# Applying Lora to Phi 3 model

In [None]:
import torch

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
vocab_size = 30522
num_layers = 6
num_heads = 8
num_kv_heads = 1
embed_dim = 512
intermediate_dim = 2048
max_seq_len = 512

In [None]:
from torchtune.models.phi3 import phi3, lora_phi3
#llama2 without lora layers
base_model = phi3(vocab_size=vocab_size,
        num_layers=num_layers,
        num_heads=num_heads,
        num_kv_heads=num_kv_heads,
        embed_dim=embed_dim,
        intermediate_dim=intermediate_dim,
        max_seq_len=max_seq_len).to(device)

# The default settings for lora_llama2_7b will match those for llama2_7b
# We just need to define which layers we want LoRA applied to.
# Within each self-attention, we can choose from ["q_proj", "k_proj", "v_proj", and "output_proj"].
# We can also set apply_lora_to_mlp=True or apply_lora_to_output=True to apply LoRA to other linear
# layers outside of the self-attention.



In [None]:
lora_model = lora_phi3(vocab_size=vocab_size,
        num_layers=num_layers,
        num_heads=num_heads,
        num_kv_heads=num_kv_heads,
        embed_dim=embed_dim,
        intermediate_dim=intermediate_dim,
        max_seq_len=max_seq_len,
        lora_attn_modules=['q_proj','v_proj'],
        lora_rank=8,
        lora_alpha=16)

In [None]:
print(base_model.layers[0])

TransformerDecoderLayer(
  (sa_norm): RMSNorm()
  (attn): CausalSelfAttention(
    (q_proj): Linear(in_features=512, out_features=512, bias=False)
    (k_proj): Linear(in_features=512, out_features=64, bias=False)
    (v_proj): Linear(in_features=512, out_features=64, bias=False)
    (output_proj): Linear(in_features=512, out_features=512, bias=False)
    (pos_embeddings): Phi3RotaryPositionalEmbeddings()
  )
  (mlp_norm): RMSNorm()
  (mlp): FeedForward(
    (w1): Linear(in_features=512, out_features=2048, bias=False)
    (w2): Linear(in_features=2048, out_features=512, bias=False)
    (w3): Linear(in_features=512, out_features=2048, bias=False)
    (activation): SiLU()
  )
)


In [None]:
print(lora_model.layers[0])

TransformerDecoderLayer(
  (sa_norm): RMSNorm()
  (attn): CausalSelfAttention(
    (q_proj): LoRALinear(
      (dropout): Dropout(p=0.0, inplace=False)
      (lora_a): Linear(in_features=512, out_features=8, bias=False)
      (lora_b): Linear(in_features=8, out_features=512, bias=False)
    )
    (k_proj): Linear(in_features=512, out_features=64, bias=False)
    (v_proj): LoRALinear(
      (dropout): Dropout(p=0.0, inplace=False)
      (lora_a): Linear(in_features=512, out_features=8, bias=False)
      (lora_b): Linear(in_features=8, out_features=64, bias=False)
    )
    (output_proj): Linear(in_features=512, out_features=512, bias=False)
    (pos_embeddings): Phi3RotaryPositionalEmbeddings()
  )
  (mlp_norm): RMSNorm()
  (mlp): FeedForward(
    (w1): Linear(in_features=512, out_features=2048, bias=False)
    (w2): Linear(in_features=2048, out_features=512, bias=False)
    (w3): Linear(in_features=512, out_features=2048, bias=False)
    (activation): SiLU()
  )
)


In [None]:
#assuming that base model has pretrained weights
#this will load them directly into lora model without any conversion necessary

lora_model.load_state_dict(base_model.state_dict(),strict=False)

_IncompatibleKeys(missing_keys=['layers.0.attn.q_proj.lora_a.weight', 'layers.0.attn.q_proj.lora_b.weight', 'layers.0.attn.v_proj.lora_a.weight', 'layers.0.attn.v_proj.lora_b.weight', 'layers.1.attn.q_proj.lora_a.weight', 'layers.1.attn.q_proj.lora_b.weight', 'layers.1.attn.v_proj.lora_a.weight', 'layers.1.attn.v_proj.lora_b.weight', 'layers.2.attn.q_proj.lora_a.weight', 'layers.2.attn.q_proj.lora_b.weight', 'layers.2.attn.v_proj.lora_a.weight', 'layers.2.attn.v_proj.lora_b.weight', 'layers.3.attn.q_proj.lora_a.weight', 'layers.3.attn.q_proj.lora_b.weight', 'layers.3.attn.v_proj.lora_a.weight', 'layers.3.attn.v_proj.lora_b.weight', 'layers.4.attn.q_proj.lora_a.weight', 'layers.4.attn.q_proj.lora_b.weight', 'layers.4.attn.v_proj.lora_a.weight', 'layers.4.attn.v_proj.lora_b.weight', 'layers.5.attn.q_proj.lora_a.weight', 'layers.5.attn.q_proj.lora_b.weight', 'layers.5.attn.v_proj.lora_a.weight', 'layers.5.attn.v_proj.lora_b.weight'], unexpected_keys=[])

In [None]:
#once weights are loaded set lora parameters to trainable
from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params

#fetch all params from the model that are associated with lora
lora_params = get_adapter_params(lora_model)

#set requires_grad =True on lora params and requires_grad=False on all other
set_trainable_params(lora_model,lora_params)

#print total number of params
total_params = sum([p.numel() for p in lora_model.parameters()])
trainable_params = sum([p.numel() for p in lora_model.parameters() if p.requires_grad])

print(f"""{total_params} total parameters,
 {trainable_params} trainable parameters,
  {(100.0 * trainable_params / total_params):.2f}% of all params are trainable""")

53751296 total parameters,
 76800 trainable parameters,
  0.14% of all params are trainable
