**Pruning Sarvam Model** Kaggle Environment GPU T4

Install Libraries

In [None]:
# Install required libraries
!pip install transformers accelerate datasets lm-eval sacrebleu evaluate torch torchvision torchaudio


In [None]:
!pip install -q lm-eval

In [None]:
!pip install protobuf==3.20.3

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch import nn
from torch.utils.data import DataLoader
import os
from tqdm import tqdm
from lm_eval import evaluator, tasks, models
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import tempfile

In [None]:

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Download Model and Explore Structure

In [None]:

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")

An MLP block typically consists of layers that scale the data to larger dimensions and others that return it to its original size.

In the MLP block of the model, we find two projection layers: gat_proj and up_proj, both scaling from 2048 to 8192. The purpose of having two layers projecting to the same intermediate size might be related to gating mechanisms. A gating mechanism selectively controls information flow in neural networks by using learned weights to "gate" or filter inputs.

In [None]:
#Study Model Architecture
print(model)

In [None]:
def get_output(prompt, model=model, tokenizer=tokenizer):
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    outputs = model.generate(
        inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        max_length=50,
        num_return_sequences=1,
        pad_token_id=tokenizer.pad_token_id,
        temperature=None,
        top_p=None,
        do_sample=False,          # Disable sampling
        num_beams=5,              # Use beam search
        early_stopping=True,      # Stop when end-of-sequence token is generated
        no_repeat_ngram_size=2    # Prevent repetition of 2-grams
    )
    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())


In [None]:
original_param_count = count_parameters(model)
print(f"Original model parameters: {original_param_count}")

The Maximum Absolute Weight method works because it directly identifies the most influential neurons based on the magnitude of their connections. These neurons are likely responsible for key decisions, making the model more accurate after pruning. The Variance of Weights method, while useful in some contexts, can retain neurons that may not contribute significantly to the task, leading to less coherent model outputs.

In [None]:
#Maximum Absolute Weight:
#The maximum absolute weight in a neuron might indicate its significance.

def compute_neuron_pair_importance(gate_weight, up_weight):
  """
  compute neuron pair importance scores (Maximum Absolute Weight)

  Args:
  - gate_weight: Weight matrix from the gate_proj layer.
  - up_weight: Weight matrix from the up_weight layer.

  Returns:
  - importance_scores: Importance scores for each neuron pair.
  """

  gate_max_abs = torch.max(gate_weight, dim=1).values + torch.abs(torch.min(gate_weight, dim=1).values)
  up_max_abs = torch.max(up_weight, dim=1).values + torch.abs(torch.min(up_weight, dim=1).values)
  importance_scores = gate_max_abs + up_max_abs
  return importance_scores


In [None]:
#Prunes a specific percentatge of neurons from the MLP (feed forward layers).
def prune_neuron_pairs(mlp, prune_percent):
    """
    Reduces the dimensions of the **gate_proj**,**up_proj**, **down_proj**
    layers removing the least important neurons.

    Args:
    - mlp: Layers to prune.
    - prune_percent: Percentage of neurons to prune.

    Returns:
    - new_gate_proj, new_up_proj, new_down_proj:  New pruned layers.
    - k: New intermediate size.

    """
    # Extract the weights from the MLP layers
    #  these weights are used to calculate each neuron's
    #  importance score in the next step.
    gate_weight = mlp.gate_proj.weight.data.float()
    up_weight = mlp.up_proj.weight.data.float()

    #Compute importance stores. Neurons with higher importance scores
    # are considered more important and less likely to be pruned.
    importance_scores = compute_neuron_pair_importance(gate_weight, up_weight)

    #Store the original number of neurons in the intermediate layer.
    original_intermediate_size = gate_weight.size(0)
    #Computes the number of neurons to prune.
    num_neuron_pairs_to_prune = min(int(prune_percent * original_intermediate_size), original_intermediate_size - 1)
    #Calculate the number of neurons to keep. The new intermediate size.
    k = original_intermediate_size - num_neuron_pairs_to_prune

    #Just check that there is no big error calculating k. We can't prune all the neurons.
    if k <= 0:
        raise ValueError(f"Invalid number of neuron pairs to keep: {k}. Adjust the prune_percent.")

    #Select the neuros to keep, by obtaining the indices to keep.
    _, indices_to_keep = torch.topk(importance_scores, k, largest=True, sorted=True)
    indices_to_keep = indices_to_keep.sort().values

    #create the new layers
    new_gate_proj = nn.Linear(mlp.gate_proj.in_features, k, bias=False).to(device)
    new_up_proj = nn.Linear(mlp.up_proj.in_features, k, bias=False).to(device)
    new_down_proj = nn.Linear(k, mlp.down_proj.out_features, bias=False).to(device)

    #copy weights to the new layers.
    new_gate_proj.weight.data = mlp.gate_proj.weight.data[indices_to_keep, :]
    new_up_proj.weight.data = mlp.up_proj.weight.data[indices_to_keep, :]
    new_down_proj.weight.data = mlp.down_proj.weight.data[:, indices_to_keep]

    #return new layers and intermediate size.
    return new_gate_proj, new_up_proj, new_down_proj, k



The neurons are removed in the prune_neurons function based on the values returned by compute_neuron_pair_importance.

In [None]:
#Iterates throught the model layers and applies pruning.
def update_model(model, prune_percent):
    """
    It modifies each mlp layer present in model, to retain only the most
    important neurons. Creating new smaller versions of each layer pruned.

    Args:
    - model: Model to prune.
    - prune_percent: Percentage of neurons to prune.

    Returns:
    - model: New pruned model.
    """
    new_intermediate_size = None

    #loop for each model layer.
    for idx, layer in enumerate(model.model.layers):
        #Since each layer is a LlamaDecoderLayer it contains multiple components
        # Attention, MLP and Layer norms. We're targetting MLP component
        # by accesing layer.mlp.
        mlp = layer.mlp

        #Call the prune_neiron_pairs with the layers and receiving the pruned.
        new_gate_proj, new_up_proj, new_down_proj, new_size = prune_neuron_pairs(mlp, prune_percent)

        #Replace the Origiginal Layers with Pruned Layers.
        mlp.gate_proj = new_gate_proj
        mlp.up_proj = new_up_proj
        mlp.down_proj = new_down_proj

        #new_intermediate_size only needs to be set once
        if new_intermediate_size is None:
            new_intermediate_size = new_size

    #Update the model config file.
    model.config.intermediate_size = new_intermediate_size

    return model



**Obtain & test the pruned model.**

In [None]:
prune_percent = 0.2  # Prune 20% of neurons
model = update_model(model, prune_percent)

In [None]:

# Recalculate the number of parameters
pruned_param_count = count_parameters(model)
reduction_in_params = original_param_count - pruned_param_count
percentage_savings = (reduction_in_params / original_param_count) * 100

print(f"Pruned model parameters: {pruned_param_count}")
print(f"Reduction in parameters: {reduction_in_params}")
print(f"Percentage of weight savings: {percentage_savings:.2f}%")


Study Model Structure after pruning

In [None]:
print(model)

In [None]:
#Save the model
new_model_name = 'pruned20_sarvam_stw'
output_dir = './'+new_model_name
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"Pruned model saved to {output_dir}")

In [None]:
# Push the model to your Hugging Face repository

model.push_to_hub(new_model_name, private=True)
tokenizer.push_to_hub(new_model_name)

Evaluating the Model

LAMBADA – tests the model’s ability to predict the final word of a sentence using long-range context.

BoolQ – a yes/no question-answering task.

ARC Easy – multiple-choice questions that test basic reasoning.

These tasks give a quick snapshot of how pruning affected understanding, reasoning, and language prediction abilities.

In [None]:
def evaluate_loaded_hf_model(model, tokenizer, tasks=['arc_easy', 'boolq', 'lambada'], num_fewshot=0):
    """
    Evaluates a Hugging Face model already loaded into memory by temporarily saving it
    to a folder and using lm-eval's standard local model loading (pretrained=...).
    """
    # Create a temporary directory to store the model
    with tempfile.TemporaryDirectory() as tmpdir:
        print(f"Saving in-memory model temporarily to: {tmpdir}")
        model.save_pretrained(tmpdir)
        tokenizer.save_pretrained(tmpdir)

        model_args = f"pretrained={tmpdir},device=cuda,dtype=float16"

        print(f"Loading model from temp path for evaluation...")

        results = evaluator.simple_evaluate(
            model="hf",
            model_args=model_args,
            tasks=tasks,
            num_fewshot=num_fewshot,
            limit=None,
            bootstrap_iters=10
        )

    metrics = results.get('results', {})
    return metrics





In [None]:
tasks_to_run = ['lambada', 'boolq', 'arc_easy']

print(f"\n--- Starting Evaluation for In-Memory Model ---")
metrics_pruned = evaluate_loaded_hf_model(pruned_model, tokenizer, tasks=tasks_to_run)
print("\n--- Pruned Sarvam-1 Evaluation Metrics ---")
print(metrics_pruned)