## 🔍 LLaMA 3 Pruning for Efficient Inference

This project explores structured pruning of the LLaMA 3 8B model to reduce its size and improve inference efficiency. Using block-level influence scores based on hidden state similarity, we identify and remove the least important layers while maintaining performance on downstream tasks like HellaSwag.

The workflow includes:
- Loading and preparing the model
- Measuring influence of transformer blocks
- Pruning selected layers
- Evaluating the pruned model using `lm_eval`

This notebook provides a reproducible implementation of the pruning pipeline, along with benchmarks to compare model performance before and after pruning.


In [1]:
!pip install datasets lm_eval



In [2]:
# importing the required libraries
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, Trainer, TrainingArguments
from collections import OrderedDict
from typing import List, Optional
import numpy as np
from tqdm.notebook import tqdm
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from lm_eval import evaluator, tasks

In [3]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Token has not been saved to git credential helper.


In [4]:
def layer_removal(
    model: nn.Module,
    layers_to_remove: OrderedDict
):
    """
    Generic removal implementation
    """

    for layer_name, layer_idx in layers_to_remove.items():
        modules = layer_name.split(".")
        mod = model
        for m in modules[:-1]:
            mod = getattr(mod, m)
        
        if layer_idx is None:
            delattr(mod, modules[-1])
        else:
            delattr(mod, modules[-1])[layer_idx]

### `block_influence`

Measures how much a layer changes the hidden states.

- Takes input and output hidden states from a layer.
- Calculates similarity between them.
- If `angular=True`, uses angle-based distance.
- Otherwise, returns 1 - cosine similarity.

Used to find less important layers for pruning.


In [5]:
def block_influence(
    input_hidden_state: torch.Tensor,
    output_hidden_state: torch.Tensor,
    angular=False,
):
    """
    input_hidden_state: B, S, D
    output_hidden_state: B, S, D
    """
    _, _, d = input_hidden_state.shape
    input_hidden_state = input_hidden_state.reshape(-1, d)
    output_hidden_state = output_hidden_state.reshape(-1, d)

    norm_input = input_hidden_state.norm(dim=-1, keepdim=True)
    norm_output = output_hidden_state.norm(dim=-1, keepdim=True)

    sim = (input_hidden_state @ output_hidden_state.T) / (norm_input * norm_output)
    sim = sim.diagonal().nan_to_num(nan=0.5)

    if angular:
        return (torch.arccos(sim) / torch.pi)

    return 1 - sim

### `ShortHFModel` Class

A wrapper around HuggingFace's language models for pruning transformer layers based on their importance.

#### `__init__`
- Loads a HuggingFace model and tokenizer.
- Finds the model layers using the provided path (e.g. `"model.layers"`).
- Sets up a list to store layer importance scores.

#### `remove_layers()`
- Removes a number of layers based on importance.
- If `angular=True`, picks consecutive layers with least angular impact.
- Otherwise, removes the least important layers globally.
- Layers are removed in reverse to avoid index errors.

#### `compute_bi()`
- Calculates how much each layer changes the hidden states.
- Uses the `block_influence()` function.

#### `eval_importance()`
- Feeds input prompts into the model using a sliding window.
- Collects hidden states across layers without generating output.
- Computes and updates layer importance scores from the outputs.

Used to rank and prune LLaMA layers in a structured way.


In [6]:
class ShortHFModel():

    def __init__(self, model_name: str, layers_path: str, n_prune_layers: Optional[int] = None):
        """
        HuggingFace Model Wrapper

        Args:
            model_name (str): HuggingFace model name
            layers_path (str): String in dot notation demonstrating how to access layers of the model. Ex: "model.layers"
            (Optional) n_prune_layers (int): Number of layers to prune. Defaults to None.
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
        # self.model.params = self.model.to_fp16(self.model.params)
        self.model.to("cuda")

        modules = layers_path.split(".")
        mod = self.model
        for m in modules:
            mod = getattr(mod, m)
        self.layers = mod

        self.n_prune_layers = n_prune_layers
        self.importances = [0 for _ in self.layers]  # layer-wise importance scores

    def remove_layers(
        self,
        layers_to_remove: Optional[List[int]] = [],
        angular: Optional[bool] = False
    ):
        if angular:
            assert self.importances, "Need to compute importances with eval_importance()"
            assert self.n_prune_layers, "Need number of layers to prune, set `n_prune_layers`"
            start_layer = np.argsort(np.array(self.importances[:-self.n_prune_layers+1]))[0]
            layers_to_remove = list(range(start_layer, start_layer + self.n_prune_layers))
        elif not layers_to_remove and self.n_prune_layers:
            assert self.importances, "Need to compute importances with eval_importance()"
            layers_to_remove = np.argsort(np.array(self.importances))[:self.n_prune_layers].tolist()

        # remove layers in reverse to avoid indexing errors
        for layer_idx in sorted(layers_to_remove, reverse=True):
            try:
                del self.layers[layer_idx]
            except IndexError:
                print(f"layer {layer_idx} does not exist, function may have already been called")
                return []
        
        return layers_to_remove
    
    def compute_bi(self, hiddens: List[torch.Tensor], angular: bool):
        n = 1
        if angular:
            assert self.n_prune_layers is not None, "Set number of layers to prune to use angular importance"
            n = self.n_prune_layers

        for i in range(len(hiddens) - n):
            in_hidden = hiddens[i]
            out_hidden = hiddens[i+n]
            if angular:
                # use only last token for angular distance as described in section 3.2
                # https://arxiv.org/pdf/2403.17887.pdf
                in_hidden = in_hidden[:,-1:]
                out_hidden = out_hidden[:,-1:]
            
            self.importances[i] += block_influence(
                in_hidden,
                out_hidden,
                angular=angular
            ).sum().cpu().item()

    @torch.inference_mode()
    def eval_importance(
        self,
        prompts: List[str],
        max_seq_len: int,
        stride: int = 256,
        max_gen_len: int = 0,
        temperature: float = 0.6,
        top_p: float = 0.9,
        angular: Optional[bool] = False
    ):
        """
        Computes layer-wise importances over input texts.

        NOTE: ShortGPT paper performs no generation during importance computation, which suggests a `max_gen_len`= 0.

        Args:
            prompts (List[str]): List of prompts.
            max_seq_len (int): Maximum sequence length for model input, the sliding window size.
            (Optional) stride (int): Number of tokens to skip/shift between each window inference.
            (Optional) max_gen_len (int): Maximum length of the generated text sequence.
            (Optional) temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6.
            (Optional) top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
            (Optional) angular (bool): Whether to ues angular distance. Defaults to False.

        Returns:
            None
        """
        prompt_tokens = self.tokenizer(
            prompts,
            padding=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        input_ids = prompt_tokens.input_ids
        attn_mask = prompt_tokens.attention_mask

        max_prompt_len = max(len(t) for t in input_ids)

        # authors use a sliding window of size 1024 with a shift of 256
        for start in range(0, max_prompt_len, stride):
            seq_ids = (attn_mask.sum(dim=-1) > start).nonzero().squeeze()
            seq_ids = seq_ids.unsqueeze(0) if seq_ids.dim() == 0 else seq_ids  # ensure 2d
            inputs = input_ids[seq_ids, start:start+max_seq_len]
            attn = attn_mask[seq_ids, start:start+max_seq_len]

            if max_gen_len == 0:
                outputs = self.model(
                    input_ids=inputs.to("cuda"),
                    attention_mask=attn.to("cuda"),
                    output_hidden_states=True,
                )
            else:
                outputs = self.model.generate(
                    input_ids=inputs.to("cuda"),
                    attention_mask=attn.to("cuda"),
                    max_new_tokens=max_gen_len, 
                    do_sample=True,
                    temperature=temperature,
                    top_p=top_p,
                    output_hidden_states=True,
                    return_dict_in_generate=True,
                )
            
            self.compute_bi(outputs.hidden_states, angular=angular)

        return

In [7]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

# Step 1: Load dataset
data = load_dataset("hellaswag", split="validation", trust_remote_code=True)

def process_hellaswag(examples):
    texts = []
    for ctx, endings_list, label in zip(examples["ctx"], examples["endings"], examples["label"]):
        combined_text = ctx + " " + endings_list[int(label)]  # ✅ Convert label to int
        texts.append(combined_text)
    return {"text": texts}



processed_data = data.map(process_hellaswag, batched=True)

# ✅ Step 3: Create DataLoader **after processing**
dataloader = DataLoader(
    processed_data,
    batch_size=1,
)

### Model Setup

Initializes the pruning wrapper with the base LLaMA 3.1 8B model.

- `MAX_SEQ_LEN = 1024`: Defines the window size for input sequences.
- `ShortHFModel(...)`: Loads the model and prepares to prune 5 layers from `model.layers`.


In [8]:
MAX_SEQ_LEN = 1024
short_model = ShortHFModel(
    model_name="meta-llama/Llama-3.1-8B",
    layers_path="model.layers",
    n_prune_layers=5,
)

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/826 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

In [9]:
short_model.model


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

In [10]:
short_model.model.config

LlamaConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 8.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.51.2",
  "use_cache": true,
  "vocab_size": 128256
}

In [11]:
# sample generation
gen = short_model.model.generate(
    short_model.tokenizer(["Dhaka is the capital city of"], return_tensors='pt').input_ids.to("cuda"),
    max_new_tokens=50
)
short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


['Dhaka is the capital city of Bangladesh, located in the center of the country. It is the largest city in Bangladesh and one of the most densely populated cities in the world. Dhaka is a bustling metropolis with a rich cultural heritage and a thriving economy.\nDhaka is']

In [12]:
# Step 4: Run importance evaluation
for i, batch in enumerate(tqdm(dataloader)):
    prompts = batch['text']

    short_model.eval_importance(
        prompts=prompts,
        max_seq_len=256,
        stride=256,
        max_gen_len=0,
        angular=True  # ✅ angular enabled
    )

100%|██████████| 10042/10042 [05:54<00:00, 28.31it/s]


In [13]:
short_model.importances

[4596.36279296875,
 3559.147216796875,
 3543.085205078125,
 3428.084228515625,
 3486.94287109375,
 3372.723876953125,
 3314.354736328125,
 3309.061767578125,
 3233.232666015625,
 3259.996826171875,
 3117.794921875,
 3121.95751953125,
 3185.0390625,
 3159.41552734375,
 2998.45654296875,
 2843.5714111328125,
 2648.7491455078125,
 2451.3165283203125,
 2269.86328125,
 2119.666748046875,
 1983.0262451171875,
 1926.4088134765625,
 1806.6669921875,
 1745.3758544921875,
 1764.9443359375,
 1890.9305419921875,
 2553.4669189453125,
 4105.641357421875,
 0,
 0,
 0,
 0]

In [14]:
param_size = 0
for param in short_model.model.parameters():
    param_size += param.nelement() * param.element_size()

buffer_size = 0
for buffer in short_model.model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

# Total size in bytes to GB
size_all_gb = (param_size + buffer_size) / 1024**3
print('Model size: {:.3f} GB'.format(size_all_gb))

Model size: 14.958 GB


In [15]:
short_model.remove_layers(angular=True)


[23, 24, 25, 26, 27]

In [16]:
short_model.layers

ModuleList(
  (0-26): 27 x LlamaDecoderLayer(
    (self_attn): LlamaAttention(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
    (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
  )
)

In [17]:
short_model.model.config

LlamaConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 8.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.51.2",
  "use_cache": true,
  "vocab_size": 128256
}

In [18]:
# reassign layer_idx to attentions for caching
for layer_idx, module in enumerate(short_model.layers):
    module.self_attn.layer_idx = layer_idx

In [19]:
short_model.model.config

LlamaConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 8.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.51.2",
  "use_cache": true,
  "vocab_size": 128256
}

In [20]:
gen = short_model.model.generate(
    short_model.tokenizer(["Dhaka is the capital city of"], return_tensors='pt').input_ids.to("cuda"),
    max_new_tokens=50,
    use_cache=True
)
short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


['Dhaka is the capital city of Bangladesh and the largest city in the South East Asia. It is located in the central central-eastern region of the country and is the largest urban area in the eastern hemisphere.\nThe city is the largest city in the South East Asia. It is located']

In [21]:
param_size = 0
for param in short_model.model.parameters():
    param_size += param.nelement() * param.element_size()

buffer_size = 0
for buffer in short_model.model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

# Total size in bytes to GB
size_all_gb = (param_size + buffer_size) / 1024**3
print('Model size: {:.3f} GB'.format(size_all_gb))

Model size: 12.926 GB


In [22]:
short_model.model.config.num_hidden_layers = len(short_model.layers)

In [23]:
short_model.model.config


LlamaConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 27,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 8.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.51.2",
  "use_cache": true,
  "vocab_size": 128256
}

In [24]:
param_size = 0
for param in short_model.model.parameters():
    param_size += param.nelement() * param.element_size()

buffer_size = 0
for buffer in short_model.model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

# Total size in bytes to GB
size_all_gb = (param_size + buffer_size) / 1024**3
print('Model size: {:.3f} GB'.format(size_all_gb))

Model size: 12.926 GB


In [25]:
import os
new_model_name = 'shortgpt_llama3.1_8B_hellaswag_angular'
output_dir = './'+new_model_name
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

short_model.model.save_pretrained(output_dir)
short_model.tokenizer.save_pretrained(output_dir)
#new_config.save_pretrained(output_dir)
print(f"Pruned model saved to {output_dir}")

Pruned model saved to ./shortgpt_llama3.1_8B_hellaswag_angular


In [26]:
# Push the model to your Hugging Face repository

short_model.model.push_to_hub(new_model_name, private=False)
short_model.tokenizer.push_to_hub(new_model_name)

model-00003-of-00003.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/Shahrukh0/shortgpt_llama3.1_8B_hellaswag_angular/commit/51f44ec9460ca19896cc9a2b5b8b9f8e3ea43a15', commit_message='Upload tokenizer', commit_description='', oid='51f44ec9460ca19896cc9a2b5b8b9f8e3ea43a15', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Shahrukh0/shortgpt_llama3.1_8B_hellaswag_angular', endpoint='https://huggingface.co', repo_type='model', repo_id='Shahrukh0/shortgpt_llama3.1_8B_hellaswag_angular'), pr_revision=None, pr_num=None)