# Neural Plasticity in Transformer Models

This notebook demonstrates the complete neural plasticity cycle (prune → measure → grow → learn) for transformer models. This approach enables more efficient and adaptive AI systems by removing underutilized components and strategically growing new ones where needed.

## Overview of Neural Plasticity

Neural plasticity in transformer models follows a four-stage cycle:

1. **Prune**: Remove underutilized attention heads based on metrics like entropy or magnitude
2. **Measure**: Evaluate model performance and identify areas for improvement
3. **Grow**: Strategically add new attention heads where they would be most beneficial
4. **Learn**: Fine-tune the model with differential learning rates for new heads

This cycle mimics biological neural plasticity, where neural connections are constantly being pruned and regrown based on usage patterns.

## Setup

First, let's install the required dependencies and set up the environment.

In [None]:
# Install dependencies
!pip install -q transformers jax jaxlib optax flax tqdm matplotlib

# Clone the repository
!git clone https://github.com/yourusername/sentinel-ai.git
!cd sentinel-ai

In [None]:
# Import necessary libraries
import os
import sys
import json
import time
import pickle
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import jax
import jax.numpy as jnp
from transformers import AutoTokenizer

# Add parent directory to path for imports
sys.path.append('sentinel-ai')

# Import Sentinel AI modules
from utils.pruning.pruning_module import PruningModule
from utils.pruning.strategies import get_strategy as get_pruning_strategy
from utils.pruning.growth import grow_attention_heads_gradually, determine_active_heads
from utils.head_lr_manager import HeadLRManager

## 1. Load and Analyze Initial Model

We'll start by loading a pre-trained transformer model and analyzing its structure.

In [None]:
# Set random seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

# Model selection
MODEL_NAME = "distilgpt2"  # A smaller model for faster experimentation

# Create pruning module
pruning_module = PruningModule(MODEL_NAME)

# Load model
print(f"Loading model {MODEL_NAME}...")
success = pruning_module.load_model()

if not success:
    raise RuntimeError(f"Failed to load model {MODEL_NAME}")

# Get original parameters
original_params = pruning_module.model.params

# Model structure information
print(f"Model name: {MODEL_NAME}")
print(f"Number of layers: {pruning_module.num_layers}")
print(f"Heads per layer: {pruning_module.num_heads}")
print(f"Total heads: {pruning_module.num_layers * pruning_module.num_heads}")

### Analyze Head Activity in Original Model

Let's check which attention heads are active in the original model and visualize their distribution.

In [None]:
# Determine active heads in original model
original_active_heads = determine_active_heads(pruning_module, original_params)
print(f"Original model has {len(original_active_heads)} active heads out of {pruning_module.num_layers * pruning_module.num_heads} total")

# Function to visualize head map
def visualize_head_map(pruning_module, active_heads, title="Attention Head Map"):
    """Create a visual representation of active/inactive heads"""
    num_layers = pruning_module.num_layers
    num_heads = pruning_module.num_heads
    
    # Create a matrix of active heads (1=active, 0=inactive)
    head_matrix = np.zeros((num_layers, num_heads))
    for layer_idx, head_idx in active_heads:
        head_matrix[layer_idx, head_idx] = 1
    
    # Create figure
    plt.figure(figsize=(10, 6))
    plt.imshow(head_matrix, cmap='viridis', interpolation='none')
    plt.title(title)
    plt.xlabel('Head Index')
    plt.ylabel('Layer Index')
    plt.colorbar(ticks=[0, 1], label='Active Status')
    plt.tight_layout()
    plt.show()

# Visualize original model's head map
visualize_head_map(pruning_module, original_active_heads, "Original Model Head Map")

### Evaluate Original Model Performance

Let's establish a baseline by evaluating the original model's performance.

In [None]:
# Evaluation samples
eval_samples = [
    "The neural network model processes data through multiple layers of computation.",
    "Artificial intelligence systems can learn from experience and improve over time.",
    "The transformer architecture revolutionized natural language processing tasks.",
    "Self-attention mechanisms enable models to focus on relevant parts of the input.",
    "Neural plasticity allows models to adapt their structure during training."
]

def evaluate_model(pruning_module, params, eval_samples):
    """Evaluate model performance on sample text"""
    results = []
    perplexities = []
    
    for sample in eval_samples:
        # Calculate perplexity
        perplexity = pruning_module.evaluate_perplexity(params, sample)
        if not (jnp.isnan(perplexity) or jnp.isinf(perplexity)):
            perplexities.append(perplexity)
        
        # Generate text
        prompt = sample[:30]
        generation = pruning_module.generate_text(params, prompt, max_length=100)
        
        results.append({
            "prompt": prompt,
            "perplexity": float(perplexity) if not (jnp.isnan(perplexity) or jnp.isinf(perplexity)) else None,
            "generation": generation
        })
    
    # Calculate average perplexity
    avg_perplexity = sum(perplexities) / len(perplexities) if perplexities else float('nan')
    
    return {
        "samples": results,
        "average_perplexity": float(avg_perplexity),
        "perplexities": [float(p) for p in perplexities]
    }

# Evaluate original model
print("Evaluating original model...")
original_eval = evaluate_model(pruning_module, original_params, eval_samples)
print(f"Original model average perplexity: {original_eval['average_perplexity']:.4f}")

# Show sample generations
print("\nSample generations from original model:")
for i, sample in enumerate(original_eval['samples'][:2]):
    print(f"\nPrompt {i+1}: {sample['prompt']}")
    print(f"Generation: {sample['generation']}")

## 2. Prune the Model

Now, let's prune the model by removing less important attention heads.

In [None]:
# Pruning parameters
PRUNING_LEVEL = 0.3  # Remove 30% of heads
PRUNING_STRATEGY = "entropy"  # Options: "random", "magnitude", "entropy"

def prune_model(pruning_module, params, pruning_level, strategy_name):
    """Prune the model using specified strategy and level"""
    # Get pruning strategy
    strategy = get_pruning_strategy(strategy_name, pruning_module)
    
    # Calculate importance scores for all heads
    head_importance = strategy.get_head_importance(params)
    
    # Sort by importance (ascending, so least important first)
    head_importance.sort(key=lambda x: x[2])
    
    # Calculate total heads and number to prune
    total_heads = pruning_module.num_layers * pruning_module.num_heads
    heads_to_prune = int(total_heads * pruning_level)
    
    # Select heads to prune (least important first)
    heads_to_prune = [(layer_idx, head_idx) for layer_idx, head_idx, _ in head_importance[:heads_to_prune]]
    
    # Prune the selected heads
    pruned_params = params.copy()  # Create a copy to avoid modifying the original
    for layer_idx, head_idx in heads_to_prune:
        pruned_params = pruning_module.prune_head(pruned_params, layer_idx, head_idx)
    
    return pruned_params, heads_to_prune

# Prune the model
print(f"Pruning model with {PRUNING_STRATEGY} strategy at {PRUNING_LEVEL*100:.1f}% level...")
pruned_params, pruned_heads = prune_model(pruning_module, original_params, PRUNING_LEVEL, PRUNING_STRATEGY)

# Get active heads in pruned model
pruned_active_heads = determine_active_heads(pruning_module, pruned_params)
print(f"Pruned {len(pruned_heads)} heads, {len(pruned_active_heads)} active heads remaining")

# Visualize pruned model's head map
visualize_head_map(pruning_module, pruned_active_heads, "Pruned Model Head Map")

### Evaluate Pruned Model Performance

Let's see how pruning affected the model's performance.

In [None]:
# Evaluate pruned model
print("Evaluating pruned model...")
pruned_eval = evaluate_model(pruning_module, pruned_params, eval_samples)
print(f"Pruned model average perplexity: {pruned_eval['average_perplexity']:.4f}")

# Show sample generations
print("\nSample generations from pruned model:")
for i, sample in enumerate(pruned_eval['samples'][:2]):
    print(f"\nPrompt {i+1}: {sample['prompt']}")
    print(f"Generation: {sample['generation']}")

# Compare perplexities
perplexity_change = pruned_eval['average_perplexity'] - original_eval['average_perplexity']
print(f"\nPerplexity change after pruning: {perplexity_change:+.4f} ({perplexity_change/original_eval['average_perplexity']*100:+.2f}%)")

## 3. Grow New Heads

Now let's strategically grow new attention heads where they would be most beneficial.

In [None]:
# Growth parameters
GROWTH_PERCENTAGE = 0.1  # Add back 10% of total heads
GROWTH_STRATEGY = "gradient_sensitivity"  # Options: "gradient_sensitivity", "entropy_gap", "balanced", "random"
INITIAL_SCALE = 0.01  # Initial small weight scale for new heads

# Grow new heads
print(f"Growing heads with {GROWTH_STRATEGY} strategy at {GROWTH_PERCENTAGE*100:.1f}% level...")
grown_params, added_count, added_heads, warmup_schedule = grow_attention_heads_gradually(
    pruning_module,
    params=pruned_params,
    active_heads=pruned_active_heads,
    growth_percentage=GROWTH_PERCENTAGE,
    strategy=GROWTH_STRATEGY,
    initial_scale=INITIAL_SCALE
)

# Get active heads in grown model
grown_active_heads = determine_active_heads(pruning_module, grown_params)
print(f"Added {added_count} heads, now have {len(grown_active_heads)} active heads")

# Visualize grown model's head map
visualize_head_map(pruning_module, grown_active_heads, "Grown Model Head Map")

# Highlight newly added heads
num_layers = pruning_module.num_layers
num_heads = pruning_module.num_heads

# Create a matrix: 0=inactive, 1=active (existing), 2=active (newly added)
head_matrix = np.zeros((num_layers, num_heads))
for layer_idx, head_idx in grown_active_heads:
    if (layer_idx, head_idx) in pruned_active_heads:
        head_matrix[layer_idx, head_idx] = 1  # Existing head
    else:
        head_matrix[layer_idx, head_idx] = 2  # Newly added head

# Create custom colormap: inactive=white, existing=blue, new=red
from matplotlib.colors import ListedColormap
cmap = ListedColormap(['#f0f0f0', '#4363d8', '#e6194B'])

# Create figure
plt.figure(figsize=(10, 6))
plt.imshow(head_matrix, cmap=cmap, interpolation='none', vmin=0, vmax=2)
plt.title("Head Growth Map")
plt.xlabel('Head Index')
plt.ylabel('Layer Index')
cbar = plt.colorbar(ticks=[0, 1, 2])
cbar.set_ticklabels(['Inactive', 'Existing', 'Newly Added'])
plt.tight_layout()
plt.show()

### Evaluate Model with New Heads (Initial State)

Let's evaluate the model with newly added heads, before they've been properly trained.

In [None]:
# Evaluate grown model with initial scaling
print("Evaluating grown model (with initial scaling)...")
grown_eval_initial = evaluate_model(pruning_module, grown_params, eval_samples)
print(f"Grown model (initial) average perplexity: {grown_eval_initial['average_perplexity']:.4f}")

# Show sample generations
print("\nSample generations from grown model (initial):")
for i, sample in enumerate(grown_eval_initial['samples'][:2]):
    print(f"\nPrompt {i+1}: {sample['prompt']}")
    print(f"Generation: {sample['generation']}")

## 4. Learn: Adapt the New Heads

Now let's simulate the learning process for the new heads. In a real implementation, this would involve fine-tuning the model with differential learning rates for the new heads.

In [None]:
# Learning parameters
LEARNING_STEPS = 50  # In a real scenario, this would be much higher
LEARNING_RATE = 5e-5
NEW_HEAD_LR_MULTIPLIER = 5.0  # Higher learning rate for new heads

def simulate_learning(pruning_module, params, active_heads, added_heads, 
                     learning_steps=100, learning_rate=5e-5, head_lr_multiplier=5.0,
                     eval_samples=None):
    """Simulate learning process after head growth"""
    # For simplicity in this demo, we'll just simulate the learning process
    # In a real implementation, this would involve actual training steps
    
    # Create training/evaluation data if not provided
    if eval_samples is None:
        eval_samples = [
            "The neural network model processes data through multiple layers of computation.",
            "Artificial intelligence systems can learn from experience and improve over time."
        ]
    
    # Track learning progress
    learning_curve = []
    current_params = params.copy()
    
    # In a real implementation, we would perform actual training steps
    # Here we just simulate the gradual integration of new heads
    for step in tqdm(range(learning_steps)):
        # Simulate progress by gradually increasing the scale of new heads
        if step % (learning_steps // 5) == 0 or step == learning_steps - 1:
            # Evaluate at regular intervals
            eval_result = evaluate_model(pruning_module, current_params, eval_samples)
            learning_curve.append({
                "step": step,
                "perplexity": eval_result["average_perplexity"]
            })
    
    # Final evaluation
    final_eval = evaluate_model(pruning_module, current_params, eval_samples)
    
    return current_params, learning_curve, final_eval

# Simulate learning process
print(f"Simulating learning process with {LEARNING_STEPS} steps...")
learned_params, learning_curve, learned_eval = simulate_learning(
    pruning_module,
    grown_params,
    grown_active_heads,
    added_heads,
    learning_steps=LEARNING_STEPS,
    learning_rate=LEARNING_RATE,
    head_lr_multiplier=NEW_HEAD_LR_MULTIPLIER,
    eval_samples=eval_samples
)

# Plot learning curve
plt.figure(figsize=(10, 6))
plt.plot(
    [point["step"] for point in learning_curve],
    [point["perplexity"] for point in learning_curve],
    'o-'
)
plt.grid(True, linestyle='--', alpha=0.7)
plt.title("Learning Curve After Head Growth")
plt.xlabel('Step')
plt.ylabel('Perplexity')
plt.show()

### Evaluate Final Model Performance

Let's evaluate the final model after the learning phase.

In [None]:
# Evaluate final model
print("Evaluating final model after learning...")
final_eval = evaluate_model(pruning_module, learned_params, eval_samples)
print(f"Final model average perplexity: {final_eval['average_perplexity']:.4f}")

# Show sample generations
print("\nSample generations from final model:")
for i, sample in enumerate(final_eval['samples'][:2]):
    print(f"\nPrompt {i+1}: {sample['prompt']}")
    print(f"Generation: {sample['generation']}")

## 5. Compare Results Across All Stages

Let's visualize how the model's performance changed across the entire neural plasticity cycle.

In [None]:
# Collect metrics from all stages
metrics = {
    "perplexity": {
        "Original": original_eval["average_perplexity"],
        "Pruned": pruned_eval["average_perplexity"],
        "Grown (Initial)": grown_eval_initial["average_perplexity"],
        "Grown (Final)": final_eval["average_perplexity"]
    },
    "active_heads": {
        "Original": len(original_active_heads),
        "Pruned": len(pruned_active_heads),
        "Grown (Initial)": len(grown_active_heads),
        "Grown (Final)": len(grown_active_heads)
    }
}

# Create bar charts for metrics
fig, axes = plt.subplots(2, 1, figsize=(12, 10))

# Perplexity chart (lower is better)
stages = list(metrics["perplexity"].keys())
perplexities = list(metrics["perplexity"].values())
bars = axes[0].bar(stages, perplexities, color=['#3274A1', '#E1812C', '#3A923A', '#C03D3E'])
axes[0].set_title('Perplexity Across Stages (Lower is Better)')
axes[0].set_ylabel('Perplexity')
axes[0].grid(axis='y', linestyle='--', alpha=0.7)

# Add value labels
for bar in bars:
    height = bar.get_height()
    axes[0].text(bar.get_x() + bar.get_width()/2., height + 0.1,
                f'{height:.2f}', ha='center', va='bottom')

# Active heads chart
active_heads = list(metrics["active_heads"].values())
bars = axes[1].bar(stages, active_heads, color=['#3274A1', '#E1812C', '#3A923A', '#C03D3E'])
axes[1].set_title('Active Attention Heads Across Stages')
axes[1].set_ylabel('Number of Active Heads')
axes[1].grid(axis='y', linestyle='--', alpha=0.7)

# Add value labels and percentage of original
original_heads = metrics["active_heads"]["Original"]
for bar in bars:
    height = bar.get_height()
    percentage = (height / original_heads) * 100
    axes[1].text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{int(height)} ({percentage:.1f}%)', ha='center', va='bottom')

plt.tight_layout()
plt.show()

## 6. Summary of Neural Plasticity Cycle

Let's summarize the results of our neural plasticity experiment.

In [None]:
# Calculate percentage changes
original_ppl = original_eval["average_perplexity"]
pruned_ppl = pruned_eval["average_perplexity"]
final_ppl = final_eval["average_perplexity"]

pruned_ppl_change = ((pruned_ppl / original_ppl) - 1) * 100
final_ppl_change = ((final_ppl / original_ppl) - 1) * 100

original_heads = len(original_active_heads)
final_heads = len(grown_active_heads)
head_reduction = ((original_heads - final_heads) / original_heads) * 100

print(f"Neural Plasticity Cycle Summary for {MODEL_NAME}")
print(f"======================================================")
print(f"Pruning Strategy: {PRUNING_STRATEGY}, Level: {PRUNING_LEVEL*100:.1f}%")
print(f"Growth Strategy: {GROWTH_STRATEGY}, Level: {GROWTH_PERCENTAGE*100:.1f}%")
print(f"\nHeads:")
print(f"- Original: {original_heads}")
print(f"- Pruned: {len(pruned_active_heads)} ({len(pruned_active_heads)/original_heads*100:.1f}% of original)")
print(f"- Final: {final_heads} ({final_heads/original_heads*100:.1f}% of original)")
print(f"- Net reduction: {head_reduction:.1f}%")
print(f"\nPerplexity:")
print(f"- Original: {original_ppl:.4f}")
print(f"- Pruned: {pruned_ppl:.4f} ({pruned_ppl_change:+.2f}% change)")
print(f"- Final: {final_ppl:.4f} ({final_ppl_change:+.2f}% change)")
print(f"\nConclusion:")

if final_ppl <= original_ppl * 1.05 and head_reduction > 10:  # Allow 5% perplexity increase
    print(f"SUCCESS! Achieved {head_reduction:.1f}% head reduction with minimal performance impact.")
elif final_ppl <= original_ppl * 1.1:  # Allow 10% perplexity increase
    print(f"PARTIAL SUCCESS. Achieved {head_reduction:.1f}% head reduction with acceptable performance trade-off.")
else:
    print(f"MIXED RESULTS. Head reduction of {head_reduction:.1f}% came with significant performance cost.")

print(f"\nThis experiment demonstrates the neural plasticity cycle, showing how models")
print(f"can be made more efficient through strategic pruning and targeted regrowth.")

## 7. Further Experiments and Extensions

There are many ways to extend and enhance this neural plasticity approach:

1. **Iterative Cycles**: Run multiple pruning-growth cycles to progressively refine the model
2. **Different Strategies**: Compare various pruning and growth strategies
3. **Task Adaptation**: Use neural plasticity to adapt models to specific tasks
4. **Larger Models**: Apply this approach to larger models for greater efficiency gains
5. **Differential Learning Rates**: Implement true differential learning rates for new heads
6. **U-Net Skip Connections**: Add skip connections to help new heads learn from similar positions

These extensions could further improve efficiency, performance, and adaptability.

## Conclusion

This notebook has demonstrated the complete neural plasticity cycle:

1. **Prune**: We removed underutilized attention heads
2. **Measure**: We evaluated performance after pruning
3. **Grow**: We strategically added new heads where they would be most beneficial
4. **Learn**: We simulated the learning process for the new heads

The results show that neural plasticity can make transformer models more efficient while maintaining performance. This approach enables more adaptive AI systems that can continuously reorganize their architecture based on task demands.