<a href="https://colab.research.google.com/github/peremartra/Large-Language-Model-Notebooks-Course/blob/inference-adaptative-attention-pruning/6-PRUNING/6_6b_Adaptive_Inference_Attention_Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<div>
    <h1>Large Language Models Projects</a></h1>
    <h3>Apply and Implement Strategies for Large Language Models</h3>
    <h2>Pruning Attention Layers</h2>
    <h3>Not All Attention is needed</h3>
</div>

by [Pere Martra](https://www.linkedin.com/in/pere-martra/)

_______
Models: meta-llama/Llama-3.2

Colab Environment: GPU L4 for 3B Models

T4 for 1B Model.

Keys:
* Pruning
* Attention

References:
* [Resource-Efficient Transformer Pruning for Finetuning of Large Models](https://openaccess.thecvf.com/content/CVPR2024/html/Ilhan_Resource-Efficient_Transformer_Pruning_for_Finetuning_of_Large_Models_CVPR_2024_paper.html)

_______
**disclaimer: The pruning / knowledge distillation section has been created after the first edition of the book was published. They are not included in the book’s original content but are intended to supplement and expand on the topics covered.**

This is the unofficial repository for the book:
        <a href="https://amzn.to/4eanT1g"> <b>Large Language Models:</b> Apply and Implement Strategies for Large Language Models</a> (Apress).
        The book is based on the content of this repository, but the notebooks are being updated, and I am incorporating new examples and chapters.
        If you are looking for the official repository for the book, with the original notebooks, you should visit the
        <a href="https://github.com/Apress/Large-Language-Models-Projects">Apress repository</a>, where you can find all the notebooks in their original format as they appear in the book.

______
# Introduction


# Methodology.

______

# Install libraries & Configure variables.

In [1]:
!pip install -q torch==2.6.0
!pip install -q torchvision==0.21.0
!pip install -q transformers==4.51.3
!pip install -q datasets==3.6.0
!pip install -q lm-eval==0.4.8

!pip install hf_xet #To speed up downloads from HF.

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m90.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m88.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m45.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m39.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import logging
import math
import os
import sys
import shutil
from copy import deepcopy

import torch
import torch.nn.functional as F
import json
from transformers import AutoModelForCausalLM, AutoTokenizer


# Adaptative Configuration.

In [3]:
# =============================================================================
# ADAPTIVE ATTENTION BYPASS (AAB) CONFIGURATION - CORRECTED SCALING TO 100%
# =============================================================================

GLOBAL_COMPLEXITIES = [0.1, 0.3, 0.5, 0.7, 0.9]

COMPLEXITY_WEIGHTS = {
    "token_count": 0.65,
    "embedding_variance": 0.35
}

ADAPTIVE_CONFIG = {
    # Model size-based ratios with proportional scaling to 100%
    "model_size_ratios": {
        "70B+": {
            "trivial": {"min_ratio": 0.15, "scaling_factor": 0.85},
            "simple": {"min_ratio": 0.35, "scaling_factor": 0.65},
            "medium": {"min_ratio": 0.55, "scaling_factor": 0.45},
            "complex": {"min_ratio": 0.75, "scaling_factor": 0.25},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        },
        "30B-70B": {
            "trivial": {"min_ratio": 0.25, "scaling_factor": 0.75},
            "simple": {"min_ratio": 0.40, "scaling_factor": 0.60},
            "medium": {"min_ratio": 0.60, "scaling_factor": 0.40},
            "complex": {"min_ratio": 0.80, "scaling_factor": 0.20},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        },
        "10B-30B": {
            "trivial": {"min_ratio": 0.30, "scaling_factor": 0.75},
            "simple": {"min_ratio": 0.45, "scaling_factor": 0.55},
            "medium": {"min_ratio": 0.65, "scaling_factor": 0.35},
            "complex": {"min_ratio": 0.82, "scaling_factor": 0.18},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        },
        "5B-10B": {
            "trivial": {"min_ratio": 0.45, "scaling_factor": 0.60},
            "simple": {"min_ratio": 0.55, "scaling_factor": 0.45},
            "medium": {"min_ratio": 0.75, "scaling_factor": 0.25},
            "complex": {"min_ratio": 0.87, "scaling_factor": 0.13},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        },
        "2B-5B": {
            "trivial": {"min_ratio": 0.60, "scaling_factor": 0.40},
            "simple": {"min_ratio": 0.75, "scaling_factor": 0.20},
            "medium": {"min_ratio": 0.90, "scaling_factor": 0.20},
            "complex": {"min_ratio": 0.95, "scaling_factor": 0.10},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        },
        "<2B": {
            "trivial": {"min_ratio": 0.80, "scaling_factor": 0.20},
            "simple": {"min_ratio": 0.85, "scaling_factor": 0.15},
            "medium": {"min_ratio": 0.90, "scaling_factor": 0.10},
            "complex": {"min_ratio": 0.95, "scaling_factor": 0.05},
            "very_complex": {"min_ratio": 1.0, "scaling_factor": 0.0}
        }
    },

    # 5-level complexity thresholds and descriptions
    "complexity_levels": {
        "trivial": {
            "range": [0.0, 0.2],
        },
        "simple": {
            "range": [0.2, 0.4],
        },
        "medium": {
            "range": [0.4, 0.6],
        },
        "complex": {
            "range": [0.6, 0.8],
        },
        "very_complex": {
            "range": [0.8, 1.0],
        }
    },
}


## Support & calculate functions

In [4]:
def detect_model_size_category(model):
    """
    Used only for information.
    Automatically detect model size category from model parameters
    """
    try:
        total_params = sum(p.numel() for p in model.parameters())
        size_billion = total_params / 1e9

        print(f"🔍 Detected model size: {size_billion:.2f}B parameters")

        if size_billion >= 70:
            return "70B+"
        elif size_billion >= 30:
            return "30B-70B"
        elif size_billion >= 10:
            return "10B-30B"
        elif size_billion >= 5:
            return "5B-10B"
        elif size_billion >= 2:
            return "2B-5B"
        else:
            return "<2B"

    except Exception as e:
        print(f"⚠️ Error detecting model size: {e}")
        return "1B-3B"


def count_attention_layers_correctly(model):
    """
    Correctly count attention layers by finding main decoder/transformer layers
    """
    # Method 1: Count main decoder layers directly (most reliable)
    decoder_layer_count = 0
    for name, module in model.named_modules():
        module_type = type(module).__name__
        # Look for main transformer/decoder layers
        if any(layer_type in module_type for layer_type in
               ['DecoderLayer', 'TransformerBlock', 'Block', 'Layer']) and \
           any(exclude not in module_type for exclude in
               ['Embedding', 'Norm', 'Linear', 'MLP', 'Attention']):
            # Make sure it's a numbered layer (e.g., layers.0, layers.1, etc.)
            if '.layers.' in name and name.count('.') == 2:  # e.g., "model.layers.0"
                decoder_layer_count += 1

    if decoder_layer_count > 0:
        return decoder_layer_count

    # Method 2: Use model config as fallback
    try:
        if hasattr(model, 'config'):
            config_attrs = ['num_hidden_layers', 'n_layer', 'num_layers', 'n_layers']
            for attr in config_attrs:
                if hasattr(model.config, attr):
                    return getattr(model.config, attr)
    except:
        pass

    # Method 3: Direct access to layers ModuleList
    try:
        if hasattr(model, 'model') and hasattr(model.model, 'layers'):
            return len(model.model.layers)
    except:
        pass

    return 16  # Conservative fallback


def classify_complexity_level(complexity_score):
    """
    Classify complexity score into one of 5 levels

    Args:
        complexity_score (float): Complexity score (0.0-1.0)

    Returns:
        str: Complexity level ("trivial", "simple", "medium", "complex", "very_complex")
    """
    levels = ADAPTIVE_CONFIG["complexity_levels"]

    for level_name, level_config in levels.items():
        min_val, max_val = level_config["range"]
        if min_val <= complexity_score < max_val:
            return level_name

    # Handle edge case for exactly 1.0
    if complexity_score >= 0.8:
        return "very_complex"

    return "trivial"  # Fallback


def calculate_active_layers(total_layers, model_size_category, complexity_score):
    """
    Calculate number of active layers based on complexity and model size

    Args:
        total_layers (int): Total number of attention layers
        model_size_category (str): Model size category
        complexity_score (float): Complexity score (0.0-1.0)

    Returns:
        tuple: (active_layers_count, complexity_level, layer_groups_used, min_guaranteed, max_possible)
    """
    # Classify complexity level
    complexity_level = classify_complexity_level(complexity_score)

    # Get configuration for this model size and complexity
    config = ADAPTIVE_CONFIG["model_size_ratios"][model_size_category][complexity_level]
    min_ratio = config["min_ratio"]
    scaling_factor = config["scaling_factor"]

    # Calculate layer counts
    min_guaranteed = int(total_layers * min_ratio)
    remaining_layers = total_layers - min_guaranteed
    additional_layers = int(complexity_score * scaling_factor * remaining_layers)
    active_layers = min_guaranteed + additional_layers

    # Ensure we don't exceed total layers
    active_layers = min(active_layers, total_layers)
    max_possible = total_layers  # Always can reach 100%


    return active_layers, complexity_level,  min_guaranteed, max_possible



print("✅ Enhanced 5-level calculation functions loaded successfully!")

✅ Enhanced 5-level calculation functions loaded successfully!


# Download the Model.

In [5]:
# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [None]:
#model_name = 'meta-llama/Llama-3.2-1B'
model_name = 'meta-llama/Llama-3.2-3B'
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
#tokenizer.pad_token = tokenizer.eos_token  # Set pad token

## Test

In [7]:
# Test the configuration with clean, simplified output
if 'model' in locals():
    # Get model information with improved detection
    total_attention_layers = count_attention_layers_correctly(model)
    model_category = detect_model_size_category(model)

    print(f"\n🏗️ Model Analysis:")
    print(f"   Attention layers: {total_attention_layers}")
    print(f"   Size category: {model_category}")
    print(f"   Architecture: {type(model).__name__}")

    # Show layer detection verification
    print(f"\n🔍 Layer Detection Verification:")
    decoder_layers = [name for name, module in model.named_modules()
                     if 'DecoderLayer' in type(module).__name__ and '.layers.' in name]
    print(f"   Found DecoderLayers: {len(decoder_layers)}")

    # Test all 5 complexity levels with simplified table
    test_complexities = GLOBAL_COMPLEXITIES

    print("\n🧪 Layer Activation by Complexity Level:")
    print("=" * 50)
    print(f"{'Level':<12} {'Active Layers':<15} {'Usage Ratio':<12}")
    print("-" * 50)

    for complexity in test_complexities:
        active, level, min_guaranteed, max_possible = calculate_active_layers(
            total_attention_layers, model_category, complexity
        )
        ratio = active / total_attention_layers

        print(f"{level.capitalize():<12} {active:<15} {ratio:<12.1%}")

    print(f"\n📊 Summary for {model_category} model:")
    trivial_config = ADAPTIVE_CONFIG['model_size_ratios'][model_category]['trivial']
    trivial_min = int(total_attention_layers * trivial_config['min_ratio'])
    print(f"   • Range: {trivial_min}-{total_attention_layers} layers ({trivial_min/total_attention_layers:.1%}-100%)")
    print(f"   • All complexity levels can reach 100% layer usage")

else:
    print("⏳ Load your model first to test the configuration")
    print("\nTo test, make sure you have:")
    print("1. model = ... (your loaded model)")
    print("2. tokenizer = ... (optional, your tokenizer)")

🔍 Detected model size: 3.21B parameters

🏗️ Model Analysis:
   Attention layers: 28
   Size category: 2B-5B
   Architecture: LlamaForCausalLM

🔍 Layer Detection Verification:
   Found DecoderLayers: 28

🧪 Layer Activation by Complexity Level:
Level        Active Layers   Usage Ratio 
--------------------------------------------------
Trivial      16              57.1%       
Simple       21              75.0%       
Medium       25              89.3%       
Complex      26              92.9%       
Very_complex 28              100.0%      

📊 Summary for 2B-5B model:
   • Range: 16-28 layers (57.1%-100%)
   • All complexity levels can reach 100% layer usage


## Study the structure.
* Llama-3.2-1B
```
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=2048, out_features=128256, bias=False)
)
```


The model follows the typical structure of modern Llama models, consisting of blocks made up of an Attention layer and an MLP layer with a GLU structure.

> If you want to see an example of how to perform pruning on the MLP layers of the model, you can check out the notebook:[Pruning Llama 3.2.](https://github.com/peremartra/Large-Language-Model-Notebooks-Course/blob/main/6-PRUNING/6_3_pruning_structured_llama3.2-1b_OK.ipynb) y leer el paper [Exploring GLU expansion ratios: Structured pruning in Llama-3.2 models](https://osf.io/preprints/osf/qgxea)


Since the layers form a block, the attention layer cannot be removed without also removing the accompanying MLP layer. For this reason, the decision was made to bypass their execution during inference.

The 1B model has 16 layers, as shown in the structure above, while the 3B model has 28 layers.


# Inference function & Test Base Model

The `get_output` function is designed to generate text  and measure the time taken for different stages of the generation process.

It provides insights into the performance of the model and can be used to evaluate the efficiency of text generation.

In [8]:
import time

def get_output(prompt, model=model, tokenizer=tokenizer, num_runs=1, max_length=50):
    total_time = 0
    generated_outputs = []

    for run in range(num_runs):
        # Start timing
        start_time = time.time()

        # Tokenization time
        token_start = time.time()
        inputs = tokenizer(prompt, return_tensors='pt').to(device)
        token_time = time.time() - token_start

        # Generation time
        gen_start = time.time()
        outputs = model.generate(
            inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=max_length,
            num_return_sequences=1,
            pad_token_id=tokenizer.pad_token_id,
            temperature=None,
            top_p=None,
            do_sample=False,  # Disable sampling
            num_beams=5,      # Use beam search
            early_stopping=True,  # Stop when end-of-sequence token is generated
            no_repeat_ngram_size=2  # Prevent repetition of 2-grams
        )
        gen_time = time.time() - gen_start

        # Decoding time
        decode_start = time.time()
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        decode_time = time.time() - decode_start

        # Total time for this run
        total_time += time.time() - start_time
        generated_outputs.append(generated)

        if num_runs > 1:
            print(f"\nRun {run + 1}:")
        print(f"Tokenization time: {token_time*1000:.2f} ms")
        print(f"Generation time: {gen_time*1000:.2f} ms")
        print(f"Decoding time: {decode_time*1000:.2f} ms")
        print(f"Total time: {(time.time() - start_time)*1000:.2f} ms")

    if num_runs > 1:
        avg_time = total_time / num_runs
        print(f"\nAverage time over {num_runs} runs: {avg_time*1000:.2f} ms")

    return generated_outputs[0] if num_runs == 1 else generated_outputs

In [9]:
# Test the original model
prompt = "Paris is the capital of"
generated = get_output(prompt, num_runs=2)
print(f"Generated text: {generated}")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Run 1:
Tokenization time: 1.66 ms
Generation time: 3947.34 ms
Decoding time: 0.35 ms
Total time: 3949.46 ms

Run 2:
Tokenization time: 0.61 ms
Generation time: 3031.68 ms
Decoding time: 0.23 ms
Total time: 3032.61 ms

Average time over 2 runs: 3490.94 ms
Generated text: ['Paris is the capital of France. It is located in the north-central part of the country, on the river Seine. The city has a population of over 2 million people, making it the largest city in France and the second-largest city', 'Paris is the capital of France. It is located in the north-central part of the country, on the river Seine. The city has a population of over 2 million people, making it the largest city in France and the second-largest city']


The text generation of the original model, as expected, works perfectly and returns a correct and meaningful sentence.

In [10]:
model.to("cpu")               # actual data moves ↙
torch.cuda.empty_cache()      # allocator drops cached blocks

# Pruning the Model.

In [11]:
#import torch
#import torch.nn as nn
#from torch.nn import functional as F
#from copy import deepcopy

## Execute Pruning.

**Disclaimer**

I'm using a single illustrative prompt so that the code path is easy to follow. In any research or production setting you must feed hundreds or thousands of diverse prompts before deciding which layers to deactivate

In [12]:
# Using multiple prompts for calibration
calibration_prompts = [
    "Hi I'm a sample text, used to calculate the cosine difference between input and output.",
    "The quick brown fox jumps over the lazy dog.",
    "Machine learning models can be optimized through various techniques, explain the principals.",
    "2+2=",
    "What is the meaning of life, the universe, and everything?"
]

In [13]:
# =============================================================================
# SIMPLE AAB CALIBRATION - MINIMAL CODE
# =============================================================================
def measure_layer_importance_simple(model, tokenizer, prompts):
    """Simple layer importance measurement - FIXED using original notebook pattern"""
    model.eval()
    device = next(model.parameters()).device
    total_layers = len(model.model.layers)

    # Accumulate importance scores across all prompts
    importance_acc = {idx: 0.0 for idx in range(total_layers)}

    print(f"📊 Processing {len(prompts)} prompts across {total_layers} layers...")

    for prompt_idx, prompt in enumerate(prompts):
        print(f"   Processing prompt {prompt_idx + 1}/{len(prompts)}")

        # Tokenize input (following original notebook pattern)
        inputs = tokenizer(prompt, return_tensors="pt").to(device)

        # Storage for this prompt's layer inputs/outputs
        layer_inputs = {}
        layer_outputs = {}

        # Create hooks (EXACTLY like the original function)
        def q_proj_input_hook(layer_idx):
            def _hook(module, module_input):
                # Handle tuple input (following original pattern)
                inp = module_input[0] if isinstance(module_input, tuple) else module_input
                layer_inputs[layer_idx] = inp.detach().clone()
            return _hook

        def o_proj_output_hook(layer_idx):
            def _hook(module, module_input, module_output):
                # Handle tuple output (following original pattern)
                out = module_output[0] if isinstance(module_output, tuple) else module_output
                layer_outputs[layer_idx] = out.detach().clone()
            return _hook

        # Register hooks for ALL layers (not just unpruned ones)
        handles = []
        for idx in range(total_layers):
            layer = model.model.layers[idx]
            handles.append(layer.self_attn.q_proj.register_forward_pre_hook(q_proj_input_hook(idx)))
            handles.append(layer.self_attn.o_proj.register_forward_hook(o_proj_output_hook(idx)))

        # Forward pass (following original pattern)
        with torch.no_grad():
            _ = model(**inputs)

        # Remove hooks (following original pattern)
        for h in handles:
            h.remove()

        # Calculate importance for each layer (EXACTLY like original)
        for idx in range(total_layers):
            if idx in layer_inputs and idx in layer_outputs:
                inp = layer_inputs[idx]
                out = layer_outputs[idx]

                # Flatten tensors (following original pattern)
                inp_flat = inp.view(inp.size(0), -1)
                out_flat = out.view(out.size(0), -1)

                # Calculate similarity and importance (following original pattern)
                similarity = F.cosine_similarity(inp_flat, out_flat, dim=1).mean().item()
                importance_score = 1 - similarity
                importance_acc[idx] += importance_score

    # Average across all prompts
    avg_importance = {idx: importance_acc[idx] / len(prompts) for idx in range(total_layers)}

    print("✅ Layer importance measurement complete!")
    return avg_importance


def create_adaptive_config_simple(model, tokenizer, prompts):
    """Create OPTIMIZED adaptive config - ultra-simple format for efficient inference"""
    print("🚀 Creating optimized adaptive config...")

    # Step 1: Analyze model
    model_size_category = detect_model_size_category(model)
    total_layers = count_attention_layers_correctly(model)

    # Step 2: Measure importance
    print("📊 Measuring layer importance...")
    importance_scores = measure_layer_importance_simple(model, tokenizer, prompts)

    # Step 3: Create layers_by_importance (sorted list)
    print("🏆 Creating layers_by_importance list...")
    sorted_layers = sorted(importance_scores.items(), key=lambda x: x[1], reverse=True)
    layers_by_importance = [layer_idx for layer_idx, _ in sorted_layers]

    # Step 4: Calculate complexity thresholds using existing notebook functions
    print("🎯 Calculating complexity thresholds...")
    complexity_scores = GLOBAL_COMPLEXITIES
    complexity_thresholds = {}

    print("📊 Using notebook functions to get exact layer counts:")
    for score in complexity_scores:
        active_layers_count, _, _, _ = calculate_active_layers(
            total_layers, model_size_category, score
        )
        complexity_thresholds[score] = active_layers_count
        level_name = classify_complexity_level(score)
        print(f"   Score {score:3.1f} ({level_name:12}) → {active_layers_count:2d}/{total_layers} layers")

    # Step 5: Build OPTIMIZED config
    print("⚙️ Building optimized configuration...")
    config = {
        "model_info": {
            "name": getattr(model.config, '_name_or_path', 'unknown'),
            "total_parameters": f"{sum(p.numel() for p in model.parameters()) / 1e9:.2f}B",
            "size_category": model_size_category,
            "total_layers": total_layers,
            "architecture": type(model).__name__
        },
        "layers_by_importance": layers_by_importance,
        "complexity_thresholds": complexity_thresholds,
        "complexity_weights": COMPLEXITY_WEIGHTS
    }

    # Step 6: Save optimized config
    with open("adaptive_config.json", "w") as f:
        json.dump(config, f, indent=2)

    print("✅ OPTIMIZED adaptive_config.json created!")

    # Show optimized results
    print(f"📊 Model: {total_layers} layers, {model_size_category}")
    print(f"🏆 Layers by importance: {layers_by_importance[:5]}... (showing first 5)")
    print("🎯 Complexity thresholds:")
    for threshold, count in complexity_thresholds.items():
        percentage = (count / total_layers) * 100
        level = classify_complexity_level(threshold)
        print(f"   {threshold:3.1f} ({level:12}): {count:2d} layers ({percentage:4.1f}%)")

    print("\n🚀 ULTRA-EFFICIENT RUNTIME FORMAT:")

    return config


In [14]:
# =============================================================================
# SIMPLE EXECUTION - OPTIMIZED VERSION
# =============================================================================

print("🚀 CREATING ULTRA-EFFICIENT ADAPTIVE CONFIG")
print("=" * 50)

# Create the OPTIMIZED adaptive config using existing calibration_prompts
adaptive_config = create_adaptive_config_simple(model, tokenizer, calibration_prompts)

print(f"\n🎉 DONE! Optimized adaptive_config.json ready for AAB!")

🚀 CREATING ULTRA-EFFICIENT ADAPTIVE CONFIG
🚀 Creating optimized adaptive config...
🔍 Detected model size: 3.21B parameters
📊 Measuring layer importance...
📊 Processing 5 prompts across 28 layers...
   Processing prompt 1/5
   Processing prompt 2/5
   Processing prompt 3/5
   Processing prompt 4/5
   Processing prompt 5/5
✅ Layer importance measurement complete!
🏆 Creating layers_by_importance list...
🎯 Calculating complexity thresholds...
📊 Using notebook functions to get exact layer counts:
   Score 0.1 (trivial     ) → 16/28 layers
   Score 0.3 (simple      ) → 21/28 layers
   Score 0.5 (medium      ) → 25/28 layers
   Score 0.7 (complex     ) → 26/28 layers
   Score 0.9 (very_complex) → 28/28 layers
⚙️ Building optimized configuration...
✅ OPTIMIZED adaptive_config.json created!
📊 Model: 28 layers, 2B-5B
🏆 Layers by importance: [9, 8, 12, 10, 7]... (showing first 5)
🎯 Complexity thresholds:
   0.1 (trivial     ): 16 layers (57.1%)
   0.3 (simple      ): 21 layers (75.0%)
   0.5 (med

In [15]:
adaptive_config

{'model_info': {'name': 'meta-llama/Llama-3.2-3B',
  'total_parameters': '3.21B',
  'size_category': '2B-5B',
  'total_layers': 28,
  'architecture': 'LlamaForCausalLM'},
 'layers_by_importance': [9,
  8,
  12,
  10,
  7,
  0,
  6,
  27,
  13,
  18,
  11,
  5,
  14,
  3,
  4,
  2,
  15,
  1,
  21,
  17,
  25,
  24,
  16,
  22,
  20,
  26,
  23,
  19],
 'complexity_thresholds': {0.1: 16, 0.3: 21, 0.5: 25, 0.7: 26, 0.9: 28},
 'complexity_weights': {'token_count': 0.65, 'embedding_variance': 0.35}}

## Test prompt complexity

In [16]:
# Using multiple prompts for calibration
calibration_prompts = [
    # Trivial (0.0-0.2) - Short, simple
    "Hi",
    "2+2=",
    "Hello.",
    "What is 2+2?",

    # Simple (0.2-0.4) - Basic questions
    "What is the capital of France?",
    "Tell me a joke.",
    "Name the capital of Catalonia.",
    "Who wrote 'To Kill a Mockingbird'?",

    # Medium (0.4-0.6) - Knowledge retrieval, moderate length
    "Explain the basic principles of machine learning and how neural networks work.",
    "What are the main causes of climate change and what can individuals do to help?",
    "Summarize the plot of 'The Matrix' in one sentence.",
    "List three benefits of regular exercise.",

    # Complex (0.6-0.8) - Multi-step reasoning, analysis
    "Compare and contrast the economic policies of Keynesian and Austrian schools of thought, analyzing their effectiveness during different historical periods and explaining which approach would be most suitable for addressing current global economic challenges.",
    "Design a comprehensive strategy for a small tech startup to compete against established giants like Google and Microsoft in the cloud computing market, considering market positioning, technological differentiation, partnerships, and funding requirements.",
    "Explain why the sky appears blue during the day.",
    "Describe how a neural network learns from data.",

    # Very Complex (0.8-1.0) - Deep analysis, creativity, long form
    "Write a detailed philosophical essay examining the ethical implications of artificial intelligence consciousness, incorporating perspectives from utilitarian, deontological, and virtue ethics frameworks, while addressing counterarguments and proposing a novel ethical framework for AI development that balances technological progress with human values and societal well-being.",
    "Develop a multidisciplinary research proposal that integrates quantum computing, biotechnology, and environmental science to address food security challenges in the context of climate change, including methodology, timeline, budget considerations, potential collaborations, risk assessment, and expected societal impact over the next two decades."
    "Given current economic trends, predict one challenge global markets may face in the next decade.",
    "Write a short poem about the experience of learning something new.",
    "Produce a 450-word technical tutorial that walks through implementing a transformer-based language model from scratch in NumPy, including positional encoding and scaled-dot-product attention."
    "As an expert in global macroeconomics, geopolitical risk assessment, and artificial intelligence ethics, write an in-depth policy advisory report for a coalition of G20 nations facing simultaneous systemic challenges, including post-pandemic inflation volatility, supply chain reconfiguration due to AI-driven automation, increasing regional instability in energy markets, and declining trust in democratic institutions. Your report should propose a coordinated strategy that balances fiscal stimulus with monetary restraint, integrates quantum-secure blockchain for supply chain transparency, and includes AI oversight frameworks aligned with both utilitarian and deontological ethical models. Additionally, evaluate how international institutions like the IMF and the World Bank could modernize their governance structures to reflect multipolar power dynamics, and assess the feasibility of adopting an intergovernmental AI alignment charter inspired by the Paris Agreement model. Your recommendations must be actionable, globally inclusive, and anticipate sociopolitical backlash from both populist and nationalist movements.",
    """
    Draft Integrated Strategic White-Paper for Inter-Agency Review—

Executive Overview:
This document synthesises cutting-edge research in climate science, planetary boundaries, quantum-enhanced computation, synthetic bio-manufacturing, neuro-symbolic artificial intelligence, behavioural economics, geopolitics, space-based energy infrastructure, and post-growth macro-finance. It is intended for cabinet-level policymakers across the G20, the African Union, and APEC, as well as multilateral lenders, sovereign wealth funds, philanthropic megadonors, and fourth-sector cooperative alliances.

Section 1 – Macroeconomic Volatility & Post-Pandemic Debt Overhang
1.1 Analyse the persistence of stagflationary pressures under divergent monetary regimes.
1.2 Model cascading default scenarios using agent-based stress tests that incorporate climate-induced supply-chain interruptions, semiconductor chokepoints in Taiwan and the Netherlands, and maritime bottlenecks in the Suez and Panama Canals.
1.3 Propose a menu of fiscal-monetary coordination instruments—helicopter stabilisation bonds, biodiversity-linked debt swaps, and anti-fragile carbon border adjustments—scaled to emerging-market liquidity traps.

Section 2 – Planetary Health & Regenerative Bio-Economy
2.1 Summarise findings from IPCC AR7 draft chapters on irreversible cryosphere tipping points.
2.2 Evaluate next-generation direct air capture catalysis that leverages metal-organic frameworks seeded by engineered extremophilic microbes.
2.3 Draft a governance blueprint for a Global Soil Microbiome Commons, incorporating indigenous data sovereignty protocols, fair-benefit-sharing algorithms, and quantum-secured telemetry for real-time biodiversity crediting.

Section 3 – Quantum-Classical Hybrid Infrastructure
3.1 Detail a phased roadmap for 1 000-qubit photonic processors coupled to error-mitigated superconducting qubits for combinatorial optimisation in logistics, drug-discovery, and lattice-QCD.
3.2 Define open-standard interfaces that allow sovereign cloud providers to interoperate with NATO-grade zero-trust enclaves and NIST-post-quantum cryptographic suites.
3.3 Recommend incentives for talent-mobility corridors bridging quantum start-up clusters in Toronto, Delft, Shenzhen, Sydney, and Kigali.

Section 4 – Neuro-Symbolic AI & Alignment Governance
4.1 Compare scaling-law extrapolations for transformers, mixture-of-experts, retrieval-augmented decoders, and recursive reasoning agents.
4.2 Propose a multi-layer safety stack: interpretability probes, causal influence diagrams, counterfactual policy evaluation, and cooperative inverse-reinforcement architectures monitored by open-weight red-team sandboxes.
4.3 Outline a treaty-grade AI Alignment Accord modelled after the Paris Agreement, featuring dynamic capability thresholds, compute-cluster registration, differential privacy audits, and a tiered sanctions regime enforced via programmable CBDCs.

Section 5 – Security, Geopolitics & Space-Based Energy
5.1 Assess escalation risks stemming from fractional-orbital bombardment systems, low-cost hypersonic glide vehicles, and AI-directed drone swarms.
5.2 Present techno-economic viability of kilometre-scale solar power satellites in sun-synchronous orbit, with microwave beaming arrays utilising adaptive phased-conjugate mirrors.
5.3 Recommend confidence-building measures: reciprocal on-site inspection, open telemetry APIs, catastrophe-bond insurance pools, and an International Orbital Commons Authority.

Section 6 – Behavioural & Cultural Dynamics
6.1 Integrate behavioural-nudge frameworks, narrative foresight, and social-network epistemic resilience analytics to counter disinformation loops.
6.2 Design outcome-oriented citizen deliberation platforms that leverage quadratic voting, verifiable credentials, and language-agnostic dialogue agents with embedded bias-mitigation layers.

Section 7 – Financing Mechanisms & Implementation Timeline
7.1 Catalogue blended-finance instruments: catalytic first-loss capital, sovereign green sukuk, resilience impact derivatives, and decentralized autonomous project bonds.
7.2 Map a ten-year Gantt chart with critical path analysis, specifying TRL-milestones, regulatory sandboxes, and adaptive procurement clauses.

Call to Action:
Conclude by articulating how cooperative mission-oriented investment, science-diplomacy trust architecture, and inclusive technology governance can converge to safeguard planetary health while enabling equitable prosperity within the safe-and-just operating space for humanity.
    """,
]

In [19]:
def analyze_prompt_complexity(prompts, config, model, tokenizer, verbose: bool = True):
    """
    Compute a complexity score in [0, 1] for each prompt.

    Parameters
    ----------
    prompts : list[str]
        The text prompts to score.
    config : dict
        adaptive_config.json already loaded as dict.
    model : transformers.PreTrainedModel
        The HF model (on CPU or GPU).
    tokenizer : transformers.PreTrainedTokenizer
        Matching tokenizer.
    verbose : bool
        If True, print a per-prompt breakdown.

    Returns
    -------
    list[tuple[str, float]]
        (prompt, complexity_score) for each input string.
    """

    # --- model-specific constants ----------------------------------
    hidden_size    = model.config.hidden_size

    # --- weights (fallback to defaults if not in config) -----------
    weights = config.get("complexity_weights", {
    "token_count": 0.65,
    "embedding_variance": 0.35
})

    results = []

    device  = next(model.parameters()).device
    param_str    = config["model_info"]["total_parameters"]  # e.g. "3.21B"
    param_count  = float(param_str.rstrip("B"))              # 3.21

    # 1) Normalize to a 7 B reference and clamp
    param_factor = param_count / 7.0
    param_factor = max(min(param_factor, 2.0), 0.5)

    # 2) Compute a length budget scaled by model size
    #    (small models saturate sooner, large ones later)
    base_length = 4000
    length_reference = base_length * param_factor

    total_params = sum(p.numel() for p in model.parameters())
    size_billion = total_params / 1e9
    token_multiplier = max(0.85, 1.3 - size_billion * 0.1)
    variance_multiplier = max(0.9, 1.2 - size_billion * 0.08)


    for p in prompts:
        # --- Tokenise on the model's device ------------------------
        ids = tokenizer(p, return_tensors="pt")["input_ids"][0].to(device)
        n_tokens = ids.size(0)

        # A) LENGTH  — log-scaled so it grows smoothly up to ctx-window

        #token_score   = math.log1p(n_tokens) / math.log1p(effective_ctx)
        #token_score   = min(token_score, 1.0)

        # 3) Raw length score (log-scaled)
        raw_score = math.log1p(n_tokens) / math.log1p(length_reference)
        # 4) Size adjustment: boost small models, damp large ones
        size_adjust = 1.0 / param_factor
        adj_score   = raw_score * size_adjust
        MIN_TOKS      = 4
        baseline_raw  = math.log1p(MIN_TOKS) / math.log1p(length_reference)
        baseline_adj  = baseline_raw * size_adjust

        # 5) Final token score
        #token_score = math.log1p(raw_score * size_adjust) / math.log1p(2.0)
        token_score = (adj_score - baseline_adj) / (1.0 - baseline_adj)
        raw_token_score = max(0.0, min(token_score, 1.0))
        token_score = min(raw_token_score * token_multiplier, 1.0)


        # C) EMBEDDING VARIANCE  — std/√d clamped to 1
        #with torch.no_grad():
        #    emb = model.get_input_embeddings()(ids.unsqueeze(0)).squeeze(0)
        #    emb = torch.nn.functional.normalize(emb, p=2, dim=1)
        #    cosine_sim = torch.nn.CosineSimilarity(dim=1)
        #    emb_var = 1 - torch.mean(cosine_sim(emb.unsqueeze(1), emb.unsqueeze(0)))

        #emb_var_norm = min(float(emb_var), 1.0)

        with torch.no_grad():
          emb = model.get_input_embeddings()(ids.unsqueeze(0)).squeeze(0).float()
          n = emb.size(0)
          if n < 3:                       # 1- or 2-token prompt → no diversity
              emb_var_norm = 0.0
          else:
              # 1.  ℓ2-normalise each embedding vector
              norm_emb = torch.nn.functional.normalize(emb, p=2, dim=1)    # (n, d)

              # 2.  Full cosine-similarity matrix
              sim = torch.matmul(norm_emb, norm_emb.t())                   # (n, n)

              # 3.  Remove self-similarities (diagonal) and compute mean
              off_diag = sim[~torch.eye(n, dtype=bool, device=sim.device)]  # (n²-n,)
              base_var = 1.0 - off_diag.mean().item()   # 0 … 1   (0 = identical, 1 = orthogonal)

              # 4.  Length factor: 0 → 1 across full context window
              len_fac = math.log1p(n) / math.log1p(length_reference)         # 0 … 1

              # 5.  Combine & clamp
              raw_emb_var_norm = min(base_var * len_fac, 1.0)
              emb_var_norm = min(raw_emb_var_norm * variance_multiplier, 1.0)


        # --- Weighted combination ---------------------------------
        score = (
            weights["token_count"]        * token_score   +
            weights["embedding_variance"] * emb_var_norm
        )
        score = max(0.0, min(score, 1.0))  # clamp for safety

        if verbose:
            head = (p[:57] + "…") if len(p) > 60 else p
            print(f"{head:<60} | "
                  f"score={score:.3f}  "
                  f"| length={n_tokens} "
                  f"[tok {token_score:.3f}  var {emb_var_norm:.3f}]")

        results.append((p, round(score, 4)))

    return results


In [29]:
def analyze_prompt_complexity(prompts, config, model, tokenizer, verbose: bool = True):
    """
    Compute a complexity score in [0, 1] for each prompt.

    Parameters
    ----------
    prompts : list[str]
        The text prompts to score.
    config : dict
        adaptive_config.json already loaded as dict.
    model : transformers.PreTrainedModel
        The HF model (on CPU or GPU).
    tokenizer : transformers.PreTrainedTokenizer
        Matching tokenizer.
    verbose : bool
        If True, print a per-prompt breakdown.

    Returns
    -------
    list[tuple[str, float]]
        (prompt, complexity_score) for each input string.
    """

    # Get model size and device

    device = next(model.parameters()).device
    total_params = sum(p.numel() for p in model.parameters())
    size_billion = total_params / 1e9
    MIN_TOKENS = 10
    # Unified size adjustment factor
    # Small models (< 2B) get boost, large models (> 10B) get dampening
    size_factor = 1.0 + (2.0 - size_billion) * 0.1
    size_factor = max(0.5, min(2.0, size_factor))  # Clamp between 0.5 and 2.0


    # Length reference scaled by model size
    # Smaller models reach max complexity with shorter prompts
    base_length = 2000
    length_reference = base_length / size_factor
    variance_saturation = length_reference / 15

    # Get weights from config
    weights = config.get("complexity_weights", {
        "token_count": 0.65,
        "embedding_variance": 0.35
    })

    results = []

    for prompt in prompts:
        # Tokenize
        ids = tokenizer(prompt, return_tensors="pt")["input_ids"][0].to(device)
        n_tokens = ids.size(0)

        # 1. TOKEN SCORE - Simple logarithmic scaling
        # Maps token count to [0, 1] with smooth growth
        token_score = math.log1p(n_tokens) / math.log1p(length_reference)
        token_score = min(token_score * size_factor, 1.0)
        if n_tokens < MIN_TOKENS:
          dampening = (n_tokens / MIN_TOKENS) ** 2  # Quadratic dampening
          token_score = token_score * dampening

        # 2. EMBEDDING VARIANCE - Semantic diversity
        with torch.no_grad():
            emb = model.get_input_embeddings()(ids.unsqueeze(0)).squeeze(0).float()
            n = emb.size(0)

            if n < 3:
                # Too few tokens for meaningful variance
                emb_variance = 0.0
            else:
                # Normalize embeddings
                norm_emb = torch.nn.functional.normalize(emb, p=2, dim=1)

                # Compute pairwise cosine similarities
                sim_matrix = torch.matmul(norm_emb, norm_emb.t())

                # Get off-diagonal elements (exclude self-similarity)
                mask = ~torch.eye(n, dtype=bool, device=device)
                off_diag_sim = sim_matrix[mask]

                # Variance = 1 - mean similarity
                # Higher variance = more diverse embeddings
                emb_variance = 1.0 - off_diag_sim.mean().item()

                # Scale by length (longer prompts naturally have more variance)
                length_scale = min(n_tokens / variance_saturation, 1.0)
                emb_variance = emb_variance * length_scale

        # 3. FINAL SCORE - Weighted combination
        complexity_score = (
            weights["token_count"] * token_score +
            weights["embedding_variance"] * emb_variance
        )
        complexity_score = max(0.0, min(complexity_score, 1.0))

        if verbose:
            prompt_preview = (prompt[:57] + "…") if len(prompt) > 60 else prompt
            print(f"{prompt_preview:<60} | "
                  f"score={complexity_score:.3f} | "
                  f"tokens={n_tokens} "
                  f"[tok={token_score:.3f} var={emb_variance:.3f}]")

        results.append((prompt, round(complexity_score, 4)))

    return results


In [30]:
analyze_prompt_complexity(calibration_prompts, adaptive_config, model,  tokenizer)

Hi                                                           | score=0.003 | tokens=2 [tok=0.005 var=0.000]
2+2=                                                         | score=0.042 | tokens=5 [tok=0.051 var=0.026]
Hello.                                                       | score=0.016 | tokens=3 [tok=0.014 var=0.021]
What is 2+2?                                                 | score=0.120 | tokens=8 [tok=0.160 var=0.046]
What is the capital of France?                               | score=0.122 | tokens=8 [tok=0.160 var=0.050]
Tell me a joke.                                              | score=0.066 | tokens=6 [tok=0.080 var=0.040]
Name the capital of Catalonia.                               | score=0.091 | tokens=7 [tok=0.116 var=0.046]
Who wrote 'To Kill a Mockingbird'?                           | score=0.209 | tokens=11 [tok=0.282 var=0.072]
Explain the basic principles of machine learning and how …   | score=0.238 | tokens=15 [tok=0.315 var=0.096]
What are the main causes o

[('Hi', 0.0032),
 ('2+2=', 0.0424),
 ('Hello.', 0.0164),
 ('What is 2+2?', 0.1199),
 ('What is the capital of France?', 0.1215),
 ('Tell me a joke.', 0.0657),
 ('Name the capital of Catalonia.', 0.0913),
 ("Who wrote 'To Kill a Mockingbird'?", 0.2088),
 ('Explain the basic principles of machine learning and how neural networks work.',
  0.2384),
 ('What are the main causes of climate change and what can individuals do to help?',
  0.2498),
 ("Summarize the plot of 'The Matrix' in one sentence.", 0.2381),
 ('List three benefits of regular exercise.', 0.1221),
 ('Compare and contrast the economic policies of Keynesian and Austrian schools of thought, analyzing their effectiveness during different historical periods and explaining which approach would be most suitable for addressing current global economic challenges.',
  0.3546),
 ('Design a comprehensive strategy for a small tech startup to compete against established giants like Google and Microsoft in the cloud computing market, consi