# Libraries

In [1]:
#!pip install -q transformers==4.47.1
#!pip install -q datasets==3.2.0
#!pip install -q torch==2.5.1
#!pip install -q lm-eval==0.4.7

In [2]:
import logging
import math
import os
import sys
import shutil
from copy import deepcopy

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import os

# Model Loading


In [3]:
# device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device: {device}")


Using device: {device}


In [4]:
torch.cuda.empty_cache()
torch.cuda.memory_summary()



In [6]:
model_name = 'meta-llama/Llama-3.2-3B'
model = AutoModelForCausalLM.from_pretrained(model_name,offload_buffers=True, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
# structure of the model
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
      )
    )
    (norm

# Prunning


## Support functions

In [8]:
def measure_unpruned_layer_importances(pruned_model, tokenizer, input_text):
    """
    Measures and returns importance scores for all unpruned (non-bypassed) layers.
    """
    # PREPARATION
    """
    set the model to evaluation mode to ensure that no gradients
    are computed during the forward pass.
    """
    pruned_model.eval()
    device = next(pruned_model.parameters()).device

    """
    The provided input text (input_text) is tokenized into tensors
    suitable for processing by the model.
    """
    inputs = tokenizer(input_text, return_tensors="pt").to(device)

    """This will hold tuples of (layer_idx, importance_score)"""
    importance_scores = []

    # IDENTIFY UNPRUNED LAYERS & CREATING HOOKS
    """
    We'll register hooks for only layers that are NOT in drop_attn_list
    The list of attention layers that have already been pruned,
    is stored in a variable in the model's config: pruned_model.config.drop_attn_list.
    """
    unpruned_layer_indices = [
        idx for idx in range(len(pruned_model.model.layers))
        if idx not in pruned_model.config.drop_attn_list
    ]

    """
    Temporary storage for each layer's input/output
    We'll store them by layer index
    """
    layer_inputs = {}
    layer_outputs = {}

    """
    Create 2 hooks to capture the input and the output of the layers.
    These hooks store the inputs and outputs in dictionaries
    (layer_inputs and layer_outputs) for later analysis
    """
    #Allows capture the input to the query projection (q_proj)
    def q_proj_input_hook(layer_idx):
        def _hook(module, module_input):
            # module_input can be a tuple depending on PyTorch version
            inp = module_input[0] if isinstance(module_input, tuple) else module_input
            layer_inputs[layer_idx] = inp.detach().clone()
        return _hook

    # Allows capture the output from the output projection (o_proj)
    def o_proj_output_hook(layer_idx):
        def _hook(module, module_input, module_output):
            out = module_output[0] if isinstance(module_output, tuple) else module_output
            layer_outputs[layer_idx] = out.detach().clone()
        return _hook

    # Register hooks for each unpruned layer
    handles = []
    for idx in unpruned_layer_indices:
        layer = pruned_model.model.layers[idx]
        handles.append(layer.self_attn.q_proj.register_forward_pre_hook(q_proj_input_hook(idx)))
        handles.append(layer.self_attn.o_proj.register_forward_hook(o_proj_output_hook(idx)))

    # FORWARD PASS
    """
    Single forward pass (no gradient needed)
    A single forward pass is performed on the input text.
    During this pass, the hooks capture the inputs and outputs of the unpruned layers.
    This step is done with torch.no_grad(),
    ensuring no gradients are calculated, which saves memory and computation.
    """
    with torch.no_grad():
        _ = pruned_model(**inputs)

    """
    The hooks are removed after the forward pass
    to avoid memory leaks or interference with subsequent operations.
    """
    for h in handles:
        h.remove()


    #COMPUTE IMPORTANCE SCORES
    """
    For each unpruned layer, the inputs and outputs are flattened into vectors for comparison.

    Cosine Similarity: The similarity between the input and output vectors is
    computed using cosine similarity. Layers with outputs that are very similar
    to their inputs likely contribute less to the model’s overall computation.

    Importance Score: The importance score for each layer is calculated as 1−similarity
    A higher score indicates that the layer transforms its input significantly
    and is therefore more important to the model's function.
    """
    for idx in unpruned_layer_indices:
        if idx in layer_inputs and idx in layer_outputs:
            inp = layer_inputs[idx]
            out = layer_outputs[idx]

            inp_flat = inp.view(inp.size(0), -1)
            out_flat = out.view(out.size(0), -1)

            similarity = F.cosine_similarity(inp_flat, out_flat, dim=1).mean().item()
            importance_score = 1 - similarity
            importance_scores.append((idx, importance_score))

            print(f"[Iterative] Layer {idx} importance score: {importance_score:.4f}")

    """A list of tuples is returned, where each tuple contains the layer index
    and its calculated importance score."""
    return importance_scores

In [9]:
def bypass_single_layer(pruned_model, layer_idx):
    """
    Modifies the specified layer's forward method so that attention is bypassed.
    """
    layer = pruned_model.model.layers[layer_idx]
    # Store the original forward.
    if not hasattr(layer.self_attn, '_original_forward'):
        layer.self_attn._original_forward = layer.self_attn.forward

    # A new forward that checks whether to bypass
    def new_attention_forward(self, hidden_states, attention_mask=None, position_ids=None,
                              past_key_value=None, output_attentions=False, use_cache=False,
                              **kwargs):
        # If this layer is in drop_attn_list, bypass
        if getattr(self, 'layer_idx', -1) in pruned_model.config.drop_attn_list:
            return hidden_states, None, None
        # Otherwise, use the original forward
        return self._original_forward(hidden_states, attention_mask, position_ids,
                                      past_key_value, output_attentions, use_cache, **kwargs)

    # Set the layer index and forward
    layer.self_attn.layer_idx = layer_idx
    layer.self_attn.forward = new_attention_forward.__get__(layer.self_attn, type(layer.self_attn))


In [10]:
def iterative_pruning(model, tokenizer, input_text, num_layers_to_prune):
    """
    Iteratively:
      1) Measures importance of unpruned layers,
      2) Prunes (bypasses) the least important layer,
      3) Repeats until num_layers_to_prune layers are pruned.
    """
    # Create a copy of the model so we don't modify the original
    pruned_model = deepcopy(model)

    # Make sure we have a list of pruned layers in config
    pruned_model.config.drop_attn_list = []

    total_layers = len(pruned_model.model.layers)
    print(f"Total layers: {total_layers}")

    for step in range(num_layers_to_prune):
        print(f"\n--- Iteration {step + 1} of {num_layers_to_prune} ---")

        # 1) Measure importance scores for all unpruned layers
        importance_scores = measure_unpruned_layer_importances(pruned_model, tokenizer, input_text)
        if not importance_scores:
            print("No unpruned layers found or no importance scores computed.")
            break

        # 2) Pick layer with the lowest importance
        layer_to_bypass, min_score = min(importance_scores, key=lambda x: x[1])

        # 3) Bypass that layer
        pruned_model.config.drop_attn_list.append(layer_to_bypass)
        bypass_single_layer(pruned_model, layer_to_bypass)

        print(f"Bypassing layer {layer_to_bypass} with importance score {min_score:.4f}")
        print(f"Current bypass list: {pruned_model.config.drop_attn_list}")

    print(f"\nFinal bypassed layers: {sorted(pruned_model.config.drop_attn_list)}")
    print(f"Number of bypassed layers: {len(pruned_model.config.drop_attn_list)}")

    return pruned_model

## Prunning the model

In [11]:
pruned_model = iterative_pruning(
      model,
      tokenizer,
       "Hi I'm a sample text, use to calculate teh cosine difference between input and output.",
      num_layers_to_prune=4
)

Total layers: 28

--- Iteration 1 of 4 ---
[Iterative] Layer 0 importance score: 1.0713
[Iterative] Layer 1 importance score: 0.9355
[Iterative] Layer 2 importance score: 0.9594
[Iterative] Layer 3 importance score: 0.9596
[Iterative] Layer 4 importance score: 0.9597
[Iterative] Layer 5 importance score: 0.9989
[Iterative] Layer 6 importance score: 1.0272
[Iterative] Layer 7 importance score: 1.0362
[Iterative] Layer 8 importance score: 1.1125
[Iterative] Layer 9 importance score: 1.1243
[Iterative] Layer 10 importance score: 1.0488
[Iterative] Layer 11 importance score: 0.9959
[Iterative] Layer 12 importance score: 1.0859
[Iterative] Layer 13 importance score: 1.0245
[Iterative] Layer 14 importance score: 0.9680
[Iterative] Layer 15 importance score: 0.9707
[Iterative] Layer 16 importance score: 0.9034
[Iterative] Layer 17 importance score: 0.9299
[Iterative] Layer 18 importance score: 1.0231
[Iterative] Layer 19 importance score: 0.7273
[Iterative] Layer 20 importance score: 0.8678
[

# Test the Model

In [12]:
import time

def get_output(prompt, model=model, tokenizer=tokenizer, num_runs=1, max_length=50):
    total_time = 0
    generated_outputs = []

    for run in range(num_runs):
        # Start timing
        start_time = time.time()

        # Tokenization time
        token_start = time.time()
        inputs = tokenizer(prompt, return_tensors='pt').to(device)
        token_time = time.time() - token_start

        # Generation time
        gen_start = time.time()
        outputs = model.generate(
            inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=max_length,
            num_return_sequences=1,
            pad_token_id=tokenizer.pad_token_id,
            temperature=None,
            top_p=None,
            do_sample=False,  # Disable sampling
            num_beams=5,      # Use beam search
            early_stopping=True,  # Stop when end-of-sequence token is generated
            no_repeat_ngram_size=2  # Prevent repetition of 2-grams
        )
        gen_time = time.time() - gen_start

        # Decoding time
        decode_start = time.time()
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        decode_time = time.time() - decode_start

        # Total time for this run
        total_time += time.time() - start_time
        generated_outputs.append(generated)

        # Measure memory usage
        memory_allocated = torch.cuda.memory_allocated() / (1024 ** 2)  # In MB
        memory_reserved = torch.cuda.memory_reserved() / (1024 ** 2)  # In MB

        print(f"Memory Allocated: {memory_allocated:.2f} MB")
        print(f"Memory Reserved: {memory_reserved:.2f} MB")

        if num_runs > 1:
            print(f"\nRun {run + 1}:")
        print(f"Tokenization time: {token_time*1000:.2f} ms")
        print(f"Generation time: {gen_time*1000:.2f} ms")
        print(f"Decoding time: {decode_time*1000:.2f} ms")
        print(f"Total time: {(time.time() - start_time)*1000:.2f} ms")

    if num_runs > 1:
        avg_time = total_time / num_runs
        print(f"\nAverage time over {num_runs} runs: {avg_time*1000:.2f} ms")

    return generated_outputs[0] if num_runs == 1 else generated_outputs

In [13]:
# Test the original model
prompt = "Dhaka is the capital of"
generated = get_output(prompt, num_runs=2)
print(f"Generated text: {generated}")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Memory Allocated: 12264.82 MB
Memory Reserved: 12398.00 MB

Run 1:
Tokenization time: 0.44 ms
Generation time: 3160.33 ms
Decoding time: 0.26 ms
Total time: 3162.14 ms
Memory Allocated: 12264.82 MB
Memory Reserved: 12398.00 MB

Run 2:
Tokenization time: 0.64 ms
Generation time: 3093.70 ms
Decoding time: 0.19 ms
Total time: 3095.05 ms

Average time over 2 runs: 3127.78 ms
Generated text: ['Dhaka is the capital of Bangladesh. It is located on the banks of the Buriganga River, which flows into the Bay of Bengal. The city has a population of over 10 million people, making it the largest city in Bangladesh', 'Dhaka is the capital of Bangladesh. It is located on the banks of the Buriganga River, which flows into the Bay of Bengal. The city has a population of over 10 million people, making it the largest city in Bangladesh']


In [14]:
# Test the pruned model
generated = get_output(prompt, pruned_model, num_runs=2)
print(f"Generated text: {generated}")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Memory Allocated: 12264.82 MB
Memory Reserved: 12398.00 MB

Run 1:
Tokenization time: 1.00 ms
Generation time: 2980.12 ms
Decoding time: 0.18 ms
Total time: 2981.82 ms
Memory Allocated: 12264.82 MB
Memory Reserved: 12398.00 MB

Run 2:
Tokenization time: 1.18 ms
Generation time: 2981.45 ms
Decoding time: 0.20 ms
Total time: 2983.37 ms

Average time over 2 runs: 2982.07 ms
Generated text: ['Dhaka is the capital of Bangladesh Bangladesh is a country located in Asia-Pacific region. It is also known as “””Bang Bang Bang bang bangbangbang-b-b-B B B b bB B-B-B- B.B.B-B', 'Dhaka is the capital of Bangladesh Bangladesh is a country located in Asia-Pacific region. It is also known as “””Bang Bang Bang bang bangbangbang-b-b-B B B b bB B-B-B- B.B.B-B']


# Storing the model


In [15]:
new_model_name = 'attnprun-llama-3.2-3B'
output_dir = './'+new_model_name
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

pruned_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
#new_config.save_pretrained(output_dir)
print(f"Pruned model saved to {output_dir}")

Pruned model saved to ./attnprun-llama-3.2-3B


In [18]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [19]:
# Push the model to your Hugging Face repository
pruned_model.push_to_hub(new_model_name, private=False)
tokenizer.push_to_hub(new_model_name)

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/Shahrukh0/attnprun-llama-3.2-3B/commit/759df17f502e6f279e73d3ef4a2f3f4361975e55', commit_message='Upload tokenizer', commit_description='', oid='759df17f502e6f279e73d3ef4a2f3f4361975e55', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Shahrukh0/attnprun-llama-3.2-3B', endpoint='https://huggingface.co', repo_type='model', repo_id='Shahrukh0/attnprun-llama-3.2-3B'), pr_revision=None, pr_num=None)