<a href="https://colab.research.google.com/github/abhinavsb3/Pruning_for_Model_Distillation/blob/main/Layer_Width_Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Layer Pruning with naive Width Pruning**

In [1]:
!python -m pip install --upgrade pip -q
!pip install transformers -qU

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [7]:
from transformers import AutoModelForCausalLM
import torch
import math

def total_parameters(model):
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

def prune_layers(model, target_params, original_params):
  total_layers = len(model.model.layers)

  #Calculate number of layers to keep
  layers_to_keep = round((target_params / original_params) * total_layers)
  layers_to_prune = total_layers - layers_to_keep

  #Keep all layers except those right before the last layer
  final_layers = (
      list(model.model.layers[:total_layers - layers_to_prune - 1] +
           [model.model.layers[-1]])
  )

  model.model.layers = torch.nn.ModuleList(final_layers)

  model.config.num_hidden_layers = len(final_layers)

  return model


def prune_hidden_dimensions(model, target_params, current_params):
  original_hidden_size = model.config.hidden_size
  original_intermediate_size = model.config.intermediate_size
  original_proj_ratio = original_intermediate_size / original_hidden_size #Calculate ratio dynamically
  num_heads = model.config.num_attention_heads

  #Estimate new hidden size to taget parameters
  reduction_ratio  = math.sqrt(target_params / current_params)
  new_hidden_size = int(original_hidden_size * reduction_ratio)
  new_hidden_size = (new_hidden_size // (2 * num_heads)) * (2 * num_heads) #Ensure divisibilty

  num_attention_heads = model.config.num_attention_heads
  num_key_value_heads = model.config.num_key_value_heads

  #Update hidden size and intermediate size in the config
  model.config.hidden_size = new_hidden_size
  model.config.intermediate_size = int(new_hidden_size * original_proj_ratio) #Maintain the original ratio

  for layer in model.model.layers:
    #Adjust attention projection layers
    layer.self_attn.q_proj.weight = torch.nn.Parameter(
        layer.self_attn.q_proj.weight[:new_hidden_size, :new_hidden_size].contiguous()
    )
    layer.self_attn.k_proj.weight = torch.nn.Parameter(
        layer.self_attn.k_proj.weight[:new_hidden_size, :new_hidden_size // (num_attention_heads//num_key_value_heads)].contiguous()
    )
    layer.self_attn.v_proj.weight = torch.nn.Parameter(
        layer.self_attn.v_proj.weight[:new_hidden_size, :new_hidden_size // (num_attention_heads//num_key_value_heads)].contiguous()
    )
    layer.self_attn.o_proj.weight = torch.nn.Parameter(
        layer.self_attn.o_proj.weight[:new_hidden_size, :new_hidden_size].contiguous()
    )

    #Adjust MLP layers
    new_intermediate_size = model.config.intermediate_size
    layer.mlp.gate_proj.weight = torch.nn.Parameter(
        layer.mlp.gate_proj.weight[:new_intermediate_size, :new_hidden_size].contiguous()
    )
    layer.mlp.up_proj.weight = torch.nn.Parameter(
        layer.mlp.up_proj.weight[:new_intermediate_size, :new_hidden_size].contiguous()
    )
    layer.mlp.down_proj.weight = torch.nn.Parameter(
        layer.mlp.down_proj.weight[:new_hidden_size, :new_intermediate_size].contiguous()
    )

    #Adjust rotary positional embeddings
    rotary_dim = new_hidden_size // num_heads
    model.model.rotary_emb_inv_freq = model.model.rotary_emb.inv_freq[:rotary_dim].contiguous()

  return model


def create_pruned_lm(model_name, target_params_1, target_params_2):
  #Step 1 :Load the model
  model = AutoModelForCausalLM.from_pretrained(model_name)

  print("____________________________________")
  print("Original Model:")
  print(model)

  #Count original parameters
  original_params = total_parameters(model)
  print("____________________________________")
  print(f"Original model Parameters: {original_params}")

  #Step 2 :Prune layers to target ~90M parameters
  model = prune_layers(model, target_params_1, original_params)

  new_params = total_parameters(model)
  print("____________________________________")
  print(f"\nModel parameters after layer pruning: {new_params:,}")

  #Step 3 :Prune hidden dimensions to target again reduce parameters size
  model = prune_hidden_dimensions(model, target_params_2, new_params)

  final_params = total_parameters(model)
  print("____________________________________")
  print(f"\nModel parameters after hidden dimension pruning: {final_params:,}")

  print("____________________________________")
  reduction_percentage = (1 - final_params / original_params) * 100
  print(f"\nTotal size reduction: {reduction_percentage:.2f} %")

  print("____________________________________")
  print(model)

  return model


model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
target_params_1 = 110_000_000
target_params_2 = 90_000_000

#Create the pruned model
pruned_lm = create_pruned_lm(model_name, target_params_1, target_params_2)

modified_model_path = f"{model_name.split('/')[1]}-layer-width-pruned-{int(target_params_2)}M-raw"
pruned_lm.save_pretrained(modified_model_path)
print("____________________________________")
print(f"\nprunedLM-100M-Instruct saved to: {modified_model_path}")
print("____________________________________")
# Check the first layer of your pruned model
layer_to_inspect = pruned_lm.model.layers[0]

print("--- Actual pruned Shapes ---")

# Attention block weights
print("Q-Proj shape:", layer_to_inspect.self_attn.q_proj.weight.shape)
print("K-Proj shape:", layer_to_inspect.self_attn.k_proj.weight.shape)
print("V-Proj shape:", layer_to_inspect.self_attn.v_proj.weight.shape)

print("-" * 20)

# MLP block weights
print("MLP Gate-Proj shape:", layer_to_inspect.mlp.gate_proj.weight.shape)
print("MLP Down-Proj shape:", layer_to_inspect.mlp.down_proj.weight.shape)
print("____________________________________")
print(pruned_lm.config)



____________________________________
Original Model:
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576, padding_idx=2)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
   