# Neural Plasticity Demo: Dynamic Pruning & Regrowth (v0.0.5)

This notebook demonstrates Sentinel AI's neural plasticity system, which allows transformer models to dynamically prune and regrow attention heads during training based on utility metrics.

## What is Neural Plasticity?

Neural plasticity is the ability of neural networks to adapt their structure over time through pruning (removing unused connections) and regrowth (restoring useful connections). This mimics how biological brains form efficient neural pathways.

In this demo, we:
1. Track the entropy and gradient patterns of each attention head
2. Dynamically prune high-entropy, low-gradient heads (unfocused, less useful)
3. Selectively revive low-entropy, higher-gradient heads (potentially useful)
4. Visualize the "brain dynamics" over time

This allows models to form more efficient neural structures during training.

In [None]:
# Install required packages
!pip install -q torch transformers datasets matplotlib seaborn

# Clone the Sentinel AI repository
!git clone -b feature/implement-adaptive-plasticity https://github.com/CambrianTech/sentinel-ai.git
%cd sentinel-ai

# Add repository to path
import sys
sys.path.append('.')

## Configure the Experiment

Let's set up our configuration for the neural plasticity experiment.

In [ ]:
# Configure experiment
MODEL_NAME = "distilgpt2"  # Small GPT-2 model for faster demonstration
DATASET = "wikitext"
DATASET_CONFIG = "wikitext-2-raw-v1"
MAX_LENGTH = 128
BATCH_SIZE = 4
NUM_EPOCHS = 100
LEARNING_RATE = 5e-5
WARMUP_STEPS = 100
WARMUP_EPOCHS = 1     # Number of epochs to run warmup
EVAL_INTERVAL = 50    # Evaluate every 50 steps

# Configure pruning mode
from sentinel.pruning.dual_mode_pruning import PruningMode

# Set pruning mode (ADAPTIVE allows recovery, COMPRESSED prevents recovery)
PRUNING_MODE = PruningMode.ADAPTIVE  # Change to PruningMode.COMPRESSED for permanent pruning

# Configure plasticity thresholds (even more aggressive pruning)
HIGH_ENTROPY_THRESHOLD = 0.4  # Drastically lowered from 0.6 - Heads with entropy above this are candidates for pruning
LOW_ENTROPY_THRESHOLD = 0.2   # Lowered from 0.3 - Pruned heads with entropy below this are candidates for revival
GRAD_THRESHOLD = 1e-3         # Increased from 5e-5 - Gradient threshold for pruning decisions
MIN_ZERO_EPOCHS = 1           # Minimum epochs a head should remain pruned

## Load Model and Dataset

Now we'll load the model and prepare the dataset for training.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    default_data_collator,
    get_linear_schedule_with_warmup
)
from torch.utils.data import DataLoader
from datasets import load_dataset
from sentinel.pruning.plasticity_controller import create_plasticity_controller

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load model and tokenizer
print(f"Loading model: {MODEL_NAME}")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Set pad token if needed
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load datasets
print(f"Loading dataset: {DATASET}/{DATASET_CONFIG}")
train_dataset = load_dataset(DATASET, DATASET_CONFIG, split="train")
validation_dataset = load_dataset(DATASET, DATASET_CONFIG, split="validation")

# Define tokenization function
def tokenize_function(examples):
    return tokenizer(
        examples["text"], 
        padding="max_length", 
        truncation=True, 
        max_length=MAX_LENGTH
    )

# Tokenize datasets
train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
validation_dataset = validation_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# Add labels for language modeling
def add_labels(examples):
    examples["labels"] = examples["input_ids"].copy()
    return examples

train_dataset = train_dataset.map(add_labels)
validation_dataset = validation_dataset.map(add_labels)

# Set format
train_dataset = train_dataset.with_format("torch")
validation_dataset = validation_dataset.with_format("torch")

# Create dataloaders
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=default_data_collator
)

validation_dataloader = DataLoader(
    validation_dataset, 
    batch_size=BATCH_SIZE, 
    collate_fn=default_data_collator
)

print(f"Train dataset size: {len(train_dataset)} examples")
print(f"Validation dataset size: {len(validation_dataset)} examples")

## Define Evaluation Function

Let's define a function to evaluate our model's performance.

In [None]:
def evaluate_model(model, dataloader):
    """Evaluate model on the provided dataloader."""
    model.eval()
    total_loss = 0.0
    total_steps = 0
    
    with torch.no_grad():
        for batch in dataloader:
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            
            # Forward pass
            outputs = model(**batch)
            loss = outputs.loss
            
            total_loss += loss.item()
            total_steps += 1
            
            # Limit evaluation to 10 steps for speed
            if total_steps >= 10:
                break
    
    avg_loss = total_loss / total_steps if total_steps > 0 else float("inf")
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    
    return avg_loss, perplexity

def generate_text(prompt, max_length=100):
    """Generate text from the model."""
    # Set model to evaluation mode
    model.eval()
    
    # Encode prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    # Generate text
    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            temperature=0.7,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode and return text
    return tokenizer.decode(output[0], skip_special_tokens=True)

## Run Model Warm-up

Before measuring baseline performance, we'll run a brief warm-up phase to stabilize the model parameters.

In [None]:
# Initialize optimizer and scheduler for warm-up
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
total_steps = len(train_dataloader) * WARMUP_EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=WARMUP_STEPS, 
    num_training_steps=total_steps
)

print(f"Running warm-up for {WARMUP_EPOCHS} epoch(s)...")

# Warm-up training loop
model.train()
warmup_losses = []

for epoch in range(WARMUP_EPOCHS):
    epoch_loss = 0.0
    epoch_steps = 0
    
    for step, batch in enumerate(train_dataloader):
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
        
        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        # Track loss
        epoch_loss += loss.item()
        epoch_steps += 1
        
        # Print progress every 10 steps
        if step % 10 == 0:
            warmup_losses.append(loss.item())
            print(f"Warm-up Epoch {epoch+1}, Step {step}: Loss = {loss.item():.4f}\r", end="")
            
        # Stop after 50 steps for faster execution in demo
        if step >= 50:
            break
    
    print(f"\nWarm-up Epoch {epoch+1} completed: Average Loss = {epoch_loss / epoch_steps:.4f}")

# Plot warm-up loss
plt.figure(figsize=(10, 5))
plt.plot(warmup_losses)
plt.title("Warm-up Loss")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.grid(True)
plt.show()

## Evaluate Baseline Model

Now let's measure the baseline performance after warm-up.

In [None]:
# Evaluate baseline model after warm-up
baseline_loss, baseline_perplexity = evaluate_model(model, validation_dataloader)
print(f"Baseline evaluation after warm-up: Loss = {baseline_loss:.4f}, Perplexity = {baseline_perplexity:.2f}")

# Generate text with baseline model
prompt = "Once upon a time"
baseline_text = generate_text(prompt)
print(f"\nPrompt: {prompt}")
print(f"Generated text:\n{baseline_text}")

## Create Neural Plasticity Controller

Now we'll create the plasticity controller that will monitor head metrics and dynamically prune/revive heads during training.

In [ ]:
# Create plasticity controller
controller = create_plasticity_controller(
    model=model,
    mode=PRUNING_MODE,
    high_entropy_threshold=HIGH_ENTROPY_THRESHOLD,
    low_entropy_threshold=LOW_ENTROPY_THRESHOLD,
    grad_threshold=GRAD_THRESHOLD,
    min_zero_epochs=MIN_ZERO_EPOCHS
)

# Display initial model stats
initial_stats = controller.get_summary()
print(f"Model has {initial_stats['total_heads']} attention heads across {controller.total_layers} layers")

# Debug: Let's check the actual entropy values we're dealing with
print("\nCollecting initial entropy and gradient metrics for debugging...")
debug_entropy, debug_grads = controller.collect_head_metrics(
    validation_dataloader,
    num_batches=2,
    verbose=True
)

# Calculate statistics to help with threshold setting
print("\nEntropy statistics:")
print(f"Mean entropy: {debug_entropy.mean().item():.4f}")
print(f"Min entropy: {debug_entropy.min().item():.4f}")
print(f"Max entropy: {debug_entropy.max().item():.4f}")
print(f"25th percentile: {torch.quantile(debug_entropy.flatten(), 0.25).item():.4f}")
print(f"50th percentile: {torch.quantile(debug_entropy.flatten(), 0.5).item():.4f}")
print(f"75th percentile: {torch.quantile(debug_entropy.flatten(), 0.75).item():.4f}")

print("\nGradient norm statistics:")
print(f"Mean grad norm: {debug_grads.mean().item():.6f}")
print(f"Min grad norm: {debug_grads.min().item():.6f}")
print(f"Max grad norm: {debug_grads.max().item():.6f}")
print(f"25th percentile: {torch.quantile(debug_grads.flatten(), 0.25).item():.6f}")
print(f"50th percentile: {torch.quantile(debug_grads.flatten(), 0.5).item():.6f}")
print(f"75th percentile: {torch.quantile(debug_grads.flatten(), 0.75).item():.6f}")

# Print current thresholds
print(f"\nCurrent thresholds:")
print(f"HIGH_ENTROPY_THRESHOLD: {HIGH_ENTROPY_THRESHOLD:.4f}")
print(f"LOW_ENTROPY_THRESHOLD: {LOW_ENTROPY_THRESHOLD:.4f}")
print(f"GRAD_THRESHOLD: {GRAD_THRESHOLD:.6f}")

# Count how many heads would be pruned with current thresholds
would_prune = (debug_entropy > HIGH_ENTROPY_THRESHOLD) & (debug_grads < GRAD_THRESHOLD)
print(f"\nWith current thresholds, {would_prune.sum().item()} heads would be eligible for pruning.")

## Collect Initial Head Metrics

Let's look at the initial entropy and gradient patterns of our attention heads.

In [None]:
# Collect initial head metrics
entropy_values, grad_norm_values = controller.collect_head_metrics(
    validation_dataloader, 
    num_batches=2
)

# Plot entropy heatmap
plt.figure(figsize=(10, 6))
plt.title("Initial Head Entropy (higher = less focused attention)")
entropy_map = plt.imshow(entropy_values.detach().cpu().numpy(), cmap="viridis", aspect="auto")
plt.colorbar(entropy_map, label="Entropy")
plt.xlabel("Head Index")
plt.ylabel("Layer Index")
plt.tight_layout()
plt.show()

# Plot gradient norm heatmap
plt.figure(figsize=(10, 6))
plt.title("Initial Head Gradient Norms (higher = more learning)")
grad_map = plt.imshow(grad_norm_values.detach().cpu().numpy(), cmap="plasma", aspect="auto")
plt.colorbar(grad_map, label="Gradient Norm")
plt.xlabel("Head Index")
plt.ylabel("Layer Index")
plt.tight_layout()
plt.show()

## Training with Neural Plasticity

Now let's train the model with neural plasticity enabled, dynamically pruning and reviving attention heads.

In [ ]:
# Initialize training components
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
total_steps = len(train_dataloader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=WARMUP_STEPS, 
    num_training_steps=total_steps
)

# Initialize metrics tracking
metrics_history = {
    "train_loss": [],
    "eval_loss": [],
    "pruned_heads": [],
    "revived_heads": [],
    "sparsity": [],
    "step": []
}

# Training loop
global_step = 0

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    model.train()
    
    epoch_loss = 0.0
    epoch_steps = 0
    
    for step, batch in enumerate(train_dataloader):
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
        
        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        # Track loss
        epoch_loss += loss.item()
        epoch_steps += 1
        global_step += 1
        
        # Periodically evaluate and apply plasticity
        if global_step % EVAL_INTERVAL == 0:
            # Evaluate
            model.eval()
            eval_loss, eval_perplexity = evaluate_model(model, validation_dataloader)
            
            # Apply plasticity
            pruned, revived, plasticity_metrics = controller.step(
                validation_dataloader, 
                num_batches=2,
                verbose=True
            )
            
            # Update metrics
            metrics_history["train_loss"].append(epoch_loss / epoch_steps)
            metrics_history["eval_loss"].append(eval_loss)
            metrics_history["pruned_heads"].append(len(pruned))
            metrics_history["revived_heads"].append(len(revived))
            metrics_history["sparsity"].append(plasticity_metrics["sparsity"])
            metrics_history["step"].append(global_step)
            
            # Print status
            print(f"  Step {global_step} - Train loss: {epoch_loss / epoch_steps:.4f}, Eval loss: {eval_loss:.4f}")
            print(f"  Pruned: {len(pruned)} heads, Revived: {len(revived)} heads, Total pruned: {plasticity_metrics['total_pruned']}")
            print(f"  Sparsity: {plasticity_metrics['sparsity']:.4f}")
            
            # Reset for next interval
            epoch_loss = 0.0
            epoch_steps = 0
            
            # Back to training mode
            model.train()

## Visualize Training Progress

Let's visualize the training progress, including loss metrics and head pruning/revival.

In [None]:
# Visualize training metrics
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10), sharex=True)

# Plot losses
ax1.plot(metrics_history["step"], metrics_history["train_loss"], label="Train Loss")
ax1.plot(metrics_history["step"], metrics_history["eval_loss"], label="Eval Loss")
ax1.set_ylabel("Loss")
ax1.set_title("Training and Evaluation Loss")
ax1.legend()
ax1.grid(True)

# Plot pruning metrics
ax2.bar(metrics_history["step"], metrics_history["pruned_heads"], alpha=0.5, label="Pruned Heads", color="blue")
ax2.bar(metrics_history["step"], metrics_history["revived_heads"], alpha=0.5, label="Revived Heads", color="green")
ax2.set_xlabel("Step")
ax2.set_ylabel("Count")
ax2.set_title("Head Pruning and Revival")
ax2.legend(loc="upper left")
ax2.grid(True)

# Add sparsity line on secondary axis
ax3 = ax2.twinx()
ax3.plot(metrics_history["step"], metrics_history["sparsity"], "r-", label="Sparsity")
ax3.set_ylabel("Sparsity")
ax3.legend(loc="upper right")

plt.tight_layout()
plt.show()

# Visualize head dynamics
controller.visualize_head_dynamics(metric='entropy')
plt.show()

controller.visualize_head_dynamics(metric='decision')
plt.show()

## Final Evaluation

Let's evaluate the final model to see how it compares to the baseline.

In [None]:
# Final evaluation
final_loss, final_perplexity = evaluate_model(model, validation_dataloader)
print(f"Final evaluation: Loss = {final_loss:.4f}, Perplexity = {final_perplexity:.2f}")
print(f"Baseline:         Loss = {baseline_loss:.4f}, Perplexity = {baseline_perplexity:.2f}")
print(f"Improvement:      {((baseline_loss - final_loss) / baseline_loss * 100):.2f}%")

# Get final summary
summary = controller.get_summary()
print("\nFinal Controller Summary:")
print(f"  Total heads: {summary['total_heads']}")
print(f"  Pruned heads: {summary['pruned_heads']} ({summary['pruning_rate']:.2%})")
print(f"  Model sparsity: {summary['sparsity']:.4f}")
print(f"  Model size: {summary['model_size_mb']:.2f} MB")

## Generate Text with Final Model

Let's generate text with the final model to see if there are any quality differences.

In [None]:
# Generate text with final model
final_text = generate_text(prompt)

print("Baseline Model Output:")
print(baseline_text)
print("\nPlasticity-Optimized Model Output:")
print(final_text)

## Save the Model

Let's save the optimized model for later use.

In [None]:
# Create output directory
import os
from datetime import datetime

output_dir = os.path.join("output", "plasticity", f"run_{datetime.now().strftime('%Y%m%d-%H%M%S')}")
os.makedirs(output_dir, exist_ok=True)

# Save model and tokenizer
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"Model saved to {output_dir}")

## Try Different Prompts

Let's try generating text with different prompts to evaluate the model's capabilities.

In [None]:
prompts = [
    "The meaning of life is",
    "In a distant galaxy",
    "The future of AI will be",
    "Scientists recently discovered"
]

for prompt in prompts:
    print(f"Prompt: {prompt}")
    generated = generate_text(prompt)
    print(f"Generated: {generated}\n")

## Conclusion

In this notebook, we demonstrated Sentinel AI's neural plasticity system, which enables transformer models to dynamically prune and revive attention heads during training based on their utility.

Key findings:
1. The plasticity system successfully pruned high-entropy, low-gradient heads
2. Some heads were revived when they showed potential for useful learning
3. The final model achieved comparable quality with fewer active heads
4. The brain dynamics visualization shows how attention heads evolve over time

This approach mimics biological neural plasticity, where brains form efficient neural pathways by pruning unused connections and strengthening useful ones.

## Version History

- v0.0.5: Significantly more aggressive pruning thresholds (HIGH_ENTROPY_THRESHOLD: 0.6→0.4, LOW_ENTROPY_THRESHOLD: 0.3→0.2, GRAD_THRESHOLD: 5e-5→1e-3)
- v0.0.4: Adjusted pruning thresholds for more aggressive pruning behavior (HIGH_ENTROPY_THRESHOLD: 0.8→0.6, LOW_ENTROPY_THRESHOLD: 0.4→0.3, GRAD_THRESHOLD: 1e-4→5e-5)
- v0.0.3: Removed hard-coded 200-step limit to allow full NUM_EPOCHS training
- v0.0.2: Added warmup phase to get more accurate baseline measurements, improved visualization of head metrics, fixed perplexity calculation issues
- v0.0.1: Initial implementation of neural plasticity demo