<a href="https://colab.research.google.com/github/peremartra/Large-Language-Model-Notebooks-Course/blob/inference-adaptative-attention-pruning/6-PRUNING/6_6b_Adaptive_Inference_Attention_Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<div>
    <h1>Large Language Models Projects</a></h1>
    <h3>Apply and Implement Strategies for Large Language Models</h3>
    <h2>Pruning Attention Layers</h2>
    <h3>Not All Attention is needed</h3>
</div>

by [Pere Martra](https://www.linkedin.com/in/pere-martra/)

_______
Models: meta-llama/Llama-3.2

Colab Environment: GPU L4 for 3B Models

T4 for 1B Model.

Keys:
* Pruning
* Attention

References:
* [What Matters in Transformers? Not All Attention is Needed](https://arxiv.org/abs/2406.15786)
* [Resource-Efficient Transformer Pruning for Finetuning of Large Models](https://openaccess.thecvf.com/content/CVPR2024/html/Ilhan_Resource-Efficient_Transformer_Pruning_for_Finetuning_of_Large_Models_CVPR_2024_paper.html)

_______
**disclaimer: The pruning / knowledge distillation section has been created after the first edition of the book was published. They are not included in the book’s original content but are intended to supplement and expand on the topics covered.**

This is the unofficial repository for the book:
        <a href="https://amzn.to/4eanT1g"> <b>Large Language Models:</b> Apply and Implement Strategies for Large Language Models</a> (Apress).
        The book is based on the content of this repository, but the notebooks are being updated, and I am incorporating new examples and chapters.
        If you are looking for the official repository for the book, with the original notebooks, you should visit the
        <a href="https://github.com/Apress/Large-Language-Models-Projects">Apress repository</a>, where you can find all the notebooks in their original format as they appear in the book.

______
# Introduction
This notebook implements the paper: [What Matters in Transformers? Not all Attention is Needed](https://arxiv.org/abs/2406.15786).
Although I followed the paper's guidelines, I made some adjustments to make the code clearer and easier to understand.

The original paper demonstrates that larger models tend to have excessive redundancy in their attention layers. They achieved a 48.4% increase in inference performance for a Llama-2-70B model with only a minor 2.4% drop in response quality, **just bypassing the 50% of the Attention layers!**

In this notebook, tests have been conducted using Llama-3.2-1B and 3B models. With these models, I found that removing 50% of the attention layers significantly impacted the model's functionality. However, the 3B model handled the removal of these layers much better. This suggests that redundancy may become more pronounced as model size increases.

# Methodology.
To identify which layers contribute the least, the cosine distance between the layer's input and output is measured. In the paper, this distance is calculated using a test dataset, while in the notebook, I used a simple prompt to activate the layers.  

This method of measuring the importance of attention layers and their contribution to the model allows pruning to be tailored to a specific dataset. This approach can lead to the creation of more efficient models for specialized sectors such as healthcare or finance.


Once the layer contributing the least to the final output is identified (the one with the smallest difference between input and output), it is added to a list included in the configuration file.

This list is then referenced during inference by a new forward function that replaces the original one for attention layers. When this new function detects that a layer is in the list, it skips its execution and simply returns the input without modifications.

The process of identifying layers to deactivate and marking them as non-executable is one-shot. In other words, it does  determine all the layers to skip in one go, as recommended in the paper.

The iterative implementation has a significant drawback: the test dataset must be processed for each layer to be deactivated. The paper's authors note that while the iterative method may bring slight improvements, the added computational cost is not justified. However, since this is an example notebook, and there is no test dataset—just a small prompt—and the layer selection process takes only seconds, I chose the iterative approach.

This pruning method does not produce a smaller model, as the layers are not physically removed. They remain in the model but are not executed, resulting in improved inference response times.
______

# Install libraries & Configure variables.

In [1]:
!pip install -q torch==2.6.0
!pip install -q torchvision==0.21.0
!pip install -q transformers==4.51.3
!pip install -q datasets==3.6.0
!pip install -q lm-eval==0.4.8

!pip install hf_xet #To speed up downloads from HF.

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m91.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m84.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m53.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m38.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import logging
import math
import os
import sys
import shutil
from copy import deepcopy

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer


# Adaptative Configuration.

In [18]:
# =============================================================================
# ADAPTIVE ATTENTION BYPASS (AAB) CONFIGURATION - CORRECTED SCALING TO 100%
# =============================================================================

GLOBAL_COMPLEXITIES = [0.1, 0.3, 0.5, 0.7, 0.9]

ADAPTIVE_CONFIG = {
    # Model size-based ratios with proportional scaling to 100%
    "model_size_ratios": {
        "70B+": {
            "trivial": {"min_ratio": 0.15, "scaling_factor": 0.85},
            "simple": {"min_ratio": 0.35, "scaling_factor": 0.65},
            "medium": {"min_ratio": 0.55, "scaling_factor": 0.45},
            "complex": {"min_ratio": 0.75, "scaling_factor": 0.25},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        },
        "30B-70B": {
            "trivial": {"min_ratio": 0.25, "scaling_factor": 0.75},
            "simple": {"min_ratio": 0.40, "scaling_factor": 0.60},
            "medium": {"min_ratio": 0.60, "scaling_factor": 0.40},
            "complex": {"min_ratio": 0.80, "scaling_factor": 0.20},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        },
        "10B-30B": {
            "trivial": {"min_ratio": 0.30, "scaling_factor": 0.75},
            "simple": {"min_ratio": 0.45, "scaling_factor": 0.55},
            "medium": {"min_ratio": 0.65, "scaling_factor": 0.35},
            "complex": {"min_ratio": 0.82, "scaling_factor": 0.18},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        },
        "5B-10B": {
            "trivial": {"min_ratio": 0.35, "scaling_factor": 0.65},
            "simple": {"min_ratio": 0.55, "scaling_factor": 0.45},
            "medium": {"min_ratio": 0.75, "scaling_factor": 0.25},
            "complex": {"min_ratio": 0.87, "scaling_factor": 0.13},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        },
        "2B-5B": {
            "trivial": {"min_ratio": 0.60, "scaling_factor": 0.40},
            "simple": {"min_ratio": 0.75, "scaling_factor": 0.20},
            "medium": {"min_ratio": 0.90, "scaling_factor": 0.20},
            "complex": {"min_ratio": 0.95, "scaling_factor": 0.10},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        },
        "<2B": {
            "trivial": {"min_ratio": 0.80, "scaling_factor": 0.20},
            "simple": {"min_ratio": 0.85, "scaling_factor": 0.15},
            "medium": {"min_ratio": 0.90, "scaling_factor": 0.10},
            "complex": {"min_ratio": 0.95, "scaling_factor": 0.05},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        }
    },

    # 5-level complexity thresholds and descriptions
    "complexity_levels": {
        "trivial": {
            "range": [0.0, 0.2],
            "description": "Single word answers, basic arithmetic",
            "layer_groups": ["imprescindibles", "critical"]
        },
        "simple": {
            "range": [0.2, 0.4],
            "description": "Simple factual questions",
            "layer_groups": ["imprescindibles", "critical", "important"]
        },
        "medium": {
            "range": [0.4, 0.6],
            "description": "Knowledge retrieval, completion tasks",
            "layer_groups": ["imprescindibles", "critical", "important", "standard"]
        },
        "complex": {
            "range": [0.6, 0.8],
            "description": "Reasoning, explanations",
            "layer_groups": ["imprescindibles", "critical", "important", "standard", "optional"]
        },
        "very_complex": {
            "range": [0.8, 1.0],
            "description": "Deep reasoning, analysis, creativity",
            "layer_groups": ["imprescindibles", "critical", "important", "standard", "optional", "dispensable"]
        }
    },

    # Layer activation thresholds (1:1 correspondence)
    "activation_thresholds": {
        "critical": 0.0,      # Always active above trivial level
        "important": 0.2,     # Active from simple level
        "standard": 0.4,      # Active from medium level
        "optional": 0.6,      # Active from complex level
        "dispensable": 0.8    # Active from very_complex level
    }
}

print("✅ CORRECTED 5-level adaptive configuration loaded!")
print(f"📊 All models can now reach 100% layers for very_complex prompts")
print(f"🎯 5 complexity levels: {list(ADAPTIVE_CONFIG['complexity_levels'].keys())}")

✅ CORRECTED 5-level adaptive configuration loaded!
📊 All models can now reach 100% layers for very_complex prompts
🎯 5 complexity levels: ['trivial', 'simple', 'medium', 'complex', 'very_complex']


## Support & calculate functions

In [4]:
def detect_model_size_category(model):
    """
    Automatically detect model size category from model parameters
    """
    try:
        total_params = sum(p.numel() for p in model.parameters())
        size_billion = total_params / 1e9

        print(f"🔍 Detected model size: {size_billion:.2f}B parameters")

        if size_billion >= 70:
            return "70B+"
        elif size_billion >= 30:
            return "30B-70B"
        elif size_billion >= 10:
            return "10B-30B"
        elif size_billion >= 5:
            return "5B-10B"
        elif size_billion >= 2:
            return "2B-5B"
        else:
            return "<2B"

    except Exception as e:
        print(f"⚠️ Error detecting model size: {e}")
        return "1B-3B"


def get_context_window_size(model, tokenizer=None):
    """
    Tries to detect the model's maximum context window size (in tokens).
    """
    # Try model config first (most reliable)
    if hasattr(model, 'config'):
        config = model.config

        # Common attribute names for context window
        context_attrs = ['max_position_embeddings', 'n_positions', 'max_seq_len',
                        'seq_length', 'max_sequence_length', 'context_length']

        for attr in context_attrs:
            if hasattr(config, attr):
                value = getattr(config, attr)
                if isinstance(value, int) and value > 0:
                    return value

    # Try tokenizer as fallback
    if tokenizer:
        if hasattr(tokenizer, 'model_max_length') and tokenizer.model_max_length != int(1e30):
            return tokenizer.model_max_length

    return None


def count_attention_layers_correctly(model):
    """
    Correctly count attention layers by finding main decoder/transformer layers
    """
    # Method 1: Count main decoder layers directly (most reliable)
    decoder_layer_count = 0
    for name, module in model.named_modules():
        module_type = type(module).__name__
        # Look for main transformer/decoder layers
        if any(layer_type in module_type for layer_type in
               ['DecoderLayer', 'TransformerBlock', 'Block', 'Layer']) and \
           any(exclude not in module_type for exclude in
               ['Embedding', 'Norm', 'Linear', 'MLP', 'Attention']):
            # Make sure it's a numbered layer (e.g., layers.0, layers.1, etc.)
            if '.layers.' in name and name.count('.') == 2:  # e.g., "model.layers.0"
                decoder_layer_count += 1

    if decoder_layer_count > 0:
        return decoder_layer_count

    # Method 2: Use model config as fallback
    try:
        if hasattr(model, 'config'):
            config_attrs = ['num_hidden_layers', 'n_layer', 'num_layers', 'n_layers']
            for attr in config_attrs:
                if hasattr(model.config, attr):
                    return getattr(model.config, attr)
    except:
        pass

    # Method 3: Direct access to layers ModuleList
    try:
        if hasattr(model, 'model') and hasattr(model.model, 'layers'):
            return len(model.model.layers)
    except:
        pass

    return 16  # Conservative fallback


def classify_complexity_level(complexity_score):
    """
    Classify complexity score into one of 5 levels

    Args:
        complexity_score (float): Complexity score (0.0-1.0)

    Returns:
        str: Complexity level ("trivial", "simple", "medium", "complex", "very_complex")
    """
    levels = ADAPTIVE_CONFIG["complexity_levels"]

    for level_name, level_config in levels.items():
        min_val, max_val = level_config["range"]
        if min_val <= complexity_score < max_val:
            return level_name

    # Handle edge case for exactly 1.0
    if complexity_score >= 0.8:
        return "very_complex"

    return "trivial"  # Fallback


def calculate_active_layers(total_layers, model_size_category, complexity_score):
    """
    Calculate number of active layers based on complexity and model size

    Args:
        total_layers (int): Total number of attention layers
        model_size_category (str): Model size category
        complexity_score (float): Complexity score (0.0-1.0)

    Returns:
        tuple: (active_layers_count, complexity_level, layer_groups_used, min_guaranteed, max_possible)
    """
    # Classify complexity level
    complexity_level = classify_complexity_level(complexity_score)

    # Get configuration for this model size and complexity
    config = ADAPTIVE_CONFIG["model_size_ratios"][model_size_category][complexity_level]
    min_ratio = config["min_ratio"]
    scaling_factor = config["scaling_factor"]

    # Calculate layer counts
    min_guaranteed = int(total_layers * min_ratio)
    remaining_layers = total_layers - min_guaranteed
    additional_layers = int(complexity_score * scaling_factor * remaining_layers)
    active_layers = min_guaranteed + additional_layers

    # Ensure we don't exceed total layers
    active_layers = min(active_layers, total_layers)
    max_possible = total_layers  # Always can reach 100%

    # Get which layer groups should be used
    layer_groups_used = ADAPTIVE_CONFIG["complexity_levels"][complexity_level]["layer_groups"]

    return active_layers, complexity_level, layer_groups_used, min_guaranteed, max_possible


def get_complexity_info(complexity_level):
    """
    Get detailed information about a complexity level
    """
    return ADAPTIVE_CONFIG["complexity_levels"][complexity_level]


print("✅ Enhanced 5-level calculation functions loaded successfully!")

✅ Enhanced 5-level calculation functions loaded successfully!


# Download the Model.

In [5]:
# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [6]:
#model_name = 'meta-llama/Llama-3.2-1B'
model_name = 'meta-llama/Llama-3.2-3B'
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
#tokenizer.pad_token = tokenizer.eos_token  # Set pad token

config.json:   0%|          | 0.00/844 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

## Test

In [19]:
# Test the configuration with clean, simplified output
if 'model' in locals():
    # Get model information with improved detection
    total_attention_layers = count_attention_layers_correctly(model)
    model_category = detect_model_size_category(model)
    context_window = get_context_window_size(model, tokenizer if 'tokenizer' in locals() else None)

    print(f"\n🏗️ Model Analysis:")
    print(f"   Attention layers: {total_attention_layers}")
    print(f"   Size category: {model_category}")
    print(f"   Context window: {context_window if context_window else 'Unknown'} tokens")
    print(f"   Architecture: {type(model).__name__}")

    # Show layer detection verification
    print(f"\n🔍 Layer Detection Verification:")
    decoder_layers = [name for name, module in model.named_modules()
                     if 'DecoderLayer' in type(module).__name__ and '.layers.' in name]
    print(f"   Found DecoderLayers: {len(decoder_layers)}")

    # Test all 5 complexity levels with simplified table
    test_complexities = GLOBAL_COMPLEXITIES

    print("\n🧪 Layer Activation by Complexity Level:")
    print("=" * 50)
    print(f"{'Level':<12} {'Active Layers':<15} {'Usage Ratio':<12}")
    print("-" * 50)

    for complexity in test_complexities:
        active, level, groups, min_guaranteed, max_possible = calculate_active_layers(
            total_attention_layers, model_category, complexity
        )
        ratio = active / total_attention_layers

        print(f"{level.capitalize():<12} {active:<15} {ratio:<12.1%}")

    print(f"\n📊 Summary for {model_category} model:")
    trivial_config = ADAPTIVE_CONFIG['model_size_ratios'][model_category]['trivial']
    trivial_min = int(total_attention_layers * trivial_config['min_ratio'])
    print(f"   • Range: {trivial_min}-{total_attention_layers} layers ({trivial_min/total_attention_layers:.1%}-100%)")
    print(f"   • Context window: {context_window if context_window else 'Unknown'} tokens")
    print(f"   • All complexity levels can reach 100% layer usage")

else:
    print("⏳ Load your model first to test the configuration")
    print("\nTo test, make sure you have:")
    print("1. model = ... (your loaded model)")
    print("2. tokenizer = ... (optional, your tokenizer)")

🔍 Detected model size: 3.21B parameters

🏗️ Model Analysis:
   Attention layers: 28
   Size category: 2B-5B
   Context window: 131072 tokens
   Architecture: LlamaForCausalLM

🔍 Layer Detection Verification:
   Found DecoderLayers: 28

🧪 Layer Activation by Complexity Level:
Level        Active Layers   Usage Ratio 
--------------------------------------------------
Trivial      16              57.1%       
Simple       21              75.0%       
Medium       25              89.3%       
Complex      26              92.9%       
Very_complex 28              100.0%      

📊 Summary for 2B-5B model:
   • Range: 16-28 layers (57.1%-100%)
   • Context window: 131072 tokens
   • All complexity levels can reach 100% layer usage


## Study the structure.
* Llama-3.2-1B
```
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=2048, out_features=128256, bias=False)
)
```


The model follows the typical structure of modern Llama models, consisting of blocks made up of an Attention layer and an MLP layer with a GLU structure.

> If you want to see an example of how to perform pruning on the MLP layers of the model, you can check out the notebook:[Pruning Llama 3.2.](https://github.com/peremartra/Large-Language-Model-Notebooks-Course/blob/main/6-PRUNING/6_3_pruning_structured_llama3.2-1b_OK.ipynb) y leer el paper [Exploring GLU expansion ratios: Structured pruning in Llama-3.2 models](https://osf.io/preprints/osf/qgxea)



Since the layers form a block, the attention layer cannot be removed without also removing the accompanying MLP layer. For this reason, the decision was made to deactivate their execution during inference.

The 1B model has 16 layers, as shown in the structure above, while the 3B model has 28 layers.


# Inference function & Test Base Model

The `get_output` function is designed to generate text  and measure the time taken for different stages of the generation process.

It provides insights into the performance of the model and can be used to evaluate the efficiency of text generation.

In [8]:
import time

def get_output(prompt, model=model, tokenizer=tokenizer, num_runs=1, max_length=50):
    total_time = 0
    generated_outputs = []

    for run in range(num_runs):
        # Start timing
        start_time = time.time()

        # Tokenization time
        token_start = time.time()
        inputs = tokenizer(prompt, return_tensors='pt').to(device)
        token_time = time.time() - token_start

        # Generation time
        gen_start = time.time()
        outputs = model.generate(
            inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=max_length,
            num_return_sequences=1,
            pad_token_id=tokenizer.pad_token_id,
            temperature=None,
            top_p=None,
            do_sample=False,  # Disable sampling
            num_beams=5,      # Use beam search
            early_stopping=True,  # Stop when end-of-sequence token is generated
            no_repeat_ngram_size=2  # Prevent repetition of 2-grams
        )
        gen_time = time.time() - gen_start

        # Decoding time
        decode_start = time.time()
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        decode_time = time.time() - decode_start

        # Total time for this run
        total_time += time.time() - start_time
        generated_outputs.append(generated)

        if num_runs > 1:
            print(f"\nRun {run + 1}:")
        print(f"Tokenization time: {token_time*1000:.2f} ms")
        print(f"Generation time: {gen_time*1000:.2f} ms")
        print(f"Decoding time: {decode_time*1000:.2f} ms")
        print(f"Total time: {(time.time() - start_time)*1000:.2f} ms")

    if num_runs > 1:
        avg_time = total_time / num_runs
        print(f"\nAverage time over {num_runs} runs: {avg_time*1000:.2f} ms")

    return generated_outputs[0] if num_runs == 1 else generated_outputs

In [9]:
# Test the original model
prompt = "Paris is the capital of"
generated = get_output(prompt, num_runs=2)
print(f"Generated text: {generated}")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Run 1:
Tokenization time: 3.10 ms
Generation time: 4195.37 ms
Decoding time: 0.32 ms
Total time: 4198.89 ms

Run 2:
Tokenization time: 0.65 ms
Generation time: 3044.44 ms
Decoding time: 0.23 ms
Total time: 3045.43 ms

Average time over 2 runs: 3622.05 ms
Generated text: ['Paris is the capital of France. It is located in the north-central part of the country, on the river Seine. The city has a population of over 2 million people, making it the largest city in France and the second-largest city', 'Paris is the capital of France. It is located in the north-central part of the country, on the river Seine. The city has a population of over 2 million people, making it the largest city in France and the second-largest city']


The text generation of the original model, as expected, works perfectly and returns a correct and meaningful sentence.

In [28]:
model.to("cpu")               # actual data moves ↙
torch.cuda.empty_cache()      # allocator drops cached blocks

# Pruning the Model.

In [10]:
#import torch
#import torch.nn as nn
#from torch.nn import functional as F
#from copy import deepcopy

## Execute Pruning.

**Disclaimer**

I'm using a single illustrative prompt so that the code path is easy to follow. In any research or production setting you must feed hundreds or thousands of diverse prompts before deciding which layers to deactivate

In [11]:
# Using multiple prompts for calibration
calibration_prompts = [
    "Hi I'm a sample text, used to calculate the cosine difference between input and output.",
    "The quick brown fox jumps over the lazy dog.",
    "Machine learning models can be optimized through various techniques, explain the principals.",
    "2+2=",
    "What is the meaning of life, the universe, and everything?"
]

In [20]:
# =============================================================================
# SIMPLE AAB CALIBRATION - MINIMAL CODE
# =============================================================================
def measure_layer_importance_simple(model, tokenizer, prompts):
    """Simple layer importance measurement - FIXED using original notebook pattern"""
    model.eval()
    device = next(model.parameters()).device
    total_layers = len(model.model.layers)

    # Accumulate importance scores across all prompts
    importance_acc = {idx: 0.0 for idx in range(total_layers)}

    print(f"📊 Processing {len(prompts)} prompts across {total_layers} layers...")

    for prompt_idx, prompt in enumerate(prompts):
        print(f"   Processing prompt {prompt_idx + 1}/{len(prompts)}")

        # Tokenize input (following original notebook pattern)
        inputs = tokenizer(prompt, return_tensors="pt").to(device)

        # Storage for this prompt's layer inputs/outputs
        layer_inputs = {}
        layer_outputs = {}

        # Create hooks (EXACTLY like the original function)
        def q_proj_input_hook(layer_idx):
            def _hook(module, module_input):
                # Handle tuple input (following original pattern)
                inp = module_input[0] if isinstance(module_input, tuple) else module_input
                layer_inputs[layer_idx] = inp.detach().clone()
            return _hook

        def o_proj_output_hook(layer_idx):
            def _hook(module, module_input, module_output):
                # Handle tuple output (following original pattern)
                out = module_output[0] if isinstance(module_output, tuple) else module_output
                layer_outputs[layer_idx] = out.detach().clone()
            return _hook

        # Register hooks for ALL layers (not just unpruned ones)
        handles = []
        for idx in range(total_layers):
            layer = model.model.layers[idx]
            handles.append(layer.self_attn.q_proj.register_forward_pre_hook(q_proj_input_hook(idx)))
            handles.append(layer.self_attn.o_proj.register_forward_hook(o_proj_output_hook(idx)))

        # Forward pass (following original pattern)
        with torch.no_grad():
            _ = model(**inputs)

        # Remove hooks (following original pattern)
        for h in handles:
            h.remove()

        # Calculate importance for each layer (EXACTLY like original)
        for idx in range(total_layers):
            if idx in layer_inputs and idx in layer_outputs:
                inp = layer_inputs[idx]
                out = layer_outputs[idx]

                # Flatten tensors (following original pattern)
                inp_flat = inp.view(inp.size(0), -1)
                out_flat = out.view(out.size(0), -1)

                # Calculate similarity and importance (following original pattern)
                similarity = F.cosine_similarity(inp_flat, out_flat, dim=1).mean().item()
                importance_score = 1 - similarity
                importance_acc[idx] += importance_score

    # Average across all prompts
    avg_importance = {idx: importance_acc[idx] / len(prompts) for idx in range(total_layers)}

    print("✅ Layer importance measurement complete!")
    return avg_importance


def create_adaptive_config_simple(model, tokenizer, prompts):
    """Create OPTIMIZED adaptive config - ultra-simple format for efficient inference"""
    print("🚀 Creating optimized adaptive config...")

    # Step 1: Analyze model
    model_size_category = detect_model_size_category(model)
    total_layers = count_attention_layers_correctly(model)
    context_window = get_context_window_size(model, tokenizer)

    # Step 2: Measure importance
    print("📊 Measuring layer importance...")
    importance_scores = measure_layer_importance_simple(model, tokenizer, prompts)

    # Step 3: Create layers_by_importance (sorted list)
    print("🏆 Creating layers_by_importance list...")
    sorted_layers = sorted(importance_scores.items(), key=lambda x: x[1], reverse=True)
    layers_by_importance = [layer_idx for layer_idx, _ in sorted_layers]

    # Step 4: Calculate complexity thresholds using existing notebook functions
    print("🎯 Calculating complexity thresholds...")
    complexity_scores = GLOBAL_COMPLEXITIES
    complexity_thresholds = {}

    print("📊 Using notebook functions to get exact layer counts:")
    for score in complexity_scores:
        active_layers_count, _, _, _, _ = calculate_active_layers(
            total_layers, model_size_category, score
        )
        complexity_thresholds[score] = active_layers_count
        level_name = classify_complexity_level(score)
        print(f"   Score {score:3.1f} ({level_name:12}) → {active_layers_count:2d}/{total_layers} layers")

    # Step 5: Build OPTIMIZED config
    print("⚙️ Building optimized configuration...")
    config = {
        "model_info": {
            "name": getattr(model.config, '_name_or_path', 'unknown'),
            "total_parameters": f"{sum(p.numel() for p in model.parameters()) / 1e9:.2f}B",
            "size_category": model_size_category,
            "total_layers": total_layers,
            "context_window": context_window,
            "architecture": type(model).__name__
        },
        "layers_by_importance": layers_by_importance,
        "complexity_thresholds": complexity_thresholds
    }

    # Step 6: Save optimized config
    with open("adaptive_config.json", "w") as f:
        json.dump(config, f, indent=2)

    print("✅ OPTIMIZED adaptive_config.json created!")

    # Show optimized results
    print(f"📊 Model: {total_layers} layers, {model_size_category}")
    print(f"🏆 Layers by importance: {layers_by_importance[:5]}... (showing first 5)")
    print("🎯 Complexity thresholds:")
    for threshold, count in complexity_thresholds.items():
        percentage = (count / total_layers) * 100
        level = classify_complexity_level(threshold)
        print(f"   {threshold:3.1f} ({level:12}): {count:2d} layers ({percentage:4.1f}%)")

    print("\n🚀 ULTRA-EFFICIENT RUNTIME FORMAT:")
    print("   • No redundant data")
    print("   • Direct score → layer count lookup")
    print("   • Minimal JSON size")
    print("   • Fastest possible inference decisions")

    return config




🚀 CREATING ULTRA-EFFICIENT ADAPTIVE CONFIG
🚀 Creating optimized adaptive config...
🔍 Detected model size: 3.21B parameters
📊 Measuring layer importance...
📊 Processing 5 prompts across 28 layers...
   Processing prompt 1/5
   Processing prompt 2/5
   Processing prompt 3/5
   Processing prompt 4/5
   Processing prompt 5/5
✅ Layer importance measurement complete!
🏆 Creating layers_by_importance list...
🎯 Calculating complexity thresholds...
📊 Using notebook functions to get exact layer counts:
   Score 0.1 (trivial     ) → 16/28 layers
   Score 0.3 (simple      ) → 21/28 layers
   Score 0.5 (medium      ) → 25/28 layers
   Score 0.7 (complex     ) → 26/28 layers
   Score 0.9 (very_complex) → 28/28 layers
⚙️ Building optimized configuration...
✅ OPTIMIZED adaptive_config.json created!
📊 Model: 28 layers, 2B-5B
🏆 Layers by importance: [9, 8, 12, 10, 7]... (showing first 5)
🎯 Complexity thresholds:
   0.1 (trivial     ): 16 layers (57.1%)
   0.3 (simple      ): 21 layers (75.0%)
   0.5 (med

In [None]:
# =============================================================================
# SIMPLE EXECUTION - OPTIMIZED VERSION
# =============================================================================

print("🚀 CREATING ULTRA-EFFICIENT ADAPTIVE CONFIG")
print("=" * 50)

# Create the OPTIMIZED adaptive config using existing calibration_prompts
adaptive_config = create_adaptive_config_simple(model, tokenizer, calibration_prompts)

print(f"\n🎉 DONE! Optimized adaptive_config.json ready for AAB!")
print("=" * 50)
print("📋 What's in your optimized config:")
print("   ✅ model_info - Essential model metadata")
print("   ✅ layers_by_importance - Ordered list [most → least important]")
print("   ✅ complexity_thresholds - Direct score → layer count mapping")
print("   ❌ NO redundant data (groups, rankings, ratios)")
print("   ❌ NO unnecessary metadata")
print(f"\n💡 Runtime efficiency:")
print("   • JSON size: ~80% smaller")
print("   • Lookup time: O(1) direct access")
print("   • Memory usage: Minimal")
print("   • Perfect for production inference!")

In [21]:
adaptive_config

{'model_info': {'name': 'meta-llama/Llama-3.2-3B',
  'total_parameters': '3.21B',
  'size_category': '2B-5B',
  'total_layers': 28,
  'context_window': 131072,
  'architecture': 'LlamaForCausalLM'},
 'layers_by_importance': [9,
  8,
  12,
  10,
  7,
  0,
  6,
  27,
  13,
  18,
  11,
  5,
  14,
  3,
  4,
  2,
  15,
  1,
  21,
  17,
  25,
  24,
  16,
  22,
  20,
  26,
  23,
  19],
 'complexity_thresholds': {0.1: 16, 0.3: 21, 0.5: 25, 0.7: 26, 0.9: 28}}

# Test Pruned Models


Now, let's test the pruned model, which is a Llama-3.2-3B model where I have marked 4 Attention layers to be bypassed.

In [None]:
# Test the pruned model
pruned_model = pruned_model.to(device) #Move the model to GPU again.
generated = get_output(prompt, pruned_model, num_runs=2)
print(f"Generated text: {generated}")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Run 1:
Tokenization time: 0.46 ms
Generation time: 1251.89 ms
Decoding time: 0.23 ms
Total time: 1252.68 ms

Run 2:
Tokenization time: 0.59 ms
Generation time: 1245.58 ms
Decoding time: 0.21 ms
Total time: 1246.48 ms

Average time over 2 runs: 1249.48 ms
Generated text: ['Paris is the capital of France and/or world-wide fame for its beautiful skyline skyline-top-top-bottom-bottomside-side-side sidesidednessnessNESSNESSnessinessinessnessesivenessivenessfulnessnessfulnessfulnessinessesenessesinesssinessESness', 'Paris is the capital of France and/or world-wide fame for its beautiful skyline skyline-top-top-bottom-bottomside-side-side sidesidednessnessNESSNESSnessinessinessnessesivenessivenessfulnessnessfulnessfulnessinessesenessesinesssinessESness']



The execution of this second model is slightly faster than that of the base model, and the generated text is fairly accurate, although some repetition can be noticed towards the end of the sentence.

# Store the Model.


In [None]:
new_model_name = 'attnprun-llama-3.2-3B'
output_dir = './'+new_model_name
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

pruned_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
#new_config.save_pretrained(output_dir)
print(f"Pruned model saved to {output_dir}")

Pruned model saved to ./attnprun-llama-3.2-3B


In [None]:
# 2. Check that config contains layers to skip
from transformers import AutoConfig
config = AutoConfig.from_pretrained(output_dir)

if hasattr(config, "drop_attn_list"):
    print(f"drop_attn_list stored: {config.drop_attn_list}")
else:
    print("drop_attn_list isn't present.")


drop_attn_list stored: [14, 13, 12, 11]


## Upload to Hugging Face.

El proceso de subida de este modelo a Hugging es ligeramente más complejo por que se debe almacenar no tan solo el modelo en si, sino tambien el código de la función _bypass_single_layer. Que como recordarás es la función que se encarga de decidir cuando ejecutar o simplemente bypasear una capa de atención.  

In [None]:
from huggingface_hub import HfApi, upload_folder, whoami

In [None]:
# Step 1: Get your HF username from the current token
username = whoami()["name"]  # Returns a dict like {'name': 'your_username', 'email': ...}
username

'oopere'

In [None]:
# Step 2: Define repo name
repo_id = f"{username}/{new_model_name}"

In [None]:
# Step 3: Define path to your model
output_dir = "./"+new_model_name


The function must be saved in a .py file, but since this notebook runs on Colab, I’ve decided the best approach is to create a cell that generates the file to be uploaded.

The file contains the custom class PrunedLlamaForCausalLM, which extends Hugging Face’s LlamaForCausalLM.

This custom class calls the base constructor, ensuring that the model's configuration file includes the drop_attn_list, which specifies the layers that should be skipped.

The forward function is modified only for the layers that need to be skipped; the rest continue executing their standard forward function.


In [None]:
custom_model_code = '''
from transformers.models.llama.modeling_llama import LlamaForCausalLM

class PrunedLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        if not hasattr(config, "drop_attn_list"):
            config.drop_attn_list = []

        for idx in config.drop_attn_list:
            self._bypass_single_layer(idx)

    def _bypass_single_layer(self, layer_idx):
        """
        Modifies the specified layer's forward method so that attention is bypassed.
        """
        layer = self.model.layers[layer_idx]
        if not hasattr(layer.self_attn, "_original_forward"):
            layer.self_attn._original_forward = layer.self_attn.forward

        def new_attention_forward(self, hidden_states, attention_mask=None, position_ids=None,
                                  past_key_value=None, output_attentions=False, use_cache=False,
                                  **kwargs):
            if getattr(self, "layer_idx", -1) in self.config.drop_attn_list:
                if use_cache:
                    return hidden_states, None
                else:
                    return hidden_states, None
            return self._original_forward(hidden_states, attention_mask, position_ids,
                                          past_key_value, output_attentions, use_cache, **kwargs)

        layer.self_attn.layer_idx = layer_idx
        layer.self_attn.forward = new_attention_forward.__get__(layer.self_attn, type(layer.self_attn))

'''

# Define path and write the file
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "modeling_attnprun_llama.py"), "w") as f:
    f.write(custom_model_code.strip())

print("Custom model script modeling_attnprun_llama.py created successfully.")


Custom model script modeling_attnprun_llama.py created successfully.


Now the model's configuration file is updated by adding the `auto_map` field, which tells the Transformers library which class to use to construct the model: `modeling_attnprun_llama.PrunedLlamaForCausalLM.`


In [None]:
import json
import os

# Path to the config file
config_path = os.path.join(output_dir, "config.json")

# Load the existing config
with open(config_path, "r") as f:
    config = json.load(f)

# Add or update the auto_map section
config["auto_map"] = {
    "AutoModelForCausalLM": "modeling_attnprun_llama.PrunedLlamaForCausalLM"
}

# Optional: ensure the architecture field is aligned
config["architectures"] = ["PrunedLlamaForCausalLM"]

# Save the updated config
with open(config_path, "w") as f:
    json.dump(config, f, indent=2)

print("config.json updated with auto_map and architecture.")

config.json updated with auto_map and architecture.


Time to upload the folder containing the weights, the config file and the new function to HF.

In [None]:
# Step 4: Upload the folder to the Hub
upload_folder(
    folder_path=output_dir,
    path_in_repo="",  # Upload everything to root
    repo_id=repo_id,
    repo_type="model"
)

print(f"Model uploaded successfully to https://huggingface.co/{repo_id}")

No files have been modified since last commit. Skipping to prevent empty commit.


Model uploaded successfully to https://huggingface.co/oopere/attnprun-llama-3.2-3B


## Download model from Hugging Face.

In [None]:
import gc
del pruned_model
del tokenizer
del model

# 2. Libera la caché de la GPU
torch.cuda.empty_cache()
torch.cuda.ipc_collect()  # Opcional, ayuda en Colab

# 3. Forza recolección de basura en Python
gc.collect()

186

The model is downloaded normally from Hugging Face, but you must remember to set `trust_remote_code=True` since the model includes the custom code you previously created and uploaded.


In [None]:
pruned_model_name="oopere/attnprun-llama-3.2-3B"

model_hf = AutoModelForCausalLM.from_pretrained(
    pruned_model_name,
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(pruned_model_name)

In [None]:
model_hf = model_hf.to(device) #Move the model to GPU again.
generated = get_output(prompt, model_hf, num_runs=2)
print(f"Generated text: {generated}")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Run 1:
Tokenization time: 0.77 ms
Generation time: 1245.88 ms
Decoding time: 0.23 ms
Total time: 1246.97 ms

Run 2:
Tokenization time: 0.49 ms
Generation time: 1239.63 ms
Decoding time: 0.19 ms
Total time: 1240.41 ms

Average time over 2 runs: 1243.60 ms
Generated text: ['Paris is the capital of France and/or world-wide fame for its beautiful skyline skyline-top-top-bottom-bottomside-side-side sidesidednessnessNESSNESSnessinessinessnessesivenessivenessfulnessnessfulnessfulnessinessesenessesinesssinessESness', 'Paris is the capital of France and/or world-wide fame for its beautiful skyline skyline-top-top-bottom-bottomside-side-side sidesidednessnessNESSNESSnessinessinessnessesivenessivenessfulnessnessfulnessfulnessinessesenessesinesssinessESness']


# Conclusion.
Based on the findings in the paper and the results obtained, I believe this type of pruning may work better with larger models where attention layers tend to have redundancy.

Since this type of pruning does not alter the model's structure, it does not result in a reduction in its size or the memory required to load it. The main advantage of using this pruning approach is the reduction of computational load during inference, leading to a more efficient model with faster responses and lower resource consumption.

Unlike the original paper, which describes "removing" selected attention layers but provides limited implementation details, this implementation takes a transparent functional approach by explicitly overriding the `forward` method only in the specified layers. As a result, the model retains its full architecture and parameter set, but selectively skips computations at runtime. This makes the method reversible, modular, and fully compatible with the Hugging Face ecosystem using `trust_remote_code=True`. While both approaches achieve similar computational savings, this one emphasizes clarity, portability, and practical integration.


# Authors Note.

In addition to creating content like this notebook and offering it under the MIT license, I have also contributed to repositories such as those of Hugging Face and Google Gemini.

I am especially proud of my book: [Large Language Models: Apply and Implement Strategies for Large Language Models (Apress)(https://amzn.to/3DSepLb).

You can find it on both [Amazon](https://amzn.to/3DSepLb) and [Springer](https://link.springer.com/book/10.1007/979-8-8688-0515-8), where they often have good deals on the purchase price.

If you take a look and end up purchasing it, keep in mind that you can reach out with any questions via the Discussions section of this same repository or on any of my social media channels. I’ll do my best to respond as quickly as possible.