In [2]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

def create_factorized_compression_for_linear(source_linear, rank):
  """
  Adapt from: https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/cli_svd.py
  Create a factorized compression for a given linear layer using SVD.
  Args:
      source_linear (nn.Linear): The original linear layer to be compressed.
      rank (int, optional): The rank for the factorization. If None, it will be calculated based on rank_factor.
      rank_factor (float, optional): The factor to determine the rank if rank is not provided. Default is 0.3.
  Returns:
      nn.Sequential: A sequential container of the compressed linear layers.
  """
  with torch.no_grad():
      dtype = source_linear.weight.dtype
      # Check if the source linear layer has a bias term
      if hasattr(source_linear, 'bias'):
          bias = source_linear.bias
      else:
          bias = None
      # Calculate the total number of parameters in the source linear layer
      source_num_params = sum(param.numel() for param in source_linear.parameters())
      # Get the weight matrix of the source linear layer
      source_linear_weight = source_linear.weight.data
      # Ensure rank is less than the minimum dimension of the weight matrix
      assert rank < min(source_linear_weight.shape)
      # Perform SVD on the weight matrix
      U, S, Vh = torch.linalg.svd(source_linear_weight.float())
      # Truncate U, S, Vh to the specified rank
      U = U[:, :rank]  # Shape: [out_features, rank]
      S = S[:rank]     # Shape: [rank]
      Vh = Vh[:rank, :]  # Shape: [rank, in_features]
      # Incorporate singular values into U
      U = U @ torch.diag(S)  # Shape: [out_features, rank]
      # Flatten U and Vh for quantile computation
      U_flatten = U.flatten()
      Vh_flatten = Vh.flatten()
      # Define the maximum quantization size
      max_quant_size = 2**23
      # Compute high and low quantile values for clamping
      if len(U_flatten) + len(Vh_flatten) >= max_quant_size:
          dist2 = U_flatten[:min(len(U_flatten), max_quant_size)]
          dist3 = Vh_flatten[:min(len(Vh_flatten), max_quant_size)]
          hi_val = max(torch.quantile(dist3, 1), torch.quantile(dist2, 1))
      else:
          dist = torch.cat([U_flatten, Vh_flatten])
          hi_val = torch.quantile(dist, 1)
      low_val = -hi_val
      # Clamp U and Vh to the quantile values
      U = U.clamp(low_val, hi_val)
      Vh = Vh.clamp(low_val, hi_val)
      # Create the down projection linear layer (Vh)
      lora_down = nn.Linear(Vh.shape[1], Vh.shape[0], dtype=dtype, bias=False, device=source_linear_weight.device)
      lora_down.weight.data = Vh.to(device=source_linear_weight.device, dtype=dtype)
      # Create the up projection linear layer (U)
      lora_up = nn.Linear(U.shape[1], U.shape[0], dtype=dtype, bias=bias is not None, device=source_linear_weight.device)
      lora_up.weight.data = U.to(device=source_linear_weight.device, dtype=dtype)
      # If the original linear layer had a bias, copy it to the up projection layer
      if bias is not None:
          lora_up.bias = nn.Parameter(bias.clone())
      # Print compression ratio (for debugging purposes)
      #print('compression', sum(param.numel() for param in ret.parameters()) / source_num_params)
      return lora_down, lora_up
    
class AdaVocabHead(nn.Module):
  def __init__(self, lm_head, sub_vocab_dim, dora=False, svd=False, activation_func=None):
    self.dora = dora
    hidden_size, vocab_size = lm_head.in_features, lm_head.out_features
    super().__init__()
    if svd: # SVD initialization
      self.A, self.B = create_factorized_compression_for_linear(lm_head, sub_vocab_dim)
      if dora: 
        self.m = nn.Parameter(lm_head.weight.T.norm(p=2, dim=1, keepdim=True))  # (hidden_size, 1)
    else:  # Random initialization
      self.A = nn.Linear(hidden_size, sub_vocab_dim, bias=False)
      self.B = nn.Linear(sub_vocab_dim, vocab_size, bias=False)
      std_dev = 1 / math.sqrt(sub_vocab_dim)
      nn.init.normal_(self.A.weight, 0, std_dev)
      nn.init.zeros_(self.B.weight)
    self.activation_func = activation_func
    
  def forward(self, x):
    # x.shape: (..., hidden_size), A.shape: (hidden_size, sub_vocab_dim), B.shape: (sub_vocab_dim, vocab_size)
    if self.dora:
      comb_weight = self.A.weight.T @ self.B.weight.T  # (hidden_size, vocab_size)
      norm_vec = comb_weight.norm(p=2, dim=1, keepdim=True)  # (hidden_size, 1)
      directional_component = comb_weight / norm_vec  # (hidden_size, vocab_size)
      dora_weight = self.m * directional_component  # (hidden_size, vocab_size)
      ada_vocab_logits = x @ dora_weight  # ada_vocab_logits.shape: (..., vocab_size)
    else:
      logits = self.A(x)
      if self.activation_func is not None:
          logits = self.activation_func(logits)
      ada_vocab_logits = self.B(logits)  # ada_vocab_logits.shape: (..., vocab_size)  
    return ada_vocab_logits


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model = AutoModelForCausalLM.from_pretrained("/home/z/zheng22/AdaVocab/experiment_ckpts/gemma-2b_SFT-2024-06-10-123619/checkpoint-11592", torch_dtype=torch.bfloat16)

Loading checkpoint shards: 100%|██████████| 2/2 [00:25<00:00, 12.82s/it]


In [31]:
x = torch.randn(2, 2048, dtype=torch.bfloat16).cuda()

In [2]:
new_adahead = AdaVocabHead(model.lm_head, 20, dora=True, svd=True).cuda()
new_adahead(x).shape

Downloading shards: 100%|██████████| 2/2 [00:17<00:00,  8.99s/it]
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.43s/it]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (20x256000 and 2048x20)

In [5]:
new_adahead.A, new_adahead.B, 

(Linear(in_features=20, out_features=256000, bias=False),
 Linear(in_features=2048, out_features=20, bias=False))

In [22]:
new_adahead.cuda()

AdaVocabHead(
  (A): Linear(in_features=20, out_features=256000, bias=False)
  (B): Linear(in_features=2048, out_features=20, bias=False)
)

In [7]:
comb_weight = new_adahead.B.weight.T @ new_adahead.A.weight.T

In [23]:
norm_vec = comb_weight.norm(p=2, dim=1, keepdim=True)  # (hidden_size, 1)
directional_component = comb_weight / norm_vec  # (hidden_size, vocab_size)
dora_weight = new_adahead.m * directional_component  # (hidden_size, vocab_size)

In [16]:
new_adahead.m = nn.Parameter(model.lm_head.weight.T.norm(p=2, dim=1, keepdim=True))

In [17]:
new_adahead.m.shape

torch.Size([2048, 1])

In [1]:
from transformers import AutoModelForCausalLM
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
a = torch.randn(2048, 256000, dtype=torch.float32)
b = torch.randn(2048, dtype=torch.float32)
c = torch.randn(2048, 2048, dtype=torch.float32)
torch.save((a, b, c), "test.pt")

In [2]:
model_path = "/home/z/zheng22/AdaVocab/experiment_ckpts/gemma-2b_SFT-2024-06-10-123619/checkpoint-11592"
model = AutoModelForCausalLM.from_pretrained(model_path)

Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.34s/it]


In [3]:
source_linear_weight = model.lm_head.weight.data

In [4]:
U, S, Vh = torch.linalg.svd(source_linear_weight.float())


In [5]:
print(U.shape, S.shape, Vh.shape)
print(type(U), type(S), type(Vh))

torch.Size([256000, 256000]) torch.Size([2048]) torch.Size([2048, 2048])
<class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'>


In [6]:
max_rank = 2048
U = U[:, :max_rank].clone()  # Shape: [out_features, rank]
S = S[:max_rank].clone()     # Shape: [rank]
Vh = Vh[:max_rank, :].clone()  # Shape: [rank, in_features]

In [7]:
print(U.shape, S.shape, Vh.shape)

torch.Size([256000, 2048]) torch.Size([2048]) torch.Size([2048, 2048])


In [8]:
torch.save((U, S, Vh), "svd.pth")