In [12]:
import prunning
import os 
import utils.misc as misc
import utils.visulaiser as visulaiser
import download_datasets_models as dataset
import evaluate_llm as eval
from torch import nn
from tqdm import tqdm
import numpy as np
import torch
import copy
import matplotlib.pyplot as plt

from transformers import AutoModelForCausalLM, AutoTokenizer


In [3]:
model_name = "Qwen/Qwen2.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
dataloader = dataset.gsm8k

In [4]:
model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((

In [None]:
def get_layer_norm(model):
    norms = []
    nLayers = len(model.model.layers)
    for i in range(nLayers):
        layer_norm = 0
        for name, param in model.model.layers[i].named_parameters():
            if ("weight" in name) and (("self_attn" in name) or ("mlp" in name)):
                layer_norm += torch.linalg.matrix_norm(param, ord='fro')

        norms.append((i, layer_norm.item()))

    return sorted(norms, key=lambda x: x[1])

In [8]:
def get_pruned_model(model, num_layers_to_prune=5):
    model_pruned =copy.deepcopy(model)
    layer_norms = get_layer_norm(model)

    # Get indices of the least important layers
    layers_to_prune = [idx for idx, _ in layer_norms[:num_layers_to_prune]]
    print(f"Layers to be pruned (smallest L2 norms): {layers_to_prune}")


    # Rebuild the model without these layers
    new_layers = torch.nn.ModuleList(
        [layer for i, layer in enumerate(model.model.layers) if i not in layers_to_prune]
    )

    model_pruned.model.layers = new_layers

    return model_pruned

In [9]:
model_pruned = get_pruned_model(model, 10)

Layers to be pruned (smallest L2 norms): [11, 9, 13, 12, 10, 3, 14, 5, 15, 6]


In [10]:
model_pruned

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-13): 14 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((

In [None]:
def count_weight_layers(model):
    weight_layers = 0
    for module_name, module in model.named_modules():
        # Check if the module has learnable parameters (weights)
        if isinstance(module, (nn.Linear, nn.Conv2d, nn.Embedding)):
            # Linear layers, Embedding layers, etc., have weights
            weight_layers += 1
    
    print(f"Total number of weight layers: {weight_layers}")
    return weight_layers

calculate_layers(model_pruned)

Total number of layers: 14
Layer 0: Qwen2DecoderLayer(
  (self_attn): Qwen2SdpaAttention(
    (q_proj): Linear(in_features=896, out_features=896, bias=True)
    (k_proj): Linear(in_features=896, out_features=128, bias=True)
    (v_proj): Linear(in_features=896, out_features=128, bias=True)
    (o_proj): Linear(in_features=896, out_features=896, bias=False)
    (rotary_emb): Qwen2RotaryEmbedding()
  )
  (mlp): Qwen2MLP(
    (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
    (up_proj): Linear(in_features=896, out_features=4864, bias=False)
    (down_proj): Linear(in_features=4864, out_features=896, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
  (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
)
Layer 1: Qwen2DecoderLayer(
  (self_attn): Qwen2SdpaAttention(
    (q_proj): Linear(in_features=896, out_features=896, bias=True)
    (k_proj): Linear(in_features=896, out_features=128, bias=True)
    (v_proj): Linear(in_

14

In [None]:
def plot_weight_distribution(model, bins=256, count_nonzero_only=False):
    fig, axes = plt.subplots(10, 5, figsize=(20, 12))  
    axes = axes.ravel()  
    plot_index = 0

    for name, param in model.named_parameters():
        if param.dim() > 1:  
            ax = axes[plot_index]

            if count_nonzero_only:
                param_cpu = param.detach().view(-1).cpu()
                param_cpu = param_cpu[param_cpu != 0].view(-1) 
                ax.hist(param_cpu, bins=bins, density=True, alpha=0.5)
            else:
                ax.hist(param.detach().view(-1).cpu(), bins=bins, density=True, alpha=0.5)
            
            ax.set_xlabel(name, fontsize=8)
            ax.set_ylabel('Density', fontsize=8)
            plot_index += 1
    fig.suptitle('Histogram of Weights', fontsize=16)
    fig.tight_layout()
    fig.subplots_adjust(top=0.925)  
    plt.show()