<a href="https://colab.research.google.com/github/peremartra/Large-Language-Model-Notebooks-Course/blob/inference-adaptative-attention-pruning/6-PRUNING/6_6b_Adaptive_Inference_Attention_Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<div>
    <h1>Large Language Models Projects</a></h1>
    <h3>Apply and Implement Strategies for Large Language Models</h3>
    <h2>Adaptative Attention Bypass</h2>
    <h3>Sometimes, not All Attention is needed</h3>
</div>

by [Pere Martra](https://www.linkedin.com/in/pere-martra/)

_______
Models: meta-llama/Llama-3.2

Colab Environment: GPU L4 for 3B Models

T4 for 1B Model.

Keys:
* Pruning
* Attention

References:
* [Resource-Efficient Transformer Pruning for Finetuning of Large Models](https://openaccess.thecvf.com/content/CVPR2024/html/Ilhan_Resource-Efficient_Transformer_Pruning_for_Finetuning_of_Large_Models_CVPR_2024_paper.html)

_______
**disclaimer: The pruning / knowledge distillation section has been created after the first edition of the book was published. They are not included in the book’s original content but are intended to supplement and expand on the topics covered.**

This is the unofficial repository for the book:
        <a href="https://amzn.to/4eanT1g"> <b>Large Language Models:</b> Apply and Implement Strategies for Large Language Models</a> (Apress).
        The book is based on the content of this repository, but the notebooks are being updated, and I am incorporating new examples and chapters.
        If you are looking for the official repository for the book, with the original notebooks, you should visit the
        <a href="https://github.com/Apress/Large-Language-Models-Projects">Apress repository</a>, where you can find all the notebooks in their original format as they appear in the book.

______
# Introduction
En este notebook se presenta  un enfoque innovador **Adaptive Attention Bypass, AAB**.

Permite al modelo decidir dinámicamente cuántas capas de atención utilizar en función de la complejidad de cada prompt de entrada. De esta manera, los prompts sencillos se procesan más rápido y consumen menos recursos, mientras que los prompts complejos mantienen la máxima calidad al usar todas las capas disponibles.

Actualmente la capa de atención es una de las que más redundancia tiene dentro de los modelos modernos, debido a que deben dar respuesta a ventanas de contexto desmesuradas.

Con AAB el modelo escogera para cada prompt el número de capas necesarias para realizar su labor. En el caso de chatbots es especialmente útil, ya que al principio de la conversación podria utilizarse un porcentaje muy bajo de capas, y a medida que el tamaño del prompt aumenmta con toda la conversación el modelo puede ir incorporando capas hasta llegar al 100%.

Este enfoque es compatible con modelos ya entrenados (no requiere reentrenamiento) y puede combinarse con técnicas clásicas de pruning estructurado para maximizar la eficiencia en producción.

A lo largo de este tutorial, veremos cómo configurar el modelo para que decida cuántas capas activar, cómo mide la importancia de sus capas y cómo omite la ejecución de las que no son necesarias para un prompt específico.

# Methodology.

La metodología implementada en este notebook sigue los siguientes pasos clave:

**Calibración de importancia de capas**:Se utilizan una serie de prompts para medir la importancia relativa de cada capa de atención del modelo, asignando una puntuación a cada una de ellas según su contribución al resultado.

**Cálculo de la complejidad del prompt**: Para cada prompt de entrada, se calcula de forma ultraligera un score de complejidad, configurable,  que combina:

* La longitud del prompt (número de tokens, normalizado).

* La diversidad semántica (varianza de los embeddings de entrada).

**Asignación adaptativa de capas activas**: Dependiendo del score de complejidad y del tamaño del modelo, se determina cuántas capas deben estar activas, utilizando una función continua parametrizada que evita saltos bruscos y permite una transición suave entre niveles de dificultad. Como mayor es el modelo más capas soporta que se bypaseen.

**Ejecución dinámica**: Durante la inferencia, sólo las capas de atención más importantes, considerando el score del prompt, son ejecutadas. El resto son “bypasseadas”, es decir, su computación se omite para ahorrar tiempo y recursos.

**Configuración flexible**: Todo el sistema se controla mediante un archivo de configuración (adaptive_config.json) que permite adaptar el método a distintos tamaños de modelo, dominios y requisitos de eficiencia.

# Principales usos y ventajas.
1. Optimización de modelos para sectores especificos.
2. Aceleración de la inferencia en producción.
3. Reducción consumo del modelo.
4. Chatbots y asistentes conversacionales.
5. Compatible con otras tecnicas como Quantization o Pruning estructurado.
6. No necesita recuperación mediante fine-Tuning o Knowledge Distillation.
______

# Install libraries & Configure variables.

In [1]:
!pip install -q torch==2.6.0
!pip install -q torchvision==0.21.0
!pip install -q transformers==4.51.3
!pip install -q datasets==3.6.0
!pip install -q lm-eval==0.4.8

!pip install hf_xet #To speed up downloads from HF.

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m79.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m89.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m54.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m40.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import logging
import math
import os
import sys
import shutil
from copy import deepcopy

import torch
import torch.nn.functional as F
import json
from transformers import AutoModelForCausalLM, AutoTokenizer


In [3]:
logging.basicConfig(level=logging.INFO)

# AAB Configuration.
En este apartado se definen los parámetros clave que controlan el comportamiento del bypass adaptativo de capas de atención (AAB). Estos parámetros permiten ajustar el sistema según el tamaño del modelo y la dificultad del prompt, logrando un equilibrio entre eficiencia y calidad en la respuesta.

**GLOBAL_COMPLEXITIES**: Una lista de puntuaciones de complejidad predefinidas. Estos valores se utilizarán más adelante, por ejemplo, para probar cómo responde el sistema a diferentes niveles de complejidad o durante la calibración.

**COMPLEXITY_WEIGHTS**: Un diccionario que asigna pesos a las diferentes métricas que usaremos para calcular la complejidad de un prompt. En esta primera versión de AAB se consideran el "conteo de tokens" (token_count) y la "varianza de los embeddings" (embedding_variance).

In [4]:
GLOBAL_COMPLEXITIES = [0.1, 0.3, 0.5, 0.7, 0.9]

COMPLEXITY_WEIGHTS = {
    "token_count": 0.75,
    "embedding_variance": 0.25
}

*ADAPTIVE_CONFIG*: Este es el diccionario principal que contiene la lógica de adaptación, en la que se decide el número de capas a bypasear dependiendo del tamaño del modelo.
Se divide en dos partes fundamentales:
* **model_size_ratios**: Define, para diferentes rangos de tamaño de modelo, cómo se calcula el número de capas activas. Para cada tamaño y nivel de complejidad se especifica un min_ratio y un scaling_factor, que indica cómo escalar el uso de capas adicionales en función del score de complejidad. La idea es que modelos más grandes pueden permitirse omitir un porcentaje mayor de capas en prompts sencillos.
* **complexity_levels**: Establece los umbrales para categorizar un prompt en uno de los cinco niveles de complejidad, "trivial", "simple", "medium", "complex", "very_complex",  basándose en su score de complejidad calculado, que va de 0.0 a 1.0.

In [5]:
# Tan solo han sido probadas las configuraciones de los modelos 1B y 3B.
ADAPTIVE_CONFIG = {
    # Model size-based ratios with proportional scaling to 100%
    "model_size_ratios": {
        "70B+": {
            "trivial": {"min_ratio": 0.15, "scaling_factor": 0.85},
            "simple": {"min_ratio": 0.35, "scaling_factor": 0.65},
            "medium": {"min_ratio": 0.55, "scaling_factor": 0.45},
            "complex": {"min_ratio": 0.75, "scaling_factor": 0.25},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        },
        "30B-70B": {
            "trivial": {"min_ratio": 0.25, "scaling_factor": 0.75},
            "simple": {"min_ratio": 0.40, "scaling_factor": 0.60},
            "medium": {"min_ratio": 0.60, "scaling_factor": 0.40},
            "complex": {"min_ratio": 0.80, "scaling_factor": 0.20},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        },
        "10B-30B": {
            "trivial": {"min_ratio": 0.30, "scaling_factor": 0.75},
            "simple": {"min_ratio": 0.45, "scaling_factor": 0.55},
            "medium": {"min_ratio": 0.65, "scaling_factor": 0.35},
            "complex": {"min_ratio": 0.82, "scaling_factor": 0.18},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        },
        "5B-10B": {
            "trivial": {"min_ratio": 0.45, "scaling_factor": 0.60},
            "simple": {"min_ratio": 0.55, "scaling_factor": 0.45},
            "medium": {"min_ratio": 0.75, "scaling_factor": 0.25},
            "complex": {"min_ratio": 0.87, "scaling_factor": 0.13},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        },
        "2B-5B": {
            "trivial": {"min_ratio": 0.80, "scaling_factor": 0.55},
            "simple": {"min_ratio": 0.87, "scaling_factor": 0.55},
            "medium": {"min_ratio": 0.90, "scaling_factor": 0.30},
            "complex": {"min_ratio": 0.95, "scaling_factor": 0.10},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        },
        "<2B": {
            "trivial": {"min_ratio": 0.85, "scaling_factor": 0.50},
            "simple": {"min_ratio": 0.90, "scaling_factor": 0.35},
            "medium": {"min_ratio": 0.93, "scaling_factor": 0.35},
            "complex": {"min_ratio": 0.97, "scaling_factor": 0.05},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        }
    },

    # 5-level complexity thresholds and descriptions
    "complexity_levels": {
        "trivial": {
            "range": [0.0, 0.2],
        },
        "simple": {
            "range": [0.2, 0.4],
        },
        "medium": {
            "range": [0.4, 0.6],
        },
        "complex": {
            "range": [0.6, 0.8],
        },
        "very_complex": {
            "range": [0.8, 1.0],
        }
    },
}


## Support & calculate functions
Una vez definidas las variables de configuración principales: GLOBAL_COMPLEXITIES, COMPLEXITY_WEIGHTS y ADAPTIVE_CONFIG.  Se crea un conjunto de funciones auxiliares para interpretar y aplicar esta configuración de manera efectiva.

Estas funciones nos permiten a interactuar con el modelo y usar los scores de complejidad para decidir cuántas capas de atención deben permanecer activas.


**detect_model_size_category**: Inspecciona el modelo cargado y, basándose en el número total de sus parámetros, lo clasifica en una de las categorías definidas en ADAPTIVE_CONFIG.

El código de la función se ha mantenido simple por motivos de comprensión en el notebook, pero se debe tener en cuenta que debe retornar exactamente el mismo nombre contenido en la variable ADAPTATIVE_CONFIG.  De no hacerlo  el sistema no detectara correctamente a que categoria pertenece el modelo, y no le aplicará los rangos definidos para su tamaño.

In [6]:
def detect_model_size_category(model):
    """
    Automatically detect model size category from model parameters
    """
    try:
        total_params = sum(p.numel() for p in model.parameters())
        size_billion = total_params / 1e9

        print(f"🔍 Detected model size: {size_billion:.2f}B parameters")

        if size_billion >= 70:
            return "70B+"
        elif size_billion >= 30:
            return "30B-70B"
        elif size_billion >= 10:
            return "10B-30B"
        elif size_billion >= 5:
            return "5B-10B"
        elif size_billion >= 2:
            return "2B-5B"
        else:
            return "<2B"

    except Exception as e:
        print(f"⚠️ Error detecting model size: {e}")
        return "1B-3B"


**count_attention_layers_correctly**: Para poder determinar un porcentaje de capas activas, primero necesitamos saber con precisión cuántas capas de atención contiene el modelo.

Esta función se encarga de contar estas capas buscando los módulos relevantes dentro de la arquitectura del modelo.

In [7]:
def count_attention_layers_correctly(model):
    """
    Correctly count attention layers by finding main decoder/transformer layers
    """
    # Method 1: Count main decoder layers directly (most reliable)
    decoder_layer_count = 0
    for name, module in model.named_modules():
        module_type = type(module).__name__
        # Look for main transformer/decoder layers
        if any(layer_type in module_type for layer_type in
               ['DecoderLayer', 'TransformerBlock', 'Block', 'Layer']) and \
           any(exclude not in module_type for exclude in
               ['Embedding', 'Norm', 'Linear', 'MLP', 'Attention']):
            # Make sure it's a numbered layer (e.g., layers.0, layers.1, etc.)
            if '.layers.' in name and name.count('.') == 2:  # e.g., "model.layers.0"
                decoder_layer_count += 1

    if decoder_layer_count > 0:
        return decoder_layer_count

    # Method 2: Use model config as fallback
    try:
        if hasattr(model, 'config'):
            config_attrs = ['num_hidden_layers', 'n_layer', 'num_layers', 'n_layers']
            for attr in config_attrs:
                if hasattr(model.config, attr):
                    return getattr(model.config, attr)
    except:
        pass

    # Method 3: Direct access to layers ModuleList
    try:
        if hasattr(model, 'model') and hasattr(model.model, 'layers'):
            return len(model.model.layers)
    except:
        pass

    return 16  # Conservative fallback

**classify_complexity_level**: Recibe el score de complejidad numérico,valor entre 0 y 1,que se calcula para cada prompt y lo asigna a uno de los niveles de complejidad predefinidos en ADAPTIVE_CONFIG.


In [8]:
def classify_complexity_level(complexity_score):
    """
    Classify complexity score into one of 5 levels

    Args:
        complexity_score (float): Complexity score (0.0-1.0)

    Returns:
        str: Complexity level ("trivial", "simple", "medium", "complex", "very_complex")
    """
    levels = ADAPTIVE_CONFIG["complexity_levels"]

    for level_name, level_config in levels.items():
        min_val, max_val = level_config["range"]
        if min_val <= complexity_score < max_val:
            return level_name

    # Handle edge case for exactly 1.0
    if complexity_score >= 0.8:
        return "very_complex"

    return "trivial"  # Fallback


**calculate_active_layers**: Integra la información de las funciones anteriores.

Utiliza el número total de capas del modelo, su categoría de tamaño y el score de complejidad del prompt para determinar exactamente cuántas capas de atención deben activarse.

Aplica los min_ratio y scaling_factor correspondientes, definidos en ADAPTIVE_CONFIG["model_size_ratios"],  para calcular este número, asegurando que el modelo adapte el número. decapas de atención activas de forma dinámica y según lo configurado.

In [9]:
def calculate_active_layers(total_layers, model_size_category, complexity_score):
    """
    Calculate number of active layers based on complexity and model size

    Args:
        total_layers (int): Total number of attention layers
        model_size_category (str): Model size category
        complexity_score (float): Complexity score (0.0-1.0)

    Returns:
        tuple: (active_layers_count, complexity_level, layer_groups_used, min_guaranteed, max_possible)
    """
    # Classify complexity level
    complexity_level = classify_complexity_level(complexity_score)

    # Get configuration for this model size and complexity
    config = ADAPTIVE_CONFIG["model_size_ratios"][model_size_category][complexity_level]
    min_ratio = config["min_ratio"]
    scaling_factor = config["scaling_factor"]

    # Calculate layer counts
    min_guaranteed = int(total_layers * min_ratio)
    remaining_layers = total_layers - min_guaranteed
    additional_layers = int(complexity_score * scaling_factor * remaining_layers)
    active_layers = min_guaranteed + additional_layers

    # Ensure we don't exceed total layers
    active_layers = min(active_layers, total_layers)
    max_possible = total_layers  # Always can reach 100%


    return active_layers, complexity_level,  min_guaranteed, max_possible


# Download & Study the Model.
Descargamos el modelo desde Hugging Face y estudiamos un poco su estructura.

Aunque AAB esta pensada para que sea aganostica de la estructura del modelo, este notebook tan solo se ha probado con dos modelos de la familia Llama: Llama-3.2-1B y Llama-3.2-3B.

In [10]:
# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [None]:
#model_name = 'meta-llama/Llama-3.2-1B'
model_name = 'meta-llama/Llama-3.2-3B'
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
#tokenizer.pad_token = tokenizer.eos_token  # Set pad token

## Study the structure.
* Llama-3.2-1B
```
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=2048, out_features=128256, bias=False)
)
```


The model follows the typical structure of modern Llama models, consisting of blocks made up of an Attention layer and an MLP layer with a GLU structure.

> If you want to see an example of how to perform pruning on the MLP layers of the model, you can check out the notebook:[Pruning Llama 3.2.](https://github.com/peremartra/Large-Language-Model-Notebooks-Course/blob/main/6-PRUNING/6_3_pruning_structured_llama3.2-1b_OK.ipynb) y leer el paper [Exploring GLU expansion ratios: Structured pruning in Llama-3.2 models](https://osf.io/preprints/osf/qgxea)


Since the layers form a block, the attention layer cannot be removed without also removing the accompanying MLP layer. For this reason, the decision was made to bypass their execution during inference.

The 1B model has 16 layers, as shown in the structure above, while the 3B model has 28 layers.


## Testing AAB configuration.
Este bloque de código  analiza el modelo cargado y simula cuántas capas de atención se activarían para diferentes niveles de complejidad de prompt, basándose en las funciones y configuración establecida anteriormente.

* Utiliza **count_attention_layers_correctly** para obtener el número total de capas de atención del modelo.

* **Llama a detect_model_size_category** para determinar la categoría de tamaño del modelo.

* Itera a través de una lista de scores de complejidad predefinidos para calcular las capas que estarian activas llamando a **calculate_active_layers**

* Imprime la información del modelo, y cuantas capas se activarian para los diferentes scores de complejidad.

El modelo llama-3.2-3B tendria entre 22 y 28 capas activas dependiendo de la complejidad del prompt.

In [12]:
# Test the configuration with clean, simplified output
if 'model' in locals():
    # Get model information with improved detection
    total_attention_layers = count_attention_layers_correctly(model)
    model_category = detect_model_size_category(model)

    print(f"\n Model Analysis:")
    print(f"   Attention layers: {total_attention_layers}")
    print(f"   Size category: {model_category}")
    print(f"   Architecture: {type(model).__name__}")

    # Show layer detection verification
    print(f"\n Layer Detection Verification:")
    decoder_layers = [name for name, module in model.named_modules()
                     if 'DecoderLayer' in type(module).__name__ and '.layers.' in name]
    print(f"   Found DecoderLayers: {len(decoder_layers)}")

    # Test all 5 complexity levels with simplified table
    test_complexities = GLOBAL_COMPLEXITIES

    print("\n Layer Activation by Complexity Level:")
    print("=" * 50)
    print(f"{'Level':<12} {'Active Layers':<15} {'Usage Ratio':<12}")
    print("-" * 50)

    for complexity in test_complexities:
        active, level, min_guaranteed, max_possible = calculate_active_layers(
            total_attention_layers, model_category, complexity
        )
        ratio = active / total_attention_layers

        print(f"{level.capitalize():<12} {active:<15} {ratio:<12.1%}")

    print(f"\n Summary for {model_category} model:")
    trivial_config = ADAPTIVE_CONFIG['model_size_ratios'][model_category]['trivial']
    trivial_min = int(total_attention_layers * trivial_config['min_ratio'])
    print(f"   • Range: {trivial_min}-{total_attention_layers} layers ({trivial_min/total_attention_layers:.1%}-100%)")
    print(f"   • All complexity levels can reach 100% layer usage")

else:
    print(" Load your model first to test the configuration")
    print("\nTo test, make sure you have:")
    print("1. model = ... (your loaded model)")
    print("2. tokenizer = ... (optional, your tokenizer)")

🔍 Detected model size: 3.21B parameters

 Model Analysis:
   Attention layers: 28
   Size category: 2B-5B
   Architecture: LlamaForCausalLM

 Layer Detection Verification:
   Found DecoderLayers: 28

 Layer Activation by Complexity Level:
Level        Active Layers   Usage Ratio 
--------------------------------------------------
Trivial      22              78.6%       
Simple       24              85.7%       
Medium       25              89.3%       
Complex      26              92.9%       
Very_complex 28              100.0%      

 Summary for 2B-5B model:
   • Range: 22-28 layers (78.6%-100%)
   • All complexity levels can reach 100% layer usage


## Inference function & Test Base Model

The `get_output` function is designed to generate text  and measure the time taken for different stages of the generation process.

It provides insights into the performance of the model and can be used to evaluate the efficiency of text generation.

In [13]:
import time

def get_output(prompt, model=model, tokenizer=tokenizer, num_runs=1, max_length=50):
    print(f"--- get_output ENTERED. Prompt (first 30 chars): '{prompt[:30]}...' ---") # New log

    total_time = 0
    generated_outputs = []

    for run in range(num_runs):
        # Start timing
        start_time = time.time()

        # Tokenization time
        token_start = time.time()
        inputs = tokenizer(prompt, return_tensors='pt').to(device)
        token_time = time.time() - token_start

        # Generation time
        gen_start = time.time()
        outputs = model.generate(
            inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=max_length,
            num_return_sequences=1,
            pad_token_id=tokenizer.pad_token_id,
            temperature=None,
            top_p=None,
            do_sample=False,  # Disable sampling
            num_beams=5,      # Use beam search
            early_stopping=True,  # Stop when end-of-sequence token is generated
            no_repeat_ngram_size=2  # Prevent repetition of 2-grams
        )
        gen_time = time.time() - gen_start

        # Decoding time
        decode_start = time.time()
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        decode_time = time.time() - decode_start

        # Total time for this run
        total_time += time.time() - start_time
        generated_outputs.append(generated)

        if num_runs > 1:
            print(f"\nRun {run + 1}:")
        print(f"Tokenization time: {token_time*1000:.2f} ms")
        print(f"Generation time: {gen_time*1000:.2f} ms")
        print(f"Decoding time: {decode_time*1000:.2f} ms")
        print(f"Total time: {(time.time() - start_time)*1000:.2f} ms")

    if num_runs > 1:
        avg_time = total_time / num_runs
        print(f"\nAverage time over {num_runs} runs: {avg_time*1000:.2f} ms")

    return generated_outputs[0] if num_runs == 1 else generated_outputs

In [14]:
# Test the original model
prompt = "Paris is the capital of"
generated = get_output(prompt, num_runs=2)
print(f"Generated text: {generated}")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


--- get_output ENTERED. Prompt (first 30 chars): 'Paris is the capital of...' ---


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Run 1:
Tokenization time: 3.43 ms
Generation time: 3905.31 ms
Decoding time: 0.32 ms
Total time: 3909.15 ms

Run 2:
Tokenization time: 0.57 ms
Generation time: 3005.81 ms
Decoding time: 0.20 ms
Total time: 3006.67 ms

Average time over 2 runs: 3457.83 ms
Generated text: ['Paris is the capital of France. It is located in the north-central part of the country, on the river Seine. The city has a population of over 2 million people, making it the largest city in France and the second-largest city', 'Paris is the capital of France. It is located in the north-central part of the country, on the river Seine. The city has a population of over 2 million people, making it the largest city in France and the second-largest city']


The text generation of the original model, as expected, works perfectly and returns a correct and meaningful sentence.

In [15]:
model.to("cpu")               # actual data moves ↙
torch.cuda.empty_cache()      # allocator drops cached blocks

# Model Calibration

Como ya se ha explicado al principio del notebook es imprescindible realizar una evaluación de que layers són las que modifican más la salida del modelo, para poder decidir cuales deben bypasearse.

El proceso de evaluación de capas se ha mantenido, expresamente, lo más simple posible, usando una sola metrica ya usada en el notebook [6_6_pruning_attention_layers.ipynb](https://github.com/peremartra/Large-Language-Model-Notebooks-Course/blob/main/6-PRUNING/6_6_pruning_attention_layers.ipynb) the cosine distance between the layer's input and output.

A diferencia del notebook anterior que se midió esta distancia usando tan solo un bprompt de ejemplo, en este se ha usado un conjunto de prompts con diferentes complejidades. [link text](https://)



In [16]:
# Using multiple prompts for calibration
# Using multiple prompts for calibration
calibration_prompts = [
    "Hi",
    "2+2=",
    "Hello.",
    "What is 2+2?",
    "What is the capital of France?",
    "Paris is the capital of "
    "Tell me a joke.",
    "Name the capital of Catalonia.",
    "Who wrote 'To Kill a Mockingbird'?",
    "Explain the basic principles of machine learning and how neural networks work.",
    "What are the main causes of climate change and what can individuals do to help?",
    "Summarize the plot of 'The Matrix' in one sentence.",
    "List three benefits of regular exercise.",
    "Compare and contrast the economic policies of Keynesian and Austrian schools of thought, analyzing their effectiveness during different historical periods and explaining which approach would be most suitable for addressing current global economic challenges.",
    "Design a comprehensive strategy for a small tech startup to compete against established giants like Google and Microsoft in the cloud computing market, considering market positioning, technological differentiation, partnerships, and funding requirements.",
    "The sky appears blue during the day, during the night you can see ",
    "Describe how a neural network learns from data.",
    "Write a detailed philosophical essay examining the ethical implications of artificial intelligence consciousness, incorporating perspectives from utilitarian, deontological, and virtue ethics frameworks, while addressing counterarguments and proposing a novel ethical framework for AI development that balances technological progress with human values and societal well-being.",
    "Develop a multidisciplinary research proposal that integrates quantum computing, biotechnology, and environmental science to address food security challenges in the context of climate change, including methodology, timeline, budget considerations, potential collaborations, risk assessment, and expected societal impact over the next two decades."
    "Given current economic trends, predict one challenge global markets may face in the next decade.",
    "Write a short poem about the experience of learning something new.",
    "Produce a 450-word technical tutorial that walks through implementing a transformer-based language model from scratch in NumPy, including positional encoding and scaled-dot-product attention."
    "As an expert in global macroeconomics, geopolitical risk assessment, and artificial intelligence ethics, write an in-depth policy advisory report for a coalition of G20 nations facing simultaneous systemic challenges, including post-pandemic inflation volatility, supply chain reconfiguration due to AI-driven automation, increasing regional instability in energy markets, and declining trust in democratic institutions. Your report should propose a coordinated strategy that balances fiscal stimulus with monetary restraint, integrates quantum-secure blockchain for supply chain transparency, and includes AI oversight frameworks aligned with both utilitarian and deontological ethical models. Additionally, evaluate how international institutions like the IMF and the World Bank could modernize their governance structures to reflect multipolar power dynamics, and assess the feasibility of adopting an intergovernmental AI alignment charter inspired by the Paris Agreement model. Your recommendations must be actionable, globally inclusive, and anticipate sociopolitical backlash from both populist and nationalist movements.",
    """
    Draft Integrated Strategic White-Paper for Inter-Agency Review—

Executive Overview:
This document synthesises cutting-edge research in climate science, planetary boundaries, quantum-enhanced computation, synthetic bio-manufacturing, neuro-symbolic artificial intelligence, behavioural economics, geopolitics, space-based energy infrastructure, and post-growth macro-finance. It is intended for cabinet-level policymakers across the G20, the African Union, and APEC, as well as multilateral lenders, sovereign wealth funds, philanthropic megadonors, and fourth-sector cooperative alliances.

Section 1 – Macroeconomic Volatility & Post-Pandemic Debt Overhang
1.1 Analyse the persistence of stagflationary pressures under divergent monetary regimes.
1.2 Model cascading default scenarios using agent-based stress tests that incorporate climate-induced supply-chain interruptions, semiconductor chokepoints in Taiwan and the Netherlands, and maritime bottlenecks in the Suez and Panama Canals.
1.3 Propose a menu of fiscal-monetary coordination instruments—helicopter stabilisation bonds, biodiversity-linked debt swaps, and anti-fragile carbon border adjustments—scaled to emerging-market liquidity traps.

Section 2 – Planetary Health & Regenerative Bio-Economy
2.1 Summarise findings from IPCC AR7 draft chapters on irreversible cryosphere tipping points.
2.2 Evaluate next-generation direct air capture catalysis that leverages metal-organic frameworks seeded by engineered extremophilic microbes.
2.3 Draft a governance blueprint for a Global Soil Microbiome Commons, incorporating indigenous data sovereignty protocols, fair-benefit-sharing algorithms, and quantum-secured telemetry for real-time biodiversity crediting.

Section 3 – Quantum-Classical Hybrid Infrastructure
3.1 Detail a phased roadmap for 1 000-qubit photonic processors coupled to error-mitigated superconducting qubits for combinatorial optimisation in logistics, drug-discovery, and lattice-QCD.
3.2 Define open-standard interfaces that allow sovereign cloud providers to interoperate with NATO-grade zero-trust enclaves and NIST-post-quantum cryptographic suites.
3.3 Recommend incentives for talent-mobility corridors bridging quantum start-up clusters in Toronto, Delft, Shenzhen, Sydney, and Kigali.

Section 4 – Neuro-Symbolic AI & Alignment Governance
4.1 Compare scaling-law extrapolations for transformers, mixture-of-experts, retrieval-augmented decoders, and recursive reasoning agents.
4.2 Propose a multi-layer safety stack: interpretability probes, causal influence diagrams, counterfactual policy evaluation, and cooperative inverse-reinforcement architectures monitored by open-weight red-team sandboxes.
4.3 Outline a treaty-grade AI Alignment Accord modelled after the Paris Agreement, featuring dynamic capability thresholds, compute-cluster registration, differential privacy audits, and a tiered sanctions regime enforced via programmable CBDCs.

Section 5 – Security, Geopolitics & Space-Based Energy
5.1 Assess escalation risks stemming from fractional-orbital bombardment systems, low-cost hypersonic glide vehicles, and AI-directed drone swarms.
5.2 Present techno-economic viability of kilometre-scale solar power satellites in sun-synchronous orbit, with microwave beaming arrays utilising adaptive phased-conjugate mirrors.
5.3 Recommend confidence-building measures: reciprocal on-site inspection, open telemetry APIs, catastrophe-bond insurance pools, and an International Orbital Commons Authority.

Section 6 – Behavioural & Cultural Dynamics
6.1 Integrate behavioural-nudge frameworks, narrative foresight, and social-network epistemic resilience analytics to counter disinformation loops.
6.2 Design outcome-oriented citizen deliberation platforms that leverage quadratic voting, verifiable credentials, and language-agnostic dialogue agents with embedded bias-mitigation layers.

Section 7 – Financing Mechanisms & Implementation Timeline
7.1 Catalogue blended-finance instruments: catalytic first-loss capital, sovereign green sukuk, resilience impact derivatives, and decentralized autonomous project bonds.
7.2 Map a ten-year Gantt chart with critical path analysis, specifying TRL-milestones, regulatory sandboxes, and adaptive procurement clauses.

Call to Action:
Conclude by articulating how cooperative mission-oriented investment, science-diplomacy trust architecture, and inclusive technology governance can converge to safeguard planetary health while enabling equitable prosperity within the safe-and-just operating space for humanity.
    """
]

Para medir la importancia de las diferentes layers se utiliza la función **measure_layer_importance_simple**.

Se ejecuta una pasada forward por cada uno de los prompts de calibración.

Mediante el uso de hooks se captura la entrada `q_proj` y salida `o_proj`y se calcula la similitud cosina entre ellas. Las capas con menor similitud entre entrada y salida són las que mas aportan.

In [17]:
def measure_layer_importance_simple(model, tokenizer, prompts):
    """Simple layer importance measurement - FIXED using original notebook pattern"""
    model.eval()
    device = next(model.parameters()).device
    total_layers = len(model.model.layers)

    # Accumulate importance scores across all prompts
    importance_acc = {idx: 0.0 for idx in range(total_layers)}

    print(f"📊 Processing {len(prompts)} prompts across {total_layers} layers...")

    for prompt_idx, prompt in enumerate(prompts):
        print(f"   Processing prompt {prompt_idx + 1}/{len(prompts)}")

        # Tokenize input (following original notebook pattern)
        inputs = tokenizer(prompt, return_tensors="pt").to(device)

        # Storage for this prompt's layer inputs/outputs
        layer_inputs = {}
        layer_outputs = {}

        # Create hooks (EXACTLY like the original function)
        def q_proj_input_hook(layer_idx):
            def _hook(module, module_input):
                # Handle tuple input (following original pattern)
                inp = module_input[0] if isinstance(module_input, tuple) else module_input
                layer_inputs[layer_idx] = inp.detach().clone()
            return _hook

        def o_proj_output_hook(layer_idx):
            def _hook(module, module_input, module_output):
                # Handle tuple output (following original pattern)
                out = module_output[0] if isinstance(module_output, tuple) else module_output
                layer_outputs[layer_idx] = out.detach().clone()
            return _hook

        # Register hooks for ALL layers (not just unpruned ones)
        handles = []
        for idx in range(total_layers):
            layer = model.model.layers[idx]
            handles.append(layer.self_attn.q_proj.register_forward_pre_hook(q_proj_input_hook(idx)))
            handles.append(layer.self_attn.o_proj.register_forward_hook(o_proj_output_hook(idx)))

        # Forward pass (following original pattern)
        with torch.no_grad():
            _ = model(**inputs)

        # Remove hooks (following original pattern)
        for h in handles:
            h.remove()

        # Calculate importance for each layer (EXACTLY like original)
        for idx in range(total_layers):
            if idx in layer_inputs and idx in layer_outputs:
                inp = layer_inputs[idx]
                out = layer_outputs[idx]

                # Flatten tensors (following original pattern)
                inp_flat = inp.view(inp.size(0), -1)
                out_flat = out.view(out.size(0), -1)

                # Calculate similarity and importance (following original pattern)
                similarity = F.cosine_similarity(inp_flat, out_flat, dim=1).mean().item()
                importance_score = 1 - similarity
                importance_acc[idx] += importance_score

    # Average across all prompts
    avg_importance = {idx: importance_acc[idx] / len(prompts) for idx in range(total_layers)}

    print("✅ Layer importance measurement complete!")
    return avg_importance



La función **create_adaptive_config_simple**, es el cerebro de la fase de calibración de nuestro sistema AAB.

Su misión es tomar el modelo, un conjunto de prompts de ejemplo y la configuración global que se ha definido, para generar y guardar un archivo de configuración detallado: "adaptive_config.json.

Este archivo será la "hoja de ruta" que el modelo consultará en tiempo real para decidir cuántas capas de atención activar.

In [18]:
def create_adaptive_config_simple(model, tokenizer, prompts):
    """Create OPTIMIZED adaptive config - ultra-simple format for efficient inference"""
    print("🚀 Creating optimized adaptive config...")

    # Step 1: Analyze model
    model_size_category = detect_model_size_category(model)
    total_layers = count_attention_layers_correctly(model)

    # Step 2: Measure importance
    print("📊 Measuring layer importance...")
    importance_scores = measure_layer_importance_simple(model, tokenizer, prompts)

    # Step 3: Create layers_by_importance (sorted list)
    print("🏆 Creating layers_by_importance list...")
    sorted_layers = sorted(importance_scores.items(), key=lambda x: x[1], reverse=True)
    layers_by_importance = [layer_idx for layer_idx, _ in sorted_layers]

    # Step 4: Calculate complexity thresholds using existing notebook functions
    print("🎯 Calculating complexity thresholds...")
    complexity_scores = GLOBAL_COMPLEXITIES
    complexity_thresholds = {}

    print("📊 Using notebook functions to get exact layer counts:")
    for score in complexity_scores:
        active_layers_count, _, _, _ = calculate_active_layers(
            total_layers, model_size_category, score
        )
        complexity_thresholds[score] = active_layers_count
        level_name = classify_complexity_level(score)
        print(f"   Score {score:3.1f} ({level_name:12}) → {active_layers_count:2d}/{total_layers} layers")

    # Step 5: Build OPTIMIZED config
    print("⚙️ Building optimized configuration...")
    config = {
        "model_info": {
            "name": getattr(model.config, '_name_or_path', 'unknown'),
            "total_parameters": f"{sum(p.numel() for p in model.parameters()) / 1e9:.2f}B",
            "size_category": model_size_category,
            "total_layers": total_layers,
            "architecture": type(model).__name__
        },
        "layers_by_importance": layers_by_importance,
        "complexity_thresholds": complexity_thresholds,
        "complexity_weights": COMPLEXITY_WEIGHTS
    }

    # Step 6: Save optimized config
    with open("adaptive_config.json", "w") as f:
        json.dump(config, f, indent=2)

    print("✅ OPTIMIZED adaptive_config.json created!")

    # Show optimized results
    print(f"📊 Model: {total_layers} layers, {model_size_category}")
    print(f"🏆 Layers by importance: {layers_by_importance[:5]}... (showing first 5)")
    print("🎯 Complexity thresholds:")
    for threshold, count in complexity_thresholds.items():
        percentage = (count / total_layers) * 100
        level = classify_complexity_level(threshold)
        print(f"   {threshold:3.1f} ({level:12}): {count:2d} layers ({percentage:4.1f}%)")

    print("\n🚀 ULTRA-EFFICIENT RUNTIME FORMAT:")

    return config


la variable adaptive_config contendrá el fichero de configuración que marca la importancia de las capas.

In [19]:
# Create the OPTIMIZED adaptive config using existing calibration_prompts
adaptive_config = create_adaptive_config_simple(model, tokenizer, calibration_prompts)

print(f"\n🎉 DONE! Optimized adaptive_config.json ready for AAB!")

🚀 Creating optimized adaptive config...
🔍 Detected model size: 3.21B parameters
📊 Measuring layer importance...
📊 Processing 21 prompts across 28 layers...
   Processing prompt 1/21
   Processing prompt 2/21
   Processing prompt 3/21
   Processing prompt 4/21
   Processing prompt 5/21
   Processing prompt 6/21
   Processing prompt 7/21
   Processing prompt 8/21
   Processing prompt 9/21
   Processing prompt 10/21
   Processing prompt 11/21
   Processing prompt 12/21
   Processing prompt 13/21
   Processing prompt 14/21
   Processing prompt 15/21
   Processing prompt 16/21
   Processing prompt 17/21
   Processing prompt 18/21
   Processing prompt 19/21
   Processing prompt 20/21
   Processing prompt 21/21
✅ Layer importance measurement complete!
🏆 Creating layers_by_importance list...
🎯 Calculating complexity thresholds...
📊 Using notebook functions to get exact layer counts:
   Score 0.1 (trivial     ) → 22/28 layers
   Score 0.3 (simple      ) → 24/28 layers
   Score 0.5 (medium      

In [20]:
adaptive_config

{'model_info': {'name': 'meta-llama/Llama-3.2-3B',
  'total_parameters': '3.21B',
  'size_category': '2B-5B',
  'total_layers': 28,
  'architecture': 'LlamaForCausalLM'},
 'layers_by_importance': [8,
  9,
  12,
  10,
  7,
  0,
  6,
  27,
  13,
  5,
  11,
  14,
  18,
  4,
  3,
  15,
  2,
  1,
  17,
  21,
  25,
  16,
  24,
  22,
  20,
  26,
  23,
  19],
 'complexity_thresholds': {0.1: 22, 0.3: 24, 0.5: 25, 0.7: 26, 0.9: 28},
 'complexity_weights': {'token_count': 0.75, 'embedding_variance': 0.25}}

# Test prompt complexity
Esta función es una de las más importantes de todo el notebook y una de las mas críticas. No tan solo es usado en el proceso de calibración, en el euqe se decide la importancia de las capas, sino que también debe usarse en tiempo de inferencia para clasificar el prompt dependiendo de su complejidad.

Se calcula un score de complejidad del prompt, entre 0 y 1. El calculo se realiza teniendo en cuenta dos variables: la longitud del prompt y la varianza de los embeddings del prompt.

El calculo se ha mantenido simple por que debe ejecutarse en la recepción de cada prompt y no debe añadir tiempo decomputación al modelo.

In [21]:
def analyze_prompt_complexity(prompts, config, model, tokenizer, verbose: bool = False):
    """
    Compute a complexity score in [0, 1] for each prompt.

    Parameters
    ----------
    prompts : list[str]
        The text prompts to score.
    config : dict
        adaptive_config.json already loaded as dict.
    model : transformers.PreTrainedModel
        The HF model (on CPU or GPU).
    tokenizer : transformers.PreTrainedTokenizer
        Matching tokenizer.
    verbose : bool
        If True, print a per-prompt breakdown.

    Returns
    -------
    list[tuple[str, float]]
        (prompt, complexity_score) for each input string.
    """

    # Get model size and device

    device = next(model.parameters()).device
    total_params = sum(p.numel() for p in model.parameters())
    size_billion = total_params / 1e9
    MIN_TOKENS = 4
    # Unified size adjustment factor
    # Small models (< 2B) get boost, large models (> 10B) get dampening
    size_factor = 1.0 + (2.0 - size_billion) * 0.1
    size_factor = max(0.5, min(2.0, size_factor))  # Clamp between 0.5 and 2.0


    # Length reference scaled by model size
    # Smaller models reach max complexity with shorter prompts
    base_length = 2000
    length_reference = base_length / size_factor
    variance_saturation = length_reference / 15

    # Get weights from config
    weights = config.get("complexity_weights", {
        "token_count": 0.65,
        "embedding_variance": 0.35
    })

    results = []

    for prompt in prompts:
        # Tokenize
        ids = tokenizer(prompt, return_tensors="pt")["input_ids"][0].to(device)
        n_tokens = ids.size(0)

        # 1. TOKEN SCORE - Simple logarithmic scaling
        # Maps token count to [0, 1] with smooth growth
        token_score = math.log1p(n_tokens) / math.log1p(length_reference)
        token_score = min(token_score * size_factor, 1.0)
        if n_tokens < MIN_TOKENS:
          dampening = (n_tokens / MIN_TOKENS) ** 2  # Quadratic dampening
          token_score = token_score * dampening

        # 2. EMBEDDING VARIANCE - Semantic diversity
        with torch.no_grad():
            emb = model.get_input_embeddings()(ids.unsqueeze(0)).squeeze(0).float()
            n = emb.size(0)

            if n < 3:
                # Too few tokens for meaningful variance
                emb_variance = 0.0
            else:
                # Normalize embeddings
                norm_emb = torch.nn.functional.normalize(emb, p=2, dim=1)

                # Compute pairwise cosine similarities
                sim_matrix = torch.matmul(norm_emb, norm_emb.t())

                # Get off-diagonal elements (exclude self-similarity)
                mask = ~torch.eye(n, dtype=bool, device=device)
                off_diag_sim = sim_matrix[mask]

                # Variance = 1 - mean similarity
                # Higher variance = more diverse embeddings
                emb_variance = 1.0 - off_diag_sim.mean().item()

                # Scale by length (longer prompts naturally have more variance)
                length_scale = min(n_tokens / variance_saturation, 1.0)
                emb_variance = emb_variance * length_scale

        # 3. FINAL SCORE - Weighted combination
        complexity_score = (
            weights["token_count"] * token_score +
            weights["embedding_variance"] * emb_variance
        )
        complexity_score = max(0.0, min(complexity_score, 1.0))

        if verbose:
            prompt_preview = (prompt[:57] + "…") if len(prompt) > 60 else prompt
            print(f"{prompt_preview:<60} | "
                  f"score={complexity_score:.3f} | "
                  f"tokens={n_tokens} "
                  f"[tok={token_score:.3f} var={emb_variance:.3f}]")

        results.append((prompt, round(complexity_score, 4)))

    return results

In [22]:
analyze_prompt_complexity(calibration_prompts, adaptive_config, model,  tokenizer, verbose=True)

Hi                                                           | score=0.023 | tokens=2 [tok=0.031 var=0.000]
2+2=                                                         | score=0.159 | tokens=5 [tok=0.204 var=0.026]
Hello.                                                       | score=0.072 | tokens=3 [tok=0.089 var=0.021]
What is 2+2?                                                 | score=0.199 | tokens=8 [tok=0.250 var=0.046]
What is the capital of France?                               | score=0.200 | tokens=8 [tok=0.250 var=0.050]
Paris is the capital of Tell me a joke.                      | score=0.229 | tokens=11 [tok=0.282 var=0.069]
Name the capital of Catalonia.                               | score=0.189 | tokens=7 [tok=0.236 var=0.046]
Who wrote 'To Kill a Mockingbird'?                           | score=0.230 | tokens=11 [tok=0.282 var=0.072]
Explain the basic principles of machine learning and how …   | score=0.260 | tokens=15 [tok=0.315 var=0.096]
What are the main causes 

[('Hi', 0.0234),
 ('2+2=', 0.1594),
 ('Hello.', 0.0716),
 ('What is 2+2?', 0.1987),
 ('What is the capital of France?', 0.1999),
 ('Paris is the capital of Tell me a joke.', 0.2292),
 ('Name the capital of Catalonia.', 0.1888),
 ("Who wrote 'To Kill a Mockingbird'?", 0.2298),
 ('Explain the basic principles of machine learning and how neural networks work.',
  0.2603),
 ('What are the main causes of climate change and what can individuals do to help?',
  0.2723),
 ("Summarize the plot of 'The Matrix' in one sentence.", 0.2601),
 ('List three benefits of regular exercise.', 0.2003),
 ('Compare and contrast the economic policies of Keynesian and Austrian schools of thought, analyzing their effectiveness during different historical periods and explaining which approach would be most suitable for addressing current global economic challenges.',
  0.3723),
 ('Design a comprehensive strategy for a small tech startup to compete against established giants like Google and Microsoft in the cloud

# AAB Implementation
En esta sección se definen las clases y funciones que modificarán el comportamiento del modelo para permitirle omitir dinámicamente capas de atención basándose en la complejidad del prompt calculada en tiempo real.

In [23]:
from typing import Dict, List, Tuple, Optional, Union
import logging

La clase **LayerActivationMask** es como un "panel de control" externo que decide y recuerda qué capas de atención del modelo deben trabajar y cuáles pueden descansar (ser omitidas) para un prompt determinado.  Su diseño busca mantener esta lógica de activación separada del código interno del modelo, lo que resulta en un sistema más limpio y modular.

Algunas de las funciones estan tan solo para obtener más información durante la ejecución en el notebook, pero no son necesarias para el código final.


In [24]:
class LayerActivationMask:
    """
    External mask system to control which attention layers are active at inference time.
    Keeps a clean separation from model internals and allows dynamic updates per prompt.
    """
    def __init__(self, total_layers: int):
        self.total_layers = total_layers
        # Boolean mask: True means this layer is active for the current inference
        self.active_mask = [True] * total_layers
        # The latest prompt complexity score (float between 0 and 1)
        self.current_complexity = None
        # How many layers are currently active
        self.current_active_count = total_layers

        # --- Debug and tracking variables ---
        # Detailed log of which layers were executed or bypassed for each inference
        self.execution_log = []
        # Unique ID for each inference pass (useful for debugging multiple calls)
        self.current_inference_id = 0
        # Sequence length tracking for special triggers (e.g., layer 0 activation)
        self.last_sequence_length = 0

    def update_for_prompt(self, active_layer_indices: List[int], complexity_score: float):
        """
        Update the active mask for the current prompt.
        Should be called before inference, after computing prompt complexity.
        """
        self.active_mask = [i in active_layer_indices for i in range(self.total_layers)]
        self.current_complexity = complexity_score
        self.current_active_count = len(active_layer_indices)
        # Reset the execution log for this new inference
        self.execution_log = []
        self.current_inference_id += 1

    def is_layer_active(self, layer_idx: int) -> bool:
        """
        Returns True if the given layer should be active for this inference.
        """
        return self.active_mask[layer_idx]

    def get_stats(self) -> Dict:
        """
        Returns a summary of the current mask status.
        Includes complexity score, number of active layers, and ratio.
        """
        return {
            'complexity_score': self.current_complexity,
            'active_layers': self.current_active_count,
            'total_layers': self.total_layers,
            'usage_ratio': self.current_active_count / self.total_layers if self.current_active_count else 0,
            'initialized': self.current_complexity is not None
        }

    def log_layer_execution(self, layer_idx: int, executed: bool):
        """
        (DEBUG) Log whether a layer was actually executed or bypassed in this inference pass.
        """
        self.execution_log.append({
            'inference_id': self.current_inference_id,
            'layer_idx': layer_idx,
            'executed': executed,
            'expected_active': self.active_mask[layer_idx]
        })

    def get_execution_stats(self) -> Dict:
        """
        (DEBUG) Return detailed statistics about which layers were executed or bypassed,
        and whether execution matched the expected mask.
        """
        if not self.execution_log:
            return {
                'inference_id': self.current_inference_id,
                'layers_executed': [],
                'layers_bypassed': [],
                'total_calls': 0,
                'execution_matches_mask': True
            }

        executed = [log['layer_idx'] for log in self.execution_log if log['executed']]
        bypassed = [log['layer_idx'] for log in self.execution_log if not log['executed']]

        # Check if execution matches what the mask specified
        execution_matches = True
        for log in self.execution_log:
            if log['executed'] != log['expected_active']:
                execution_matches = False
                break

        return {
            'inference_id': self.current_inference_id,
            'layers_executed': sorted(executed),
            'layers_bypassed': sorted(bypassed),
            'total_calls': len(self.execution_log),
            'execution_matches_mask': execution_matches,
            'expected_active': [i for i, active in enumerate(self.active_mask) if active],
            'expected_bypassed': [i for i, active in enumerate(self.active_mask) if not active]
        }


Aunque en el notebook tan solo se han realizado pruebas con modelos llama, AAB esta pensado para ser facilmente adaptable a otras familias de modelos.

Esta función permite identificar las principales familias y facilitar la posterior adaptación del código de otras funciones, siempre que sea necesario.

In [25]:
def detect_model_architecture(model) -> str:
    """
    Automatically detect model architecture for compatibility
    """
    model_class = model.__class__.__name__.lower()
    model_name = getattr(model.config, '_name_or_path', '').lower()

    if 'llama' in model_class or 'llama' in model_name:
        return 'llama'
    elif 'mistral' in model_class or 'mistral' in model_name:
        return 'mistral'
    elif 'gpt2' in model_class or 'gpt2' in model_name:
        return 'gpt2'
    else:
        # Default to generic transformer approach
        return 'generic'



In [26]:
def get_attention_layers(model, architecture: str) -> List:
    """
    Get attention layers based on architecture
    """
    if architecture in ['llama', 'mistral']:
        return model.model.layers
    elif architecture == 'gpt2':
        return model.transformer.h
    else:
        # Generic approach - try common patterns
        if hasattr(model, 'model') and hasattr(model.model, 'layers'):
            return model.model.layers
        elif hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
            return model.transformer.h
        else:
            raise ValueError(f"Cannot find attention layers for architecture: {architecture}")



In [27]:
def get_attention_module(layer, architecture: str):
    """
    Get the attention module from a layer based on architecture
    """
    if architecture in ['llama', 'mistral']:
        return layer.self_attn
    elif architecture == 'gpt2':
        return layer.attn
    else:
        # Generic approach
        if hasattr(layer, 'self_attn'):
            return layer.self_attn
        elif hasattr(layer, 'attn'):
            return layer.attn
        else:
            raise ValueError(f"Cannot find attention module for architecture: {architecture}")



Esta función calcula el score de complejidad para un único prompt, utilizando la misma lógica que la función de calibración **analyze_prompt_complexity**, pero optimizada para uso en tiempo real durante la inferencia.



In [28]:
def compute_prompt_complexity_runtime(prompt: str, model, tokenizer, config: Dict) -> float:
    """
    Computes a complexity score in [0, 1] for a single prompt, using the same
    core logic as analyze_prompt_complexity. Optimized for runtime inference.
    """
    device = next(model.parameters()).device

    # --- Efficiently get model parameters and derive calculation constants ---
    # Get size_billion from the pre-calculated config
    param_str = config["model_info"]["total_parameters"]  # e.g., "3.21B"
    size_billion = float(param_str.rstrip("B"))     # e.g., 3.21

    MIN_TOKENS = 4

    # Unified size adjustment factor (logic from analyze_prompt_complexity)
    # Small models (< 2B) get boost, large models (> 10B) get dampening
    size_factor = 1.0 + (2.0 - size_billion) * 0.1
    size_factor = max(0.5, min(2.0, size_factor))  # Clamp between 0.5 and 2.0

    # Length reference scaled by model size (logic from analyze_prompt_complexity)
    # Smaller models reach max complexity with shorter prompts
    base_length = 2000  # Using base_length from analyze_prompt_complexity
    length_reference = base_length / size_factor
    variance_saturation = length_reference / 15

    # Get weights from config (logic from analyze_prompt_complexity)
    weights = config.get("complexity_weights", {
        "token_count": 0.65,  # Default fallback
        "embedding_variance": 0.35  # Default fallback
    })

    # --- Process the single prompt ---
    # Tokenize
    ids = tokenizer(prompt, return_tensors="pt")["input_ids"][0].to(device)
    n_tokens = ids.size(0)

    # 1. TOKEN SCORE (logic from analyze_prompt_complexity)
    token_score = math.log1p(n_tokens) / math.log1p(length_reference)
    token_score = min(token_score * size_factor, 1.0)
    if n_tokens < MIN_TOKENS:
        dampening = (n_tokens / MIN_TOKENS) ** 2  # Quadratic dampening
        token_score = token_score * dampening

    # 2. EMBEDDING VARIANCE (logic from analyze_prompt_complexity)
    emb_variance = 0.0 # Default value
    with torch.no_grad():
        emb = model.get_input_embeddings()(ids.unsqueeze(0)).squeeze(0).float()
        n_emb_tokens = emb.size(0) # Use n_emb_tokens for clarity in this block

        if n_emb_tokens < 3:
            # Too few tokens for meaningful variance
            emb_variance = 0.0
        else:
            # Normalize embeddings
            norm_emb = torch.nn.functional.normalize(emb, p=2, dim=1)

            # Compute pairwise cosine similarities
            sim_matrix = torch.matmul(norm_emb, norm_emb.t())

            # Get off-diagonal elements (exclude self-similarity)
            # Create mask on the correct device
            mask = ~torch.eye(n_emb_tokens, dtype=bool, device=sim_matrix.device)
            off_diag_sim = sim_matrix[mask]

            if off_diag_sim.numel() > 0: # Ensure there are elements to mean
                emb_variance = 1.0 - off_diag_sim.mean().item()
            else: # Should not happen if n_emb_tokens >= 3 and mask is correct
                emb_variance = 0.0

            # Scale by length (longer prompts naturally have more variance)
            length_scale = min(n_tokens / variance_saturation, 1.0) # Use n_tokens from original prompt
            emb_variance = emb_variance * length_scale

    # 3. FINAL SCORE (logic from analyze_prompt_complexity)
    complexity_score = (
        weights["token_count"] * token_score +
        weights["embedding_variance"] * emb_variance
    )
    complexity_score = max(0.0, min(complexity_score, 1.0)) # Clamp for safety

    return complexity_score

Una vez se ha calculado el **complexity_score** para un prompt utilizando **compute_prompt_complexity_runtime**, el siguiente paso es decidir exactamente cuántas capas de atención deben activarse y, lo más importante, cuáles de ellas.

La función **get_active_layers_for_prompt** utiliza los umbrales de complejidad y la lista de capas ordenadas por importancia, almacenados en config,  para decidir las layers a ejecutar.

In [29]:
def get_active_layers_for_prompt(complexity_score: float, config: Dict) -> List[int]:
    """
    Use your pre-computed complexity_thresholds instead of recalculating.
    This respects your original calibration work exactly!
    """
    layers_by_importance = config["layers_by_importance"]
    complexity_thresholds = config["complexity_thresholds"]

    # Convert string keys to float and sort (EXACT logic from your original design)
    thresholds = [(float(k), v) for k, v in complexity_thresholds.items()]
    thresholds.sort()

    # Find the appropriate number of layers to activate (EXACT logic)
    num_layers_to_activate = thresholds[-1][1]  # Default to max

    for threshold, num_layers in thresholds:
        if complexity_score <= threshold:
            num_layers_to_activate = num_layers
            break

    # Return the most important N layers using your ranking (EXACT COPY)
    return layers_by_importance[:num_layers_to_activate]



Los metodos contenidos en **add_manual_complexity_methods**, se utilizan para equipar el modelo con un conjunto de herramientas que permiten probar y depurar manualmente el sistema AAB.

Estos métodos operan de forma independiente del sistema completamente automático que se activa durante la generación de texto normal, permitiendo realizar pruebas atomicas para analizar el comportamiento con la configuración creada, y poder probar como reaccionaria a prompts especificos.

Las he usado durante el desarrollo, para afinar tanto las funciones como la configuración y se mantiene por su valor informativo. Actualemte se usan en la parte final del notebook.


In [30]:
def add_manual_complexity_methods(model, tokenizer, config: Dict):
    """
    Add manual methods for complexity calculation and debugging.
    These work independently of the automatic system.
    """
    def manual_complexity_calculation(prompt: str) -> float:
        """Calculate exact prompt complexity manually"""
        return compute_prompt_complexity_runtime(prompt, model, tokenizer, config)

    def manual_mask_update(complexity_score: float):
        """Manually update the adaptive mask"""
        active_layers = get_active_layers_for_prompt(complexity_score, config)
        model._adaptive_mask.update_for_prompt(active_layers, complexity_score)
        return model._adaptive_mask.get_stats()

    def get_debug_info():
        """Get comprehensive debug information"""
        stats = model._adaptive_mask.get_stats()
        execution_stats = model._adaptive_mask.get_execution_stats()

        return {
            'mask_stats': stats,
            'execution_stats': execution_stats,
            'config_thresholds': config['complexity_thresholds'],
            'layers_by_importance': config['layers_by_importance'][:10]  # First 10
        }

    def test_prompt_processing(prompt: str, verbose: bool = True):
        """Test end-to-end prompt processing"""
        if verbose:
            print(f"🧪 Testing prompt: '{prompt[:50]}{'...' if len(prompt) > 50 else ''}'")

        # Step 1: Calculate complexity
        complexity = manual_complexity_calculation(prompt)
        if verbose:
            print(f"   Complexity: {complexity:.4f}")

        # Step 2: Update mask
        stats = manual_mask_update(complexity)
        if verbose:
            print(f"   Active layers: {stats['active_layers']}/{stats['total_layers']} "
                  f"({stats['usage_ratio']:.1%})")

        # Step 3: Simulate inference (tokenize)
        inputs = tokenizer(prompt, return_tensors='pt').to(next(model.parameters()).device)

        # Step 4: Test forward pass
        with torch.no_grad():
            result = model.forward(input_ids=inputs['input_ids'])

        # Step 5: Get execution stats
        exec_stats = model._adaptive_mask.get_execution_stats()
        if verbose:
            print(f"   Executed layers: {exec_stats['layers_executed']}")
            print(f"   Bypassed layers: {exec_stats['layers_bypassed']}")
            print(f"   Execution matches mask: {exec_stats['execution_matches_mask']}")

        return {
            'complexity': complexity,
            'mask_stats': stats,
            'execution_stats': exec_stats
        }

    # Add methods to model
    model.manual_complexity = manual_complexity_calculation
    model.manual_mask_update = manual_mask_update
    model.get_debug_info = get_debug_info
    model.test_prompt = test_prompt_processing

    return model

A continuación se encuentra una de las funciónes más importantes del notebook: modifica dinámicamente el método forward del modelo para que, cada vez que se procese un nuevo prompt (en la primera pasada de generación), el sistema:

* Calcule automáticamente la complejidad del prompt.

* Determine cuántas capas de atención deben estar activas (bypass adaptativo) en función de esa complejidad.

* Actualice la máscara adaptativa antes de ejecutar la inferencia real.

Así, el modelo adapta su eficiencia a cada prompt sin intervención manual, integrando AAB de forma **transparente** en el ciclo de inferencia.

Uno de los principales retos fue identificar cuando se producia la primera ejecución del prompt y no realizar el calculo. dela complejidad en los forwards recurrentes con. lacreación de nuevos tokens.

In [32]:
import traceback # For printing stack traces in exceptions

def add_automatic_complexity_computation(model, tokenizer):
    """
    Add automatic complexity computation to model's forward method.
    This will automatically update the adaptive mask when new prompts are processed.
    """
    if not hasattr(model, '_adaptive_mask') or not hasattr(model, '_adaptive_config'):
        # Ensure the model has been prepared by create_adaptive_model first
        print("ERROR: Model is not set up for AAB. Please call create_adaptive_model() first.") # User-friendly error
        raise ValueError("Model must be created with create_adaptive_model() first. Missing AAB attributes.")

    # For tutorial/debug purposes, show the state of model.forward before and after modification
    print(f"Modifying model.forward. Original: {model.forward}")

    # Store the original forward method if it hasn't been stored already
    if not hasattr(model, '_original_forward'):
        model._original_forward = model.forward
        print(f"   Original model.forward stored as _original_forward: {model._original_forward}")


    # Define the new forward method that will replace the original one
    def adaptive_model_forward(self, input_ids=None, **kwargs): # 'self' here is the model instance
        # Attempt to get input_ids, whether passed directly or in kwargs
        current_call_input_ids = input_ids
        if current_call_input_ids is None and 'input_ids' in kwargs:
            current_call_input_ids = kwargs['input_ids']

        # Essential AAB attributes must be present on the model
        if not hasattr(self, '_adaptive_config') or not hasattr(self, '_adaptive_mask'):
            print("ERROR: AAB attributes (_adaptive_config or _adaptive_mask) missing during forward pass!") # User-friendly error
            # If critical AAB attributes are missing but we have the original forward, try to use it.
            if hasattr(self, '_original_forward'):
                return self._original_forward(input_ids=input_ids, **kwargs)
            # If _original_forward is also missing, it's a critical setup error.
            raise RuntimeError("Critical AAB setup error: _original_forward and AAB attributes missing.")

        # --- Determine if this is the first effective pass for a new prompt ---
        # This is crucial because complexity should only be calculated once per prompt.
        # Generation involves multiple forward passes: one for the prompt, then one for each new token.
        past_key_values = kwargs.get('past_key_values')
        is_effectively_first_pass = False

        if past_key_values is None:
            # No past_key_values typically means it's the first pass with the initial prompt.
            is_effectively_first_pass = True
        elif hasattr(past_key_values, 'seen_tokens'):
            # Check for Hugging Face DynamicCache objects (used in model.generate())
            # 'seen_tokens' attribute indicates how many tokens are already in the KV cache.
            current_cache_seq_len = past_key_values.seen_tokens
            if current_cache_seq_len == 0:
                # If seen_tokens is 0, the cache is empty, indicating a new generation sequence.
                is_effectively_first_pass = True
        elif (isinstance(past_key_values, tuple) and
              len(past_key_values) > 0 and
              isinstance(past_key_values[0], tuple) and len(past_key_values[0]) > 0 and
              hasattr(past_key_values[0][0], 'shape') and
              # Check if the sequence length dimension of the key/value tensors in the cache is 0.
              # This typically corresponds to the second to last dimension (e.g., [batch_size, num_heads, sequence_length, head_dim]).
              # For Llama-like models, KV cache shape is often [bsz, num_heads, seq_len, head_dim].
              # We check the seq_len part of the first layer's key cache.
              past_key_values[0][0].shape[-2] == 0):
            # This handles standard tuple-based KV caches when they are empty.
            is_effectively_first_pass = True

        # Check if input_ids are valid for decoding a prompt
        can_get_prompt_for_complexity = (current_call_input_ids is not None and
                                         current_call_input_ids.ndim == 2 and # Expected [batch_size, sequence_length]
                                         current_call_input_ids.shape[0] > 0 and
                                         current_call_input_ids.shape[1] > 0)

        # --- Main AAB Logic: Calculate complexity and update mask on the first pass ---
        if is_effectively_first_pass and can_get_prompt_for_complexity:
            try:
                # Decode the prompt text from the first item in the batch
                prompt_text = tokenizer.decode(current_call_input_ids[0], skip_special_tokens=True)

                # Calculate complexity using the runtime function
                complexity_score = compute_prompt_complexity_runtime(
                    prompt_text, self, tokenizer, self._adaptive_config
                )
                # Determine which layers to activate based on the score and config
                active_layers = get_active_layers_for_prompt(complexity_score, self._adaptive_config)
                # Update the shared activation mask
                self._adaptive_mask.update_for_prompt(active_layers, complexity_score)

                stats = self._adaptive_mask.get_stats()
                # This is an informative print for the tutorial user to see AAB in action
                print(f"AAB Activated: Complexity {complexity_score:.3f} -> "
                      f"{stats['active_layers']}/{stats['total_layers']} layers active "
                      f"({stats['usage_ratio']:.1%})")
            except Exception as e:
                print(f"ERROR during AAB complexity calculation/mask update: {e}") # User-friendly error
                traceback.print_exc() # Print full traceback for debugging
        # else:
            # Not the first pass, or input_ids are not suitable for complexity calculation.
            # No AAB logic is run; mask remains as set by the last "first pass".
            # print("DEBUG: Not a first pass or invalid inputs for complexity calculation. Skipping AAB logic.")


        if not hasattr(self, '_original_forward'):
            # This should not happen if the setup logic at the beginning of
            # add_automatic_complexity_computation ran correctly.
            print("CRITICAL ERROR: _original_forward method is missing on model instance!") # User-friendly error
            raise RuntimeError("Cannot call missing _original_forward. Critical AAB setup error.")

        # Always call the original forward method to perform the actual model computation
        return self._original_forward(input_ids=input_ids, **kwargs)

    # --- Replace the model's original forward method with our adaptive_model_forward ---
    # The use of .__get__(model, type(model)) ensures that 'adaptive_model_forward'
    # is correctly bound as a method to the 'model' instance, so that 'self'
    # inside 'adaptive_model_forward' refers to the model object.
    model.forward = adaptive_model_forward.__get__(model, type(model))
    print(f"   New model.forward set to: {model.forward}")
    print("Automatic complexity computation hooked into model.forward.")
    return model

La función **create_adaptive_attention_forward** actúa como una "fábrica" que crea un nuevo método forward personalizado para cada capa de atención individual del nuevo modelo.

Este nuevo método, llamado  **adaptive_forward**, se especializa para cada capa. Utiliza una clausura para "recordar" dos datos cruciales específicos de la capa a la que se adjunta: su layer_idx (para saber si debe activarse según la máscara) y su original_forward (el comportamiento de atención original de esa capa, al que llamará si está activa).

Así, el adaptive_forward generado para cada capa:

* Consulta la LayerActivationMask usando su layer_idx único.
* Si la capa debe estar activa, ejecuta el original_forward que tenía guardado.
* Si está inactiva, omite los costosos cálculos de atención, pasando los hidden_states sin modificar y devolviendo (hidden_states, None) para mantener la compatibilidad de la salida.

Este mecanismo permite que cada capa decida si procesar o saltarse la atención, basándose en la LayerActivationMask que ya ha sido actualizada por la lógica principal del sistema AAB.

In [33]:
def create_adaptive_attention_forward(original_forward, layer_idx: int, mask: LayerActivationMask,
                                    architecture: str, model, tokenizer, config: Dict):
    """
    Create a new forward method that respects the activation mask.
    Layer 0 acts as automatic trigger for complexity calculation.
    """
    def adaptive_forward(self, hidden_states, *args, **kwargs):
        # Check if this layer should be active
        is_active = mask.is_layer_active(layer_idx)
        mask.log_layer_execution(layer_idx, is_active)

        if is_active:
            # Execute normal attention
            result = original_forward(hidden_states, *args, **kwargs)
            return result
        else:
            # Bypass attention
            print(f"--- Layer-level adaptive_forward PRINT: Bypassing Layer {layer_idx} ---") # layer_idx from closure
            use_cache_flag = kwargs.get('use_cache', False) # For logging
            print(f"Layer {layer_idx} bypass: use_cache={use_cache_flag}")

            # Always return a 2-tuple as per the ValueError (expected 2)
            # and similar to potentially working static pruner.
            print(f"Layer {layer_idx} bypass: Now returning (hidden_states, None) (2-tuple)")
            return (hidden_states, None)

    return adaptive_forward

In [36]:
def create_adaptive_model(model, config: Dict, verbose: bool = True):
    """
    Create an adaptive model that dynamically adjusts active layers based on prompt complexity.
    MODIFIED: Uses the new create_adaptive_attention_forward with layer 0 trigger.
    """
    # Detect architecture
    architecture = detect_model_architecture(model)
    if verbose:
        print(f"🔍  Detected architecture: {architecture}")

    # Get attention layers
    try:
        attention_layers = get_attention_layers(model, architecture)
        total_layers = len(attention_layers)
        if verbose:
            print(f"📊  Found {total_layers} attention layers")
    except Exception as e:
        raise ValueError(f"Failed to get attention layers: {e}")

    # Create activation mask
    mask = LayerActivationMask(total_layers)

    # Store references in model for access during inference
    model._adaptive_mask = mask
    model._adaptive_config = config
    model._adaptive_architecture = architecture

    # Modify attention layers (CHANGED to pass all parameters to create_adaptive_attention_forward)
    modified_layers = 0
    for layer_idx, layer in enumerate(attention_layers):
        try:
            attention_module = get_attention_module(layer, architecture)

            # Store original forward if not already stored
            if not hasattr(attention_module, '_original_forward'):
                attention_module._original_forward = attention_module.forward

            # Create adaptive forward method (CHANGED - now includes model, tokenizer, config)
            adaptive_forward = create_adaptive_attention_forward(
                attention_module._original_forward,
                layer_idx,
                mask,
                architecture,
                model,  # NEW
                tokenizer,  # NEW (note: will be passed when called)
                config  # NEW
            )

            # Replace forward method
            attention_module.forward = adaptive_forward.__get__(attention_module, type(attention_module))
            modified_layers += 1

        except Exception as e:
            logger.warning(f"Failed to modify layer {layer_idx}: {e}")

    if verbose:
        print(f"✅  Successfully modified {modified_layers}/{total_layers} attention layers")
        print(f"🎯  Complexity thresholds: {config['complexity_thresholds']}")
        print(f"⚡ Ready for adaptive inference with Layer 0 auto-trigger!")

    return model

In [35]:
def setup_adaptive_model_complete(model, tokenizer, config: Dict, verbose: bool = True):
    """
    Complete setup of adaptive model with automatic complexity computation.
    MODIFIED: Uses manual methods instead of automatic model.forward hooking.
    """
    if verbose:
        print("🚀  Setting up Adaptive Attention Bypass (AAB) system...")
        print("=" * 60)

    # Step 1: Create adaptive model structure
    adaptive_model = create_adaptive_model(model, config, verbose=verbose)


    # Step 2: Add automatic complexity computation to hook into model.forward
    adaptive_model = add_automatic_complexity_computation(adaptive_model, tokenizer)

    # Step 2b: Add manual complexity methods, for testing.
    adaptive_model = add_manual_complexity_methods(adaptive_model, tokenizer, config)

    if verbose:
        print("=" * 60)
        print("✅  AAB setup complete! Model ready for adaptive inference.")
        print(f"📈  Usage will vary from {min(config['complexity_thresholds'].values())}"
              f" to {max(config['complexity_thresholds'].values())} layers based on prompt complexity")
        print("\n🔧  Available methods:")
        print("   • model.test_prompt(prompt) - Test end-to-end processing")
        print("   • model.get_debug_info() - Get comprehensive debug info")
        print("   • model.manual_complexity(prompt) - Calculate complexity manually")
        print("   • model.get_adaptive_stats() - Get current mask stats")

    return adaptive_model

In [37]:
# Step 1: REPLACE your current setup function call
print("🔄 Creating new adaptive model with Layer 0 trigger...")
adaptive_model = setup_adaptive_model_complete(model, tokenizer, adaptive_config, verbose=True)


🔄 Creating new adaptive model with Layer 0 trigger...
🚀  Setting up Adaptive Attention Bypass (AAB) system...
🔍  Detected architecture: llama
📊  Found 28 attention layers
✅  Successfully modified 28/28 attention layers
🎯  Complexity thresholds: {0.1: 22, 0.3: 24, 0.5: 25, 0.7: 26, 0.9: 28}
⚡ Ready for adaptive inference with Layer 0 auto-trigger!
Modifying model.forward. Original: <bound method LlamaForCausalLM.forward of LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, o

In [38]:
print(f"ID of adaptive_model after setup: {id(adaptive_model)}")
print(f"adaptive_model.forward after setup: {adaptive_model.forward}")

ID of adaptive_model after setup: 134346889892496
adaptive_model.forward after setup: <bound method add_automatic_complexity_computation.<locals>.adaptive_model_forward of LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Llam

In [39]:
# Step 2: Move to device
adaptive_model.to(device)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((3072,), eps=1e-05)
    (rotary_emb

In [40]:
print(f"The model's forward method is: {adaptive_model.forward}")
if hasattr(adaptive_model, '_original_forward'):
    print(f"The model's _original_forward is: {adaptive_model._original_forward}")
else:
    print("The model does NOT have an _original_forward attribute.")

The model's forward method is: <bound method add_automatic_complexity_computation.<locals>.adaptive_model_forward of LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_la

In [41]:
prompt = "Paris is the capital of "
# The layer 0 trigger should work automatically during generate()
print(f"ID of adaptive_model before get_output: {id(adaptive_model)}")
print(f"adaptive_model.forward before get_output: {adaptive_model.forward}")
generated = get_output(prompt, adaptive_model, num_runs=1)
print(f"Generated text: {generated}")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.



3️⃣ Testing original prompt with get_output:
ID of adaptive_model before get_output: 134346889892496
adaptive_model.forward before get_output: <bound method add_automatic_complexity_computation.<locals>.adaptive_model_forward of LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          

In [42]:
prompt = "The sky appears blue during the day, during the night you can see wow it is totally different, and "
# The layer 0 trigger should work automatically during generate()
generated = get_output(prompt, adaptive_model, num_runs=1)
print(f"Generated text: {generated}")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



5 Testing a third prompt with get_output:
ID of adaptive_model before get_output: 134346889892496
adaptive_model.forward before get_output: <bound method add_automatic_complexity_computation.<locals>.adaptive_model_forward of LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (ac

## Manual testing.

In [43]:
# Test 1: Simple prompt (should use fewer layers)
print("\n1️⃣ Testing simple prompt:")
simple_result = adaptive_model.test_prompt("Hi", verbose=True)



1️⃣ Testing simple prompt:
🧪 Testing prompt: 'Hi'
   Complexity: 0.0234
   Active layers: 22/28 (78.6%)
--- adaptive_model_forward (nested in add_auto) PRINT: ENTERED ---
adaptive_model_forward (nested in add_auto) PRINT: 'input_ids' provided directly. Shape: torch.Size([1, 2])
adaptive_model_forward (nested in add_auto) PRINT: Initial past_key_values is None: True
adaptive_model_forward (nested in add_auto) PRINT: >>> past_key_values is None. is_effectively_first_pass = True.
adaptive_model_forward (nested in add_auto) PRINT: FINAL is_effectively_first_pass: True
adaptive_model_forward (nested in add_auto) PRINT: can_get_prompt_for_complexity: True
adaptive_model_forward (nested in add_auto) PRINT: Effective first pass & valid inputs. Entering main logic block.
adaptive_model_forward (nested in add_auto) PRINT: Decoded prompt for complexity (from batch item 0): 'Hi...'
🎯 PRINT Prompt complexity: 0.023 | Active layers: 22/28 (78.6%)
adaptive_model_forward (nested in add_auto) PRINT: M

In [44]:
# Test 1: Simple prompt (should use fewer layers)
print("\n1️⃣ Paris is the capital of ")
simple_result = adaptive_model.test_prompt("Paris is the capital of ", verbose=True)



1️⃣ Paris is the capital of 
🧪 Testing prompt: 'Paris is the capital of '
   Complexity: 0.1881
   Active layers: 24/28 (85.7%)
--- adaptive_model_forward (nested in add_auto) PRINT: ENTERED ---
adaptive_model_forward (nested in add_auto) PRINT: 'input_ids' provided directly. Shape: torch.Size([1, 7])
adaptive_model_forward (nested in add_auto) PRINT: Initial past_key_values is None: True
adaptive_model_forward (nested in add_auto) PRINT: >>> past_key_values is None. is_effectively_first_pass = True.
adaptive_model_forward (nested in add_auto) PRINT: FINAL is_effectively_first_pass: True
adaptive_model_forward (nested in add_auto) PRINT: can_get_prompt_for_complexity: True
adaptive_model_forward (nested in add_auto) PRINT: Effective first pass & valid inputs. Entering main logic block.
adaptive_model_forward (nested in add_auto) PRINT: Decoded prompt for complexity (from batch item 0): 'Paris is the capital of ...'
🎯 PRINT Prompt complexity: 0.188 | Active layers: 24/28 (85.7%)
adapti

In [45]:
# Test 2: Complex prompt (should use more layers)
print("\n2️⃣ Testing complex prompt:")
complex_result = adaptive_model.test_prompt(
    "Analyze the geopolitical implications of quantum computing on global cybersecurity frameworks",
    verbose=True
)


2️⃣ Testing complex prompt:
🧪 Testing prompt: 'Analyze the geopolitical implications of quantum c...'
   Complexity: 0.2461
   Active layers: 24/28 (85.7%)
--- adaptive_model_forward (nested in add_auto) PRINT: ENTERED ---
adaptive_model_forward (nested in add_auto) PRINT: 'input_ids' provided directly. Shape: torch.Size([1, 13])
adaptive_model_forward (nested in add_auto) PRINT: Initial past_key_values is None: True
adaptive_model_forward (nested in add_auto) PRINT: >>> past_key_values is None. is_effectively_first_pass = True.
adaptive_model_forward (nested in add_auto) PRINT: FINAL is_effectively_first_pass: True
adaptive_model_forward (nested in add_auto) PRINT: can_get_prompt_for_complexity: True
adaptive_model_forward (nested in add_auto) PRINT: Effective first pass & valid inputs. Entering main logic block.
adaptive_model_forward (nested in add_auto) PRINT: Decoded prompt for complexity (from batch item 0): 'Analyze the geopolitical implications of quantum c...'
🎯 PRINT Prompt 

In [46]:
# Test 4: Comprehensive debug info
print("\n4️⃣ Full debug info:")
debug_info = adaptive_model.get_debug_info()
print(f"   Mask stats: {debug_info['mask_stats']}")
print(f"   Total execution calls: {debug_info['execution_stats']['total_calls']}")
print(f"   Most important layers: {debug_info['layers_by_importance']}")


4️⃣ Full debug info:
   Mask stats: {'complexity_score': 0.2461123389241667, 'active_layers': 24, 'total_layers': 28, 'usage_ratio': 0.8571428571428571, 'initialized': True}
   Total execution calls: 28
   Most important layers: [8, 9, 12, 10, 7, 0, 6, 27, 13, 5]


In [47]:
# Test 6: Verify different complexity levels work
print("\n6️⃣ Testing different complexity levels:")
test_prompts = [
    ("Simple", "2+2="),
    ("Medium", "Explain machine learning basics"),
    ("Complex", "Write a comprehensive analysis of the economic implications of artificial intelligence")
]

for level, test_prompt in test_prompts:
    print(f"\n{level}: '{test_prompt[:50]}{'...' if len(test_prompt) > 50 else ''}'")
    result = adaptive_model.test_prompt(test_prompt, verbose=False)
    print(f"   Complexity: {result['complexity']:.3f}")
    print(f"   Layers: {result['mask_stats']['active_layers']}/{result['mask_stats']['total_layers']} "
          f"({result['mask_stats']['usage_ratio']:.1%})")
    print(f"   Executed: {len(result['execution_stats']['layers_executed'])}, "
          f"Bypassed: {len(result['execution_stats']['layers_bypassed'])}")


6️⃣ Testing different complexity levels:

Simple: '2+2='
--- adaptive_model_forward (nested in add_auto) PRINT: ENTERED ---
adaptive_model_forward (nested in add_auto) PRINT: 'input_ids' provided directly. Shape: torch.Size([1, 5])
adaptive_model_forward (nested in add_auto) PRINT: Initial past_key_values is None: True
adaptive_model_forward (nested in add_auto) PRINT: >>> past_key_values is None. is_effectively_first_pass = True.
adaptive_model_forward (nested in add_auto) PRINT: FINAL is_effectively_first_pass: True
adaptive_model_forward (nested in add_auto) PRINT: can_get_prompt_for_complexity: True
adaptive_model_forward (nested in add_auto) PRINT: Effective first pass & valid inputs. Entering main logic block.
adaptive_model_forward (nested in add_auto) PRINT: Decoded prompt for complexity (from batch item 0): '2+2=...'
🎯 PRINT Prompt complexity: 0.159 | Active layers: 24/28 (85.7%)
adaptive_model_forward (nested in add_auto) PRINT: Main logic block COMPLETED.
--- Layer-level ad

In [48]:
generated = get_output(prompt, adaptive_model, num_runs=2)
print(f"Generated text: {generated}")

--- get_output ENTERED. Prompt (first 30 chars): 'The sky appears blue during th...' ---


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


--- adaptive_model_forward (nested in add_auto) PRINT: ENTERED ---
adaptive_model_forward (nested in add_auto) PRINT: 'input_ids' provided directly. Shape: torch.Size([5, 16])
adaptive_model_forward (nested in add_auto) PRINT: Initial past_key_values is None: False
adaptive_model_forward (nested in add_auto) PRINT: Detected Cache object with 'seen_tokens'. Type: <class 'transformers.cache_utils.DynamicCache'>. seen_tokens: 0
adaptive_model_forward (nested in add_auto) PRINT: >>> Cache 'seen_tokens' is 0. is_effectively_first_pass = True.
adaptive_model_forward (nested in add_auto) PRINT: FINAL is_effectively_first_pass: True
adaptive_model_forward (nested in add_auto) PRINT: can_get_prompt_for_complexity: True
adaptive_model_forward (nested in add_auto) PRINT: Effective first pass & valid inputs. Entering main logic block.
adaptive_model_forward (nested in add_auto) PRINT: Decoded prompt for complexity (from batch item 0): 'The sky appears blue during the day, during the ni...'
🎯 PRINT

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Run 1:
Tokenization time: 30.54 ms
Generation time: 2305.21 ms
Decoding time: 0.20 ms
Total time: 2336.02 ms
--- adaptive_model_forward (nested in add_auto) PRINT: ENTERED ---
adaptive_model_forward (nested in add_auto) PRINT: 'input_ids' provided directly. Shape: torch.Size([5, 16])
adaptive_model_forward (nested in add_auto) PRINT: Initial past_key_values is None: False
adaptive_model_forward (nested in add_auto) PRINT: Detected Cache object with 'seen_tokens'. Type: <class 'transformers.cache_utils.DynamicCache'>. seen_tokens: 0
adaptive_model_forward (nested in add_auto) PRINT: >>> Cache 'seen_tokens' is 0. is_effectively_first_pass = True.
adaptive_model_forward (nested in add_auto) PRINT: FINAL is_effectively_first_pass: True
adaptive_model_forward (nested in add_auto) PRINT: can_get_prompt_for_complexity: True
adaptive_model_forward (nested in add_auto) PRINT: Effective first pass & valid inputs. Entering main logic block.
adaptive_model_forward (nested in add_auto) PRINT: Deco

### complementary tests

In [None]:
test_complexity = compute_prompt_complexity_runtime("Paris is the capital of", adaptive_model, tokenizer, adaptive_config)
print(f"Manual complexity test: {test_complexity}")

In [None]:
print("Original forward?", hasattr(adaptive_model, '_original_forward'))
print("Forward actual:", type(adaptive_model.forward))

In [None]:
# Llamar directamente al forward (sin generate)
inputs = tokenizer("Paris is the capital of", return_tensors='pt').to(device)
try:
    result = adaptive_model.forward(input_ids=inputs['input_ids'])
    print("Manual forward call worked")
except Exception as e:
    print(f"Manual forward failed: {e}")

In [None]:
# Ver si generate() usa forward() o algo diferente
print("Generate method:", adaptive_model.generate.__func__.__name__)
print("Model class:", type(adaptive_model).__name__)

In [None]:
# Después de la llamada manual anterior, verificar si se calculó complejidad
stats_after_manual = adaptive_model.get_adaptive_stats()
print("Stats after manual forward:", stats_after_manual)
print("Complexity calculated?", stats_after_manual['initialized'])

In [None]:
# Resetear el sistema y probar manual de nuevo
inputs = tokenizer("Paris is the capital of", return_tensors='pt').to(device)
print("=== Manual forward call ===")
result = adaptive_model.forward(input_ids=inputs['input_ids'])
print("=== End manual call ===")

In [None]:
inputs = tokenizer("Paris is the capital of", return_tensors='pt').to(device)
result = adaptive_model.forward(input_ids=inputs['input_ids'])

In [None]:
stats_after_manual = adaptive_model.get_adaptive_stats()
print("Complexity after manual:", stats_after_manual['complexity_score'])
print("Initialized?", stats_after_manual['initialized'])

In [None]:
# Resetear el sistema para el test
adaptive_model._adaptive_mask.current_complexity = None