**Pruning Sarvam Model** Kaggle Environment GPU T4

Install Libraries

In [None]:
# Install required libraries
!pip install transformers accelerate datasets lm-eval sacrebleu evaluate torch torchvision torchaudio

In [None]:
!pip install -q lm-eval

In [None]:
!pip install protobuf==3.20.3

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch import nn
from torch.utils.data import DataLoader
import os
from tqdm import tqdm
from lm_eval import evaluator, tasks, models
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import tempfile

Download and Study the model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
#Study Model Architecture
print(model)

In [None]:
def get_output(prompt, model=model, tokenizer=tokenizer):
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    outputs = model.generate(
        inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        max_length=50,
        num_return_sequences=1,
        pad_token_id=tokenizer.pad_token_id,
        temperature=None,
        top_p=None,
        do_sample=False,          # Disable sampling
        num_beams=5,              # Use beam search
        early_stopping=True,      # Stop when end-of-sequence token is generated
        no_repeat_ngram_size=2    # Prevent repetition of 2-grams
    )
    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated

We choose:

SPARSITY_RATIO = 0.50 → prune 50% of weights

Use Wikitext-2 as calibration data

Use 128 samples to capture activations

Only prune the Linear layers inside Attention + MLP blocks

In [None]:
SPARSITY_RATIO = 0.50  # Prune 50% of the weights (unstructured sparsity)
CALIBRATION_SAMPLES = 128 # Number of sentences to use for calibration (activation capture)
MAX_SEQ_LEN = 128
CALIB_DATASET = "wikitext"
CALIB_CONFIG = "wikitext-2-raw-v1"

# Target modules for pruning (typically Linear layers in Attention and MLP)
# These names are generally consistent across Llama-style models like Sarvam-1 (which uses SwiGLU and Grouped-Query Attention)
ATTENTION_LAYERS_TO_PRUNE = ["q_proj", "k_proj", "v_proj", "o_proj"]
MLP_LAYERS_TO_PRUNE = ["gate_proj", "up_proj", "down_proj"]
TARGET_MODULES = ATTENTION_LAYERS_TO_PRUNE + MLP_LAYERS_TO_PRUNE

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

Before we perform WANDA pruning, we need a small amount of real text data to estimate how important each weight is.
This is called calibration data, and it allows us to measure the activation statistics of the model.

In [None]:
#Load calibration data
print(f"Loading calibration dataset: {CALIB_DATASET}/{CALIB_CONFIG}...")
calib_dataset = load_dataset(CALIB_DATASET, CALIB_CONFIG, split="train")


def get_calibration_data(data, n_samples):
    """Tokenizes and batches the first n_samples for activation capture."""
    tokenized_data = []

    # Filter and tokenize a subset of the dataset
    for example in data:
        text = example.get('text', '')
        if text.strip() != '':
            inputs = tokenizer(
                text,
                return_tensors="pt",
                max_length=MAX_SEQ_LEN,
                truncation=True,
                padding="max_length"
            )
            tokenized_data.append(inputs['input_ids'][0])
            if len(tokenized_data) >= n_samples:
                break

    # Stack the samples into a single tensor
    return torch.stack(tokenized_data, dim=0).to(DEVICE)

# Prepare the data tensor
calib_input_ids = get_calibration_data(calib_dataset, CALIBRATION_SAMPLES)
print(f"Calibration data shape: {calib_input_ids.shape}")

This section performs the core of the WANDA pruning algorithm, which removes a percentage of weights (e.g., 50%) from the model without retraining, while preserving the most important parameters.

WANDA computes a per-weight importance score using both:

Weight magnitude (|W|)

Activation strength of its input (‖X‖₂)

This lets pruning be input-aware and significantly more accurate than magnitude-only pruning.

In [None]:
print(f"\nLoading model: {MODEL_ID}...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16, # Sarvam-1 uses bfloat16
).to(DEVICE)
model.eval()

# We store the L2 norm of the input activation (X) for each layer's Linear weight (W)
act_scales = {}

def save_input_hook(name):
    def hook(module, input_tensor, output_tensor):
        # input_tensor is a tuple/list (input, attention_mask, etc.)
        # The first element is the hidden state (activation)
        act = input_tensor[0].detach() # shape (batch_size, seq_len, hidden_size)

        # Reshape to (Batch*Seq, HiddenSize)
        act = act.view(-1, act.shape[-1])

        # Calculate the L2 norm for each feature (column)
        # shape is (HiddenSize,)
        act_l2_norm = torch.norm(act, p=2, dim=0)

        # Accumulate the L2 norms
        if name not in act_scales:
            act_scales[name] = act_l2_norm
        else:
            # Simple aggregation (e.g., taking the max or summing)
            # The original paper uses a single forward pass, but aggregating across batches is common practice.
            act_scales[name] = torch.max(act_scales[name], act_l2_norm)

    return hook

#Register hooks on all target Linear layers (W in Y=XW+B)
handles = []
for name, module in model.named_modules():
    if isinstance(module, nn.Linear) and any(target in name for target in TARGET_MODULES):
        # We target the layer BEFORE the weights are used, which is the input activation (X)
        handle = module.register_forward_hook(save_input_hook(name))
        handles.append(handle)

print("Starting activation capture...")
#Run a single forward pass over the calibration data to capture activations
with torch.no_grad():
    for i in range(calib_input_ids.shape[0]):
        input_ids = calib_input_ids[i].unsqueeze(0)
        model(input_ids)

#Remove hooks after capture
for h in handles:
    h.remove()
print("Activation capture complete.")

#Calculate Wanda scores and prune
for name, module in model.named_modules():
    if isinstance(module, nn.Linear) and any(target in name for target in TARGET_MODULES):

        print(f"Pruning layer: {name}...")
        W = module.weight.data
        X_norm = act_scales[name]

        # Ensure the activation norm is the correct shape for element-wise multiplication
        # W shape: (out_features, in_features)
        # X_norm shape: (in_features,)
        # We need to broadcast X_norm to match W's columns (in_features)
        X_norm_broad = X_norm.unsqueeze(0).expand_as(W)

        # Wanda Score: |W| * ||X||_2
        # This is a per-output (row of W) comparison of score:
        wanda_score = torch.abs(W) * X_norm_broad

        # Get the total number of weights to prune in this layer
        total_weights = wanda_score.numel()
        prune_count = int(total_weights * SPARSITY_RATIO)

        # Find the threshold value for the smallest 'prune_count' scores
        # We use a flattened tensor to find the global threshold for this layer
        threshold = torch.kthvalue(wanda_score.flatten(), prune_count).values

        # Create a mask: True where score < threshold, False otherwise
        mask = (wanda_score < threshold).to(W.dtype)

        # Apply the pruning: zero out the weights that fall below the threshold
        module.weight.data = module.weight.data * (1.0 - mask)

print(f"\nModel pruned to {SPARSITY_RATIO * 100}% sparsity.")

Checking Sparsity after applying pruning

In [None]:
def check_sparsity(model, target_modules):
    """Calculates the overall and layer-wise sparsity of the pruned model."""
    total_weights = 0
    zero_weights = 0

    print("\n--- Sparsity Check ---")

    for name, module in model.named_modules():
        # Only check the layers you targeted for pruning
        if isinstance(module, nn.Linear) and any(target in name for target in target_modules):
            weight_tensor = module.weight.data
            layer_total = weight_tensor.numel()
            layer_zeros = torch.sum(weight_tensor == 0).item()

            total_weights += layer_total
            zero_weights += layer_zeros

            layer_sparsity = (layer_zeros / layer_total) * 100
            print(f"Layer: {name} | Sparsity: {layer_sparsity:.2f}%")

    overall_sparsity = (zero_weights / total_weights) * 100
    print(f"\nOverall Pruned Weights Count: {zero_weights} / {total_weights}")
    print(f"*** Overall Model Sparsity: {overall_sparsity:.2f}% ***")

    # Check if the overall sparsity is close to the 50% target
    if abs(overall_sparsity - SPARSITY_RATIO * 100) < 1.0:
         print("Verification successful: Sparsity is close to the target ratio.")
    else:
         print("WARNING: Sparsity does not match the target ratio.")

# Re-run this function after your Wanda pruning code
check_sparsity(model, TARGET_MODULES)

In [None]:
# --- SAVE THE PRUNED MODEL ---
PRUNED_MODEL_DIR = "sarvam_wanda_pruned"
tokenizer.save_pretrained(PRUNED_MODEL_DIR)
model.save_pretrained(PRUNED_MODEL_DIR)

print(f"\nPruned model and tokenizer saved to: {PRUNED_MODEL_DIR}")


In [None]:
# Push the model to your Hugging Face repository

model.push_to_hub(new_model_name, private=True)
tokenizer.push_to_hub(new_model_name)

Evaluating the Model

LAMBADA – tests the model’s ability to predict the final word of a sentence using long-range context.

BoolQ – a yes/no question-answering task.

ARC Easy – multiple-choice questions that test basic reasoning.

These tasks give a quick snapshot of how pruning affected understanding, reasoning, and language prediction abilities.

In [None]:
def evaluate_loaded_hf_model(model, tokenizer, tasks=['arc_easy', 'boolq', 'lambada'], num_fewshot=0):
    """
    Evaluates a Hugging Face model already loaded into memory by temporarily saving it
    to a folder and using lm-eval's standard local model loading (pretrained=...).
    """
    # Create a temporary directory to store the model
    with tempfile.TemporaryDirectory() as tmpdir:
        print(f"Saving in-memory model temporarily to: {tmpdir}")
        model.save_pretrained(tmpdir)
        tokenizer.save_pretrained(tmpdir)

        model_args = f"pretrained={tmpdir},device=cuda,dtype=float16"

        print(f"Loading model from temp path for evaluation...")

        results = evaluator.simple_evaluate(
            model="hf",
            model_args=model_args,
            tasks=tasks,
            num_fewshot=num_fewshot,
            limit=None,
            bootstrap_iters=10
        )

    metrics = results.get('results', {})
    return metrics


In [None]:
tasks_to_run = ['lambada', 'boolq', 'arc_easy']

print(f"\n--- Starting Evaluation for In-Memory Model ---")
metrics_pruned = evaluate_loaded_hf_model(model, tokenizer, tasks=tasks_to_run)
print("\n--- Pruned Sarvam-1 Evaluation Metrics ---")
print(metrics_pruned)