
# Neural Plasticity Demo: Dynamic Pruning & Regrowth (v0.0.52 2025-04-19 19:07:02)

### New in v0.0.52:
- Fixed GPU tensor visualization errors
- Fixed visualization utilities integration
- Ensured proper tensor detachment and CPU conversion for visualization
- Integrated with utils.colab.visualizations module
- Added %matplotlib inline for Colab compatibility
- Added system dependency checks
- Improved error handling in training loop
- Fixed tensorboard visualizations 
- Enhanced memory management
- Deduplicated import statements
- Fixed cell execution counts for better notebook flow

This notebook demonstrates Sentinel AI's neural plasticity system, which allows transformer models to dynamically prune and regrow attention heads during training based on utility metrics.

## What is Neural Plasticity?

Neural plasticity is the ability of neural networks to adapt their structure over time through pruning (removing unused connections) and regrowth (restoring useful connections). This mimics how biological brains form efficient neural pathways.

In this demo, we:
1. Track the entropy and gradient patterns of each attention head
2. Dynamically prune high-entropy, low-gradient heads (unfocused, less useful)
3. Selectively revive low-entropy, higher-gradient heads (potentially useful)
4. Visualize the "brain dynamics" over time

This allows models to form more efficient neural structures during training.

### New in v0.0.51:
- Fixed 5 GPU tensor visualization errors (`plt.imshow(...)` on CUDA tensors)
- Ensured all visualizations are Colab-compatible
- Replaced unsafe tensor access with `.detach().cpu().numpy()` where needed


### New in v0.0.51:
- Fixed GPU tensor visualization errors
- Fixed visualization utilities integration
- Ensured proper tensor detachment and CPU conversion for visualization
- Integrated with utils.colab.visualizations module
- Added %matplotlib inline for Colab compatibility
- Added system dependency checks
- Improved error handling in training loop
- Deduplicated import statements
- Fixed cell execution counts for better notebook flow

### New in v0.0.52:
- Fixed GPU tensor visualization errors
- Fixed visualization utilities integration
- Ensured proper tensor detachment and CPU conversion for visualization
- Integrated with utils.colab.visualizations module
- Added %matplotlib inline for Colab compatibility
- Added system dependency checks
- Improved error handling in training loop
- Deduplicated import statements
- Fixed cell execution counts for better notebook flow

In [1]:
# Check and install system dependencies if needed
!apt-get update -qq > /dev/null
!apt-get install -qq libopenblas-dev > /dev/null  # For better performance

In [2]:
# Install required packages
!pip install -q torch transformers datasets matplotlib seaborn

# Clone the Sentinel AI repository
!git clone -b feature/implement-adaptive-plasticity https://github.com/CambrianTech/sentinel-ai.git
%cd sentinel-ai

# Add repository to path
import sys
sys.path.append('.')

# Configure the Experiment

Let's set up our configuration for the neural plasticity experiment

In [3]:
# Configure experiment
MODEL_NAME = "distilgpt2"  # Small GPT-2 model for faster demonstration
DATASET = "wikitext"
DATASET_CONFIG = "wikitext-2-raw-v1"
MAX_LENGTH = 128
BATCH_SIZE = 4
NUM_EPOCHS = 100      # Run for many epochs if needed
LEARNING_RATE = 5e-5
WARMUP_STEPS = 100
WARMUP_MAX_EPOCHS = 1     # Maximum number of warmup epochs (will stop earlier if loss stabilizes)
EVAL_INTERVAL = 50    # Evaluate every 50 steps
VISUALIZATION_INTERVAL = 100  # Show visuals every 100 steps
INFERENCE_INTERVAL = 500      # Run inference every 500 steps
CHECKPOINT_INTERVAL = 500    # Save checkpoint more frequently (was 1000)
MAX_STEPS_PER_EPOCH = None    # Set to a number to limit steps per epoch, or None for unlimited

# Set to True to enable continuous training for long periods
ENABLE_LONG_TRAINING = False  # Set to False for demo purposes to avoid memory/runtime issues

# If ENABLE_LONG_TRAINING is True, run with unlimited steps per epoch
# If ENABLE_LONG_TRAINING is False, override to a reasonable limit for demo purposes
if not ENABLE_LONG_TRAINING:
    MAX_STEPS_PER_EPOCH = 200 # Limit steps per epoch for demo purposes
    NUM_EPOCHS = 3            # Limit epochs for demo purposes

# Configure pruning mode
from sentinel.pruning.dual_mode_pruning import PruningMode

# Set pruning mode (ADAPTIVE allows recovery, COMPRESSED prevents recovery)
PRUNING_MODE = PruningMode.ADAPTIVE  # Change to PruningMode.COMPRESSED for permanent pruning

# Configure statistical-based pruning strategy
# Instead of fixed thresholds, we'll use percentile-based thresholds
ENTROPY_PERCENTILE = 70  # Heads with entropy above the 70th percentile are candidates for pruning
GRADIENT_PERCENTILE = 30  # Heads with gradient below the 30th percentile are candidates for pruning
PRUNE_PERCENT = 0.1      # Target to prune approximately 10% of heads in each step
MIN_ZERO_EPOCHS = 1      # Minimum epochs a head should remain pruned

# Load Model and Dataset

Now we'll load the model and prepare the dataset for training

In [4]:
%matplotlib inline
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    default_data_collator,
    get_linear_schedule_with_warmup
)
from torch.utils.data import DataLoader
from datasets import load_dataset
from sentinel.pruning.plasticity_controller import create_plasticity_controller
from sentinel.pruning.dual_mode_pruning import prune_head_in_model, get_model_info

# Import visualization utilities
from utils.colab.visualizations import TrainingMonitor, visualize_gradient_norms, visualize_attention_heatmap, visualize_head_entropy


# Import helper for safe tensor visualization
from utils.colab.helpers import safe_tensor_imshow

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load model and tokenizer
print(f"Loading model: {MODEL_NAME}")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Set pad token if needed
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load datasets
print(f"Loading dataset: {DATASET}/{DATASET_CONFIG}")
train_dataset = load_dataset(DATASET, DATASET_CONFIG, split="train")
validation_dataset = load_dataset(DATASET, DATASET_CONFIG, split="validation")

# Define tokenization function
def tokenize_function(examples):
    return tokenizer(
        examples["text"], 
        padding="max_length", 
        truncation=True, 
        max_length=MAX_LENGTH
    )

# Tokenize datasets
train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
validation_dataset = validation_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# Add labels for language modeling
def add_labels(examples):
    examples["labels"] = examples["input_ids"].copy()
    return examples

train_dataset = train_dataset.map(add_labels)
validation_dataset = validation_dataset.map(add_labels)

# Set format
train_dataset = train_dataset.with_format("torch")
validation_dataset = validation_dataset.with_format("torch")

# Create dataloaders
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=default_data_collator
)

validation_dataloader = DataLoader(
    validation_dataset, 
    batch_size=BATCH_SIZE, 
    collate_fn=default_data_collator
)

print(f"Train dataset size: {len(train_dataset)} examples")
print(f"Validation dataset size: {len(validation_dataset)} examples")

# Define Evaluation Function

Let's define a function to evaluate our model's performance

In [5]:
def evaluate_model(model, dataloader):
    """Evaluate model on the provided dataloader."""
    model.eval()
    total_loss = 0.0
    total_steps = 0
    
    with torch.no_grad():
        for batch in dataloader:
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            
            # Forward pass
            outputs = model(**batch)
            loss = outputs.loss
            
            total_loss += loss.item()
            total_steps += 1
            
            # Limit evaluation to 10 steps for speed
            if total_steps >= 10:
                break
    
    avg_loss = total_loss / total_steps if total_steps > 0 else float("inf")
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    
    return avg_loss, perplexity

def generate_text(prompt, max_length=100):
    """Generate text from the model."""
    # Set model to evaluation mode
    model.eval()
    
    # Encode prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    # Generate text
    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            temperature=0.7,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode and return text
    return tokenizer.decode(output[0], skip_special_tokens=True)

## Run Model Warm-up

Before measuring baseline performance and applying neural plasticity, we'll run a brief warm-up phase to get initial attention patterns and stabilize metrics.

In [6]:
# Initialize optimizer and scheduler for warm-up
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
total_steps = len(train_dataloader) * WARMUP_MAX_EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=WARMUP_STEPS, 
    num_training_steps=total_steps
)

print(f"Running warm-up until loss stabilizes (max {WARMUP_MAX_EPOCHS} epochs)...")

# Warm-up training loop
model.train()
warmup_losses = []
warmup_step_losses = []
last_loss_decrease = 0
patience = 15      # Number of steps with no decrease to consider stabilized
min_warmup_steps = 50  # Minimum number of warm-up steps
max_warmup_steps = 150  # Maximum number of warm-up steps per epoch

# Helper function to calculate if loss has stabilized 
def is_loss_stabilized(losses, min_steps, patience_steps, window_size=5):
    # Not enough steps yet
    if len(losses) < min_steps:
        return False, 0

    # Not enough steps since last decrease
    steps_since_decrease = len(losses) - last_loss_decrease
    if steps_since_decrease < patience_steps:
        return False, steps_since_decrease
    
    # Check if recent trend is flat or increasing using rolling average
    if len(losses) >= window_size * 2:
        recent_window = sum(losses[-window_size:]) / window_size
        previous_window = sum(losses[-(window_size*2):-window_size]) / window_size
        # If recent average is lower than previous, we're still decreasing
        if recent_window < previous_window * 0.99:  # Allow 1% variation
            return False, steps_since_decrease
            
    return True, steps_since_decrease

try:
    for epoch in range(WARMUP_MAX_EPOCHS):
        epoch_loss = 0.0
        epoch_steps = 0
        
        for step, batch in enumerate(train_dataloader):
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            
            # Forward pass
            outputs = model(**batch)
            loss = outputs.loss
            
            # Backward pass
            loss.backward()
            
            # Update weights
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            # Track loss
            loss_val = loss.item()
            epoch_loss += loss_val
            epoch_steps += 1
            warmup_losses.append(loss_val)
            
            # Check if we've met the minimum steps and loss has stabilized
            if len(warmup_losses) > 1:
                # Track non-increasing steps
                if loss_val <= warmup_losses[-2]:
                    last_loss_decrease = len(warmup_losses)
                
                # For visualization, track a smoothed version (rolling average of 5)
                if len(warmup_losses) % 5 == 0:
                    avg_loss = sum(warmup_losses[-5:]) / 5
                    warmup_step_losses.append(avg_loss)
            
            # Print progress every 5 steps
            if step % 5 == 0:
                print(f"Warm-up Epoch {epoch+1}, Step {step}: Loss = {loss_val:.4f}", end='\r')
            
            # Check if loss has stabilized
            is_stable, steps_without_decrease = is_loss_stabilized(
                warmup_losses, min_warmup_steps, patience
            )
            
            if is_stable:
                print(f"\nWarm-up loss stabilized after {len(warmup_losses)} steps")
                print(f"Loss has been non-decreasing for {steps_without_decrease} steps")
                break
                
            # Stop after max_warmup_steps for faster execution in demo
            if step >= max_warmup_steps:
                print(f"\nReached maximum warm-up steps per epoch ({max_warmup_steps})")
                break
        
        print(f"\nWarm-up Epoch {epoch+1} completed: Average Loss = {epoch_loss / epoch_steps:.4f}")
        
        # Check if loss has stabilized across epochs
        is_stable, steps_without_decrease = is_loss_stabilized(
            warmup_losses, min_warmup_steps, patience
        )
        
        if is_stable:
            print(f"Loss has stabilized with {steps_without_decrease} steps without significant decrease.")
            print(f"Ending warm-up early after {epoch+1} epochs.")
            break
    
    # Plot warm-up loss
    plt.figure(figsize=(12, 8))
    
    # Raw loss
    plt.subplot(2, 1, 1)
    plt.plot(warmup_losses)
    plt.title("Warm-up Loss (Raw)")
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.grid(True)
    
    # Smoothed loss if we have enough data
    if len(warmup_step_losses) > 1:
        plt.subplot(2, 1, 2)
        plt.plot(range(0, len(warmup_step_losses)*5, 5), warmup_step_losses)
        plt.title("Warm-up Loss (5-step Rolling Average)")
        plt.xlabel("Step")
        plt.ylabel("Loss")
        plt.grid(True)
        
        # Add trend line to smoothed plot
        from scipy.stats import linregress
        x = range(0, len(warmup_step_losses)*5, 5)
        slope, intercept, r_value, p_value, std_err = linregress(x, warmup_step_losses)
        plt.plot(x, [slope*xi + intercept for xi in x], 'r--', 
                 label=f'Trend: slope={slope:.6f}, R²={r_value**2:.2f}')
        plt.legend()
    
    # Consider using constrained_layout=True instead of tight_layout()
    plt.tight_layout()
    plt.show()
    
    # Segment analysis - compare first third vs last third of training
    if len(warmup_losses) > 6:
        segment_size = len(warmup_losses) // 3
        first_segment = warmup_losses[:segment_size]
        last_segment = warmup_losses[-segment_size:]
        first_avg = sum(first_segment) / len(first_segment)
        last_avg = sum(last_segment) / len(last_segment)
        
        print(f"\nWarm-up Segment Analysis:")
        print(f"First {segment_size} steps average loss: {first_avg:.4f}")
        print(f"Last {segment_size} steps average loss: {last_avg:.4f}")
        print(f"Improvement during warm-up: {(1 - last_avg/first_avg)*100:.1f}%")
        
        # Calculate if still improving significantly
        still_improving = (first_avg - last_avg) / first_avg > 0.01  # More than 1% improvement
        print(f"Is model still significantly improving? {'Yes' if still_improving else 'No'}")
    
    # Print warm-up summary
    print(f"\nWarm-up completed with {len(warmup_losses)} steps across {epoch+1} epochs")
    print(f"Initial loss: {warmup_losses[0]:.4f}")
    print(f"Final loss: {warmup_losses[-1]:.4f}")
    print(f"Overall loss reduction: {(1 - warmup_losses[-1]/warmup_losses[0])*100:.1f}%")

except Exception as e:
    print(f"\nError during training: {e}")
    # Try to save checkpoint on error
    try:
        error_checkpoint_path = save_checkpoint(global_step, epoch + 1)
        print(f"Checkpoint saved at {error_checkpoint_path}")
    except Exception as save_error:
        print(f"Could not save checkpoint: {save_error}")


# Evaluate Baseline Model

Now let's measure the baseline performance after warm-up

In [7]:
# Evaluate baseline model after warm-up
baseline_loss, baseline_perplexity = evaluate_model(model, validation_dataloader)
print(f"Baseline evaluation after warm-up: Loss = {baseline_loss:.4f}, Perplexity = {baseline_perplexity:.2f}")

# Generate text with baseline model
prompt = "Once upon a time"
baseline_text = generate_text(prompt)
print(f"\nPrompt: {prompt}")
print(f"Generated text:\n{baseline_text}")

## Create Neural Plasticity Controller

Now we'll create our neural plasticity controller that will monitor attention heads and make pruning decisions.

In [8]:
# Create a custom statistical pruning function based only on gradients
def gradient_based_pruning(grad_norm_values, prune_percent=0.1):
    """
    Make pruning decisions based only on gradient norms.
    We want to prune heads with LOWEST gradient norms, as they're
    learning the least.
    
    Args:
        grad_norm_values: Tensor of gradient norm values for all heads
        prune_percent: Target percentage of heads to prune (0-1)
        
    Returns:
        pruning_mask: Boolean tensor where True indicates a head should be pruned
    """
    # Flatten tensor for calculating percentiles
    flat_grad_norm = grad_norm_values.view(-1)
    
    # Calculate how many heads we want to prune
    total_heads = grad_norm_values.numel()
    target_prune_count = int(total_heads * prune_percent)
    
    # Get the indices of the heads with the LOWEST gradient norms
    # Here's the fix: we use largest=False to get the lowest values
    _, indices = torch.topk(flat_grad_norm, k=target_prune_count, largest=False)
    
    # Create pruning mask where True = head should be pruned (low gradient norm)
    pruning_mask = torch.zeros_like(grad_norm_values, dtype=torch.bool)
    pruning_mask.view(-1)[indices] = True
    
    print(f"Gradient-based pruning - target: {target_prune_count} heads")
    print(f"Final pruning decision: pruning {pruning_mask.sum().item()} heads")
    print(f"Average grad norm of pruned heads: {grad_norm_values[pruning_mask].mean().item():.6f}")
    print(f"Average grad norm of kept heads: {grad_norm_values[~pruning_mask].mean().item():.6f}")
    return pruning_mask



## Important Note on Cell Execution Order

⚠️ **Critical**: Cells in this notebook must be executed in order.

The next cell creates the plasticity controller that's used throughout the notebook. Make sure to run:


1.  The cell below that creates the controller


2.  Then the debug cell after it


3.  Then continue with the rest of the notebook in sequence

If you get `NameError: name 'controller' is not defined`, go back and run the controller creation cell first.

In [9]:
# NOTE: This cell requires the controller to be defined
# Create plasticity controller with default thresholds
controller = create_plasticity_controller(
    model=model,
    mode=PRUNING_MODE,
    high_entropy_threshold=0.8,  # These will be ignored by our custom approach
    low_entropy_threshold=0.4,   # but we need to provide values
    grad_threshold=1e-3,
    min_zero_epochs=MIN_ZERO_EPOCHS
)

# Display initial model stats
initial_stats = controller.get_summary()
print(f"Model has {initial_stats['total_heads']} attention heads across {controller.total_layers} layers")

# Override the controller's entropy calculation to fix zero entropy issues
# We're replacing the internal calculation method with a more numerically stable version
def better_entropy_calculation(attn_probs, eps=1e-6):
    """Calculate entropy with better numerical stability."""
    # Add small epsilon to avoid log(0) issues
    attn_probs = attn_probs.clamp(min=eps)
    
    # Normalize to ensure it's a proper probability distribution
    attn_probs = attn_probs / attn_probs.sum(dim=-1, keepdim=True)
    
    # Standard entropy calculation
    return -torch.sum(attn_probs * torch.log(attn_probs), dim=-1)

# Monkey-patch the controller's entropy calculation method
import types
if hasattr(controller, 'calculate_attention_entropy'):
    # For controllers with direct entropy calculation method
    controller.calculate_attention_entropy = types.MethodType(
        lambda self, attention_maps: better_entropy_calculation(attention_maps), controller)
    print("Patched controller's entropy calculation")
elif hasattr(controller, '_compute_entropy'):
    # For controllers with internal _compute_entropy method
    old_compute_entropy = controller._compute_entropy
    
    def patched_compute_entropy(self, attention_maps, eps=1e-6):
        """Improved entropy calculation with better numerical stability."""
        if not attention_maps:
            return 0.0
        
        # Concatenate all maps
        maps = torch.cat(attention_maps, dim=0)
        
        # Apply the better entropy calculation to the raw attention maps
        entropies = better_entropy_calculation(maps, eps=eps)
        
        # Average over batch and sequence length
        avg_entropy = entropies.mean().item()
        
        # Normalize to [0,1] range
        max_entropy = torch.log(torch.tensor(maps.size(-1), dtype=torch.float))
        normalized_entropy = avg_entropy / max_entropy.item()
        
        return normalized_entropy
    
    controller._compute_entropy = types.MethodType(patched_compute_entropy, controller)
    print("Patched controller's _compute_entropy method")

# Fix entropy calculation to ensure proper numerical stability
# This code adds diagnostic printing of attention probabilities
old_collect_metrics = controller.collect_head_metrics

def patched_collect_metrics(self, dataloader, num_batches=5):
    # Call original method
    entropy, grads = old_collect_metrics(dataloader, num_batches)
    
    # Print diagnostic information about raw attention values
    print("\nDIAGNOSTIC: Attention and entropy statistics")
    try:
        # Get a data batch
        inputs = next(iter(dataloader))
        if isinstance(inputs, dict):
            inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
        else:
            inputs = {"input_ids": inputs[0].to(device)}
        
        # Run model to get attention values
        model.eval()
        with torch.no_grad():
            outputs = model(**inputs, output_attentions=True)
        
        if hasattr(outputs, 'attentions') and outputs.attentions is not None:
            attn = outputs.attentions[0]  # First layer
            print(f"Attn tensor shape: {attn.shape}")
            print(f"min/max/mean: {attn.min().item():.2e}/{attn.max().item():.2e}/{attn.mean().item():.2e}")
            print(f"sum=1 check: {torch.allclose(attn.sum(dim=-1), torch.ones_like(attn.sum(dim=-1)))}")
    except Exception as e:
        print(f"Error in diagnostic: {e}")
    
    return entropy, grads

# Apply the patch
controller.collect_head_metrics = types.MethodType(patched_collect_metrics, controller)
print("Applied diagnostic patch to controller")


In [10]:
# Debug: Let's check the actual entropy values we're dealing with
print("\nCollecting initial entropy and gradient metrics for debugging...")

# Make absolutely sure controller, model, and validation_dataloader are defined
try:
    # Check if we have all the necessary variables
    controller
    model
    validation_dataloader
    device
except NameError as e:
    print(f"ERROR: Missing variable: {e}")
    print("Please run the previous cells first to set up model, controller, and data.")
    # Create empty placeholders to allow this cell to run
    if 'controller' not in globals():
        print("Creating placeholder controller...")
        from types import SimpleNamespace
        controller = SimpleNamespace()
        controller.collect_head_metrics = lambda *args, **kwargs: (torch.zeros(12, 12), torch.zeros(12, 12))
    raise

# This will collect attention entropy and gradient values
try:
    debug_entropy, debug_grads = controller.collect_head_metrics(
        validation_dataloader,
        num_batches=2
    )
    
    # Print entropy statistics
    print("\nEntropy statistics:")
    print(f"Mean entropy: {debug_entropy.mean().item():.4f}")
    print(f"Min entropy: {debug_entropy.min().item():.4f}")
    print(f"Max entropy: {debug_entropy.max().item():.4f}")
    print(f"25th percentile: {torch.quantile(debug_entropy.flatten(), 0.25).item():.4f}")
    print(f"50th percentile: {torch.quantile(debug_entropy.flatten(), 0.5).item():.4f}")
    print(f"75th percentile: {torch.quantile(debug_entropy.flatten(), 0.75).item():.4f}")
    print(f"Are all entropy values the same? {torch.allclose(debug_entropy, debug_entropy[0,0])}")
    print(f"Non-zero values: {torch.count_nonzero(debug_entropy)}/{debug_entropy.numel()}")
    
    # Add diagnostic to debug attention probability tensor
    print("\nDIAGNOSTIC: Checking raw attention probability distributions...")
    try:
        # Get a data batch safely
        inputs = next(iter(validation_dataloader))
        if isinstance(inputs, dict):
            inputs = {k: v.to(device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
        else:
            inputs = {"input_ids": inputs[0].to(device)}
        
        # Get model outputs with attention
        model.eval()
        with torch.no_grad():
            outputs = model(**inputs, output_attentions=True)
        
        # Analyze attention tensors
        if hasattr(outputs, 'attentions') and outputs.attentions is not None:
            attn_tensors = outputs.attentions
            layer_idx = 0  # Check first layer
            
            if len(attn_tensors) > 0:
                attn = attn_tensors[layer_idx]  # First layer attention
                
                # Print attention tensor stats to verify it's a valid probability distribution
                print(f"Attention tensor shape: {attn.shape}")
                print(f"Attention tensor dtype: {attn.dtype}")
                print(f"Attention tensor stats: min={attn.min().item():.6e}, max={attn.max().item():.6e}, mean={attn.mean().item():.6e}")
                
                # Check if values sum to 1 along attention dimension
                attn_sum = attn.sum(dim=-1)
                print(f"Sum along attention dimension: min={attn_sum.min().item():.6f}, max={attn_sum.max().item():.6f}")
                print(f"Close to 1.0? {torch.allclose(attn_sum, torch.ones_like(attn_sum), rtol=1e-3)}")
                
                # Check for very small values that might cause log(0) issues
                small_values = (attn < 1e-6).float().mean().item() * 100
                print(f"Percentage of very small values (<1e-6): {small_values:.2f}%")
                
                # Check for NaN or infinity
                print(f"Contains NaN: {torch.isnan(attn).any().item()}")
                print(f"Contains Inf: {torch.isinf(attn).any().item()}")
                
                # Fix entropy calculation function with better defaults
                def improved_entropy_calculation(attn_probs, eps=1e-8):
                    """Compute entropy with better numerical stability."""
                    # Ensure valid probability distribution
                    attn_probs = attn_probs.clamp(min=eps)
                    normalized_probs = attn_probs / attn_probs.sum(dim=-1, keepdim=True)
                    
                    # Compute entropy
                    log_probs = torch.log(normalized_probs)
                    entropy = -torch.sum(normalized_probs * log_probs, dim=-1)
                    return entropy
                
                # Calculate entropy using improved function
                improved_entropy = improved_entropy_calculation(attn).mean(dim=(0, 1))
                print("\nImproved entropy calculation results:")
                print(f"Mean entropy: {improved_entropy.mean().item():.4f}")
                print(f"Min entropy: {improved_entropy.min().item():.4f}")
                print(f"Max entropy: {improved_entropy.max().item():.4f}")
                
                # Add visualization of attention patterns
                print("\nVisualizing attention pattern for one head...")
                head_idx = 0
                plt.figure(figsize=(8, 6))
                attention_map = safe_tensor_imshow(attn[0, head_idx], title=f'Attention pattern (layer {layer_idx}, head {head_idx})').cpu().numpy().cpu().numpy()).cpu().numpy())
                plt.clim(0, 1.0)  # Ensure proper scaling for attention values, cmap='viridis')
                plt.clim(0, 1.0)  # Ensure proper scaling for attention visualization
                plt.colorbar(label='Attention probability')
                plt.title(f'Attention pattern (layer {layer_idx}, head {head_idx})')
                plt.xlabel('Sequence position (to)')
                plt.ylabel('Sequence position (from)')
                plt.show()
                
                # Add histogram of attention values
                plt.figure(figsize=(8, 4))
                plt.hist(attn[0, head_idx].flatten().cpu().numpy(), bins=50, alpha=0.7)
                plt.title('Histogram of attention probabilities')
                plt.xlabel('Probability value')
                plt.ylabel('Frequency')
                plt.grid(alpha=0.3)
                plt.show()
            else:
                print("No attention tensors found in the model output")
        else:
            print("Model output doesn't have attention tensors. Check if output_attentions=True is supported.")
    except Exception as e:
        print(f"Error in attention diagnostic: {e}")
        
    # Print gradient statistics
    print("\nGradient statistics:")
    print(f"Mean gradient norm: {debug_grads.mean().item():.4f}")
    print(f"Min gradient norm: {debug_grads.min().item():.4f}")
    print(f"Max gradient norm: {debug_grads.max().item():.4f}")
    print(f"25th percentile: {torch.quantile(debug_grads.flatten(), 0.25).item():.4f}")
    print(f"50th percentile: {torch.quantile(debug_grads.flatten(), 0.5).item():.4f}")
    print(f"75th percentile: {torch.quantile(debug_grads.flatten(), 0.75).item():.4f}")
except Exception as e:
    print(f"Error collecting metrics: {e}")

In [11]:
# Enhanced Entropy Analysis
# Run this after collecting the metrics to better understand the entropy issues

# Function to compute improved entropy with diagnostics
def compute_improved_entropy(attn_probs, eps=1e-8, debug=True):
    """Compute entropy with better numerical stability and detailed diagnostics."""
    if debug:
        # Print raw attention stats
        print(f"Raw attention shape: {attn_probs.shape}")
        print(f"Raw min/max/mean: {attn_probs.min().item():.6e}/{attn_probs.max().item():.6e}/{attn_probs.mean().item():.6e}")
        
        # Check for numerical issues
        print(f"Contains zeros: {(attn_probs == 0).any().item()}")
        print(f"Contains NaN: {torch.isnan(attn_probs).any().item()}")
        print(f"Contains Inf: {torch.isinf(attn_probs).any().item()}")
        
        # Check distribution validity
        row_sums = attn_probs.sum(dim=-1)
        print(f"Row sums min/max/mean: {row_sums.min().item():.6f}/{row_sums.max().item():.6f}/{row_sums.mean().item():.6f}")
        print(f"Rows sum to ~1: {torch.allclose(row_sums, torch.ones_like(row_sums), rtol=1e-2)}")
    
    # Apply numerical safeguards
    # 1. Ensure positive values
    attn_probs = attn_probs.clamp(min=eps)
    
    # 2. Normalize to ensure it sums to 1.0 along attention dimension
    attn_probs = attn_probs / attn_probs.sum(dim=-1, keepdim=True)
    
    if debug:
        print("\nAfter preprocessing:")
        print(f"Min/max/mean: {attn_probs.min().item():.6e}/{attn_probs.max().item():.6e}/{attn_probs.mean().item():.6e}")
        row_sums = attn_probs.sum(dim=-1)
        print(f"Row sums min/max/mean: {row_sums.min().item():.6f}/{row_sums.max().item():.6f}/{row_sums.mean().item():.6f}")
    
    # Compute entropy: -sum(p * log(p))
    log_probs = torch.log(attn_probs)
    entropy = -torch.sum(attn_probs * log_probs, dim=-1)
    
    if debug:
        print("\nEntropy results:")
        print(f"Entropy shape: {entropy.shape}")
        print(f"Entropy min/max/mean: {entropy.min().item():.4f}/{entropy.max().item():.4f}/{entropy.mean().item():.4f}")
        
        # Compute theoretical maximum entropy (uniform distribution)
        seq_len = attn_probs.size(-1)
        max_entropy = torch.log(torch.tensor(seq_len, dtype=torch.float))
        print(f"Theoretical max entropy (log(seq_len)): {max_entropy.item():.4f}")
        
        # Check if entropy is at maximum (uniform attention)
        print(f"Percentage of maximum entropy: {entropy.mean().item()/max_entropy.item()*100:.2f}%")
    
    return entropy

# Get the raw attention patterns from the model for analysis
try:
    # Get a batch of data
    inputs = next(iter(validation_dataloader))
    if isinstance(inputs, dict):
        inputs = {k: v.to(device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
    else:
        inputs = {"input_ids": inputs[0].to(device)}
    
    # Run model with attention outputs
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
    
    # Extract attention patterns
    if hasattr(outputs, 'attentions') and outputs.attentions is not None:
        attn_list = outputs.attentions
        if len(attn_list) > 0:
            # Create a detailed visualization of attention patterns and entropy
            num_layers = len(attn_list)
            fig, axes = plt.subplots(num_layers, 2, figsize=(12, num_layers*3))
            
            layer_entropies = []
            layer_entropies_norm = []
            
            for layer_idx in range(num_layers):
                attn = attn_list[layer_idx]
                
                # Compute entropy for this layer's attention
                print(f"\n=== Analyzing Layer {layer_idx} Attention ====")
                layer_entropy = compute_improved_entropy(attn, debug=True)
                
                # Save mean entropy per head
                head_entropies = layer_entropy.mean(dim=(0, 1))  # Average over batch and sequence
                layer_entropies.append(head_entropies)
                
                # Normalize by max possible entropy
                seq_len = attn.size(-1)
                max_entropy = torch.log(torch.tensor(seq_len, dtype=torch.float, device=attn.device))
                norm_entropies = head_entropies / max_entropy.item()
                layer_entropies_norm.append(norm_entropies)
                
                # Plot attention pattern for first head
                if isinstance(axes, np.ndarray) and len(axes.shape) > 1:  # multiple rows and cols
                    ax1 = axes[layer_idx, 0]
                    ax2 = axes[layer_idx, 1]
                else:  # only 1 layer, so axes is 1D
                    ax1 = axes[0]
                    ax2 = axes[1]
                
                # Plot attention pattern
                attn_pattern = attn[0, 0].cpu().numpy()  # First batch, first head
                im = ax1.imshow(attn_pattern, cmap='viridis')
                ax1.set_title(f'Layer {layer_idx} - Head 0 Attention')
                ax1.set_xlabel('Position (To)')
                ax1.set_ylabel('Position (From)')
                plt.colorbar(im, ax=ax1)
                # Set proper limits for attention values (0 to 1)
                im.set_clim(0, 1.0)
                
                # Plot entropy values for all heads
                ax2.bar(range(len(head_entropies)), head_entropies.cpu().numpy())
                ax2.axhline(y=max_entropy.item(), color='r', linestyle='--', alpha=0.7, label='Max Entropy')
                ax2.set_title(f'Layer {layer_idx} - Head Entropies')
                ax2.set_xlabel('Head Index')
                ax2.set_ylabel('Entropy')
                ax2.legend()
                
                # Add entropy values as text on the bars
                for i, v in enumerate(head_entropies):
                    ax2.text(i, v.item() + 0.1, f'{v.item():.2f}', ha='center')
            
            plt.tight_layout()
            plt.show()
            
            # Create a heatmap of entropy across all layers and heads
            if num_layers > 1:
                all_entropies = torch.stack(layer_entropies).cpu().numpy()
                plt.figure(figsize=(10, 6))
                plt.imshow(all_entropies.detach().cpu().numpy().cpu().numpy()).cpu().numpy())
                plt.clim(0, max(0.1, all_entropies.max()))  # Ensure non-zero range
                plt.colorbar(label='Entropy')
                plt.title('Entropy Heatmap Across All Layers and Heads')
                plt.xlabel('Head Index')
                plt.ylabel('Layer Index')
                
                # Add text annotations for each cell
                for i in range(all_entropies.shape[0]):
                    for j in range(all_entropies.shape[1]):
                        text = plt.text(j, i, f'{all_entropies[i, j]:.2f}',
                                      ha="center", va="center", color="w")
                
                plt.tight_layout()
                plt.show()
                
                # Plot normalized entropy (as percentage of maximum)
                all_norm_entropies = torch.stack(layer_entropies_norm).cpu().numpy() * 100  # as percentage
                plt.figure(figsize=(10, 6))
                plt.imshow(all_norm_entropies, cmap='viridis', aspect='auto', vmin=0, vmax=100.detach().cpu().numpy().cpu().numpy())
                plt.colorbar(label='% of Max Entropy')
                plt.title('Normalized Entropy (% of Maximum)')
                plt.xlabel('Head Index')
                plt.ylabel('Layer Index')
                
                # Add text annotations for each cell
                for i in range(all_norm_entropies.shape[0]):
                    for j in range(all_norm_entropies.shape[1]):
                        text = plt.text(j, i, f'{all_norm_entropies[i, j]:.1f}%',
                                      ha="center", va="center", color="w")
                
                plt.tight_layout()
                plt.show()
        else:
            print("No attention tensors returned by the model")
    else:
        print("Model outputs don't include attention weights")
except Exception as e:
    print(f"Error in entropy analysis: {e}")

In [12]:
# Test our gradient-only pruning approach

# Make sure we have debug_grads
try:
    # Check if debug_grads is defined
    debug_grads
    print("Using existing debug_grads")
except NameError:
    print("debug_grads not found, collecting metrics...")
    # Try to collect metrics
    try:
        debug_entropy, debug_grads = controller.collect_head_metrics(
            validation_dataloader,
            num_batches=2
        )
    except Exception as e:
        print(f"Error collecting metrics: {e}")
        # Create a dummy tensor if everything fails
        print("Creating dummy debug_grads")
        debug_grads = torch.zeros(6, 12)  # Default size for distilgpt2 (6 layers, 12 heads)

pruning_mask = gradient_based_pruning(
    debug_grads,
    prune_percent=PRUNE_PERCENT
)

# Visualize pruning mask
plt.figure(figsize=(10, 6))
safe_tensor_imshow(pruning_mask, title='Visualization of pruning_mask').cpu().numpy().cpu().numpy(), cmap='Reds', aspect='auto')
plt.colorbar(label='Prune')
plt.title(f'Pruning Mask (prune {PRUNE_PERCENT*100:.0f}% of heads)')
plt.xlabel('Head')
plt.ylabel('Layer')
plt.show()

print(f"\nPruning Analysis:")
pruned_count = pruning_mask.sum().item()
total_count = pruning_mask.numel()
print(f"Pruning {pruned_count}/{total_count} heads ({pruned_count/total_count*100:.1f}%)")

In [13]:
# Create a visual comparing entropy and gradient distributions
plt.figure(figsize=(10, 6))

# Function to properly calculate entropy
def calculate_proper_entropy(attn_tensor, eps=1e-8):
    # Calculate entropy with proper normalization and numerical stability
    # Get attention shape
    batch_size, num_heads, seq_len, _ = attn_tensor.shape
    
    # Reshape for processing
    attn_flat = attn_tensor.view(batch_size * num_heads * seq_len, -1)
    
    # Handle numerical issues - ensure positive values and proper normalization
    attn_flat = attn_flat.clamp(min=eps)
    attn_flat = attn_flat / attn_flat.sum(dim=-1, keepdim=True)
    
    # Calculate entropy: -sum(p * log(p))
    entropy = -torch.sum(attn_flat * torch.log(attn_flat), dim=-1)
    
    # Reshape back to per-head format and average
    entropy = entropy.view(batch_size, num_heads, seq_len)
    entropy = entropy.mean(dim=(0, 2))  # Average over batch and sequence
    
    # Normalize by maximum possible entropy (log of sequence length)
    max_entropy = torch.log(torch.tensor(attn_tensor.size(-1), dtype=torch.float, device=attn_tensor.device))
    
    # View as layers x heads
    return entropy.view(-1, num_heads)

# Get attention outputs to calculate entropy directly
try:
    # Get sample data
    inputs = next(iter(validation_dataloader))
    if isinstance(inputs, dict):
        inputs = {k: v.to(device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
    else:
        inputs = {"input_ids": inputs[0].to(device)}
    
    # Forward pass with attention outputs
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
    
    # Calculate entropy from raw attention
    if hasattr(outputs, 'attentions') and outputs.attentions is not None:
        # Extract attention tensors
        attentions = outputs.attentions
        
        # Print diagnostic info
        print(f"Number of attention layers: {len(attentions)}")
        first_attn = attentions[0]
        print(f"Attention shape: {first_attn.shape}")
        print(f"Attention statistics - min: {first_attn.min().item():.6f}, max: {first_attn.max().item():.6f}")
        
        # Check if attention sums to 1 along correct dimension
        attn_sum = first_attn.sum(dim=-1)
        print(f"Attention sum along last dim - min: {attn_sum.min().item():.6f}, max: {attn_sum.max().item():.6f}")
        
        # Calculate proper entropy for all layers
        all_entropies = torch.cat([calculate_proper_entropy(attn) for attn in attentions])
        
        # Print entropy statistics
        print(f"Calculated entropy shape: {all_entropies.shape}")
        print(f"Entropy statistics - min: {all_entropies.min().item():.6f}, max: {all_entropies.max().item():.6f}")
        
        # Side by side plots
        plt.subplot(1, 2, 1)
        im1 = plt.imshow(all_entropies.cpu(.detach().cpu().numpy().cpu().numpy()))
        plt.clim(0, 1.0)  # Ensure proper scaling for attention values, cmap='viridis', aspect='auto')
        plt.clim(0, max(0.1, all_entropies.max().item()))  # Ensure proper visualization range
        plt.colorbar(im1, label='Entropy')
        plt.title(f'Properly Calculated Attention Entropy (max={all_entropies.max().item():.4f})')
        plt.xlabel('Head Index')
        plt.ylabel('Layer Index')
        
        # Use the gradient tensor from debug metrics
        plt.subplot(1, 2, 2)
        im2 = plt.imshow(debug_grads.detach().cpu().numpy().cpu().numpy())
        plt.clim(0, 1.0)  # Ensure proper scaling for attention values.cpu().numpy(), cmap='plasma', aspect='auto')
        plt.colorbar(im2, label='Gradient Norm')
        plt.title('Gradient Norms')
        plt.xlabel('Head Index')
        plt.ylabel('Layer Index')
        
        plt.tight_layout()
        plt.show()
        
        # Create a scatter plot to show relationship
        plt.figure(figsize=(8, 6))
        entropy_flat = all_entropies.flatten().cpu().numpy()
        grad_flat = debug_grads.flatten().cpu().numpy()
        
        plt.scatter(entropy_flat, grad_flat, alpha=0.7)
        plt.xlabel('Entropy (higher = less focused)')
        plt.ylabel('Gradient Norm (higher = more impact)')
        plt.title('Entropy vs Gradient Relationship')
        plt.grid(alpha=0.3)
        plt.show()
        
    else:
        print("Model did not return attention tensors")
except Exception as e:
    print(f"Error in entropy calculation: {e}")
    
    # Fallback - use debug_entropy that was already collected
    # Plot entropy with a manual scale to force visibility
    plt.subplot(1, 2, 1)
    
    # Enforce a minimum scale for visibility
    entropy_data = debug_entropy.detach().cpu().numpy()
    im1 = safe_tensor_imshow(entropy_data, title='Visualization of entropy_data').cpu().numpy().cpu().numpy()).cpu().numpy())))
    plt.clim(0, max(0.1, entropy_data.max()))  # Ensure proper entropy range
    plt.clim(0, 1.0)  # Ensure proper scaling for attention values))
    plt.colorbar(im1, label='Entropy')
    plt.title(f'Attention Entropy Values (max={entropy_data.max():.4f})')
    plt.xlabel('Head Index')
    plt.ylabel('Layer Index')

    # Gradient subplot
    plt.subplot(1, 2, 2)
    im2 = plt.imshow(debug_grads.detach().cpu().numpy().cpu().numpy())
    plt.clim(0, 1.0)  # Ensure proper scaling for attention values.cpu().numpy(), cmap='plasma', aspect='auto')
    plt.colorbar(im2, label='Gradient Norm')
    plt.title('Gradient Norms')
    plt.xlabel('Head Index')
    plt.ylabel('Layer Index')
    
    plt.tight_layout()
    plt.show()

## Collect Initial Head Metrics

Let's look at the initial head metrics to establish our baseline.

In [14]:
# NOTE: This cell requires the controller to be defined
# Collect initial head metrics
# Check if controller exists
try:
    controller
except NameError:
    print("ERROR: The controller variable is not defined. Please run the cell that creates the plasticity controller first.")
    # Create a simple dummy controller to avoid breaking the notebook flow
    controller = SimpleNamespace()
    controller.collect_head_metrics = lambda *args, **kwargs: (None, None)
    controller.display_stats = lambda *args, **kwargs: None
    controller.stats = {}
    controller.total_layers = 0
    controller.heads_per_layer = 0


entropy_values, grad_norm_values = controller.collect_head_metrics(
    validation_dataloader, 
    num_batches=2
)

# Function to visualize gradients without relying on Unicode
def visualize_gradient_norms(grad_norm_values, pruned_heads=None, revived_heads=None, title="Gradient Norms", save_path=None):
    """Create a visualization of gradient norms with markers for pruned/revived heads"""
    plt.figure(figsize=(10, 5))
    plt.imshow(grad_norm_values.detach().cpu().numpy().cpu().numpy(), cmap="plasma", aspect="auto")
    plt.colorbar(label="Gradient Norm")
    
    # Mark pruned heads with 'P'
    if pruned_heads:
        for layer, head in pruned_heads:
            plt.text(head, layer, "P", ha="center", va="center", 
                     color="white", weight="bold", bbox=dict(facecolor='red', alpha=0.5))
    
    # Mark revived heads with 'R'
    if revived_heads:
        for layer, head in revived_heads:
            plt.text(head, layer, "R", ha="center", va="center", 
                     color="white", weight="bold", bbox=dict(facecolor='green', alpha=0.5))
    
    plt.title(title)
    plt.xlabel("Head Index")
    plt.ylabel("Layer Index")
    # Consider using constrained_layout=True instead of tight_layout()
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=100, bbox_inches='tight')
    
    return plt.gcf()

# Create a better pruning mask for visualization
# Get the indices of the heads with the LOWEST gradient norms
flat_grad_norm = grad_norm_values.view(-1)
total_heads = grad_norm_values.numel()
target_prune_count = int(total_heads * PRUNE_PERCENT)
_, indices = torch.topk(flat_grad_norm, k=target_prune_count, largest=False)
pruning_mask = torch.zeros_like(grad_norm_values, dtype=torch.bool)
pruning_mask.view(-1)[indices] = True

# Create a comprehensive visualization showing gradient norms with markers for pruning decisions
plt.figure(figsize=(10, 5))
plt.title("Initial Head Gradient Norms - Pruning Candidates")
plt.imshow(grad_norm_values.detach().cpu().numpy().cpu().numpy(), cmap="plasma", aspect="auto")
plt.colorbar(label="Gradient Norm")

# Add text markers for the heads with LOWEST gradient norms (candidates for pruning)
for layer in range(controller.total_layers):
    for head in range(controller.heads_per_layer):
        if pruning_mask[layer, head]:  # True means prune this head (lowest gradients)
            plt.text(head, layer, "P", ha="center", va="center", 
                     color="white", weight="bold", bbox=dict(facecolor='red', alpha=0.5))

plt.xlabel("Head Index")
plt.ylabel("Layer Index")
# Consider using constrained_layout=True instead of tight_layout()
plt.tight_layout()
plt.show()

# Now add the new visualization that combines gradient norms with pruning status
print("\nInitial Head Gradient Norms with Pruning Candidates:")

# Create a list of (layer, head) tuples for heads marked for pruning
pruning_candidates = [(layer, head) for layer in range(controller.total_layers) 
                      for head in range(controller.heads_per_layer) 
                      if pruning_mask[layer, head]]  # True means prune (low gradient)

visualize_gradient_norms(
    grad_norm_values=grad_norm_values,
    pruned_heads=pruning_candidates,  # Mark candidates as if they were pruned
    title="Initial Head Gradient Norms with Pruning Candidates"
)
plt.show()

# Also plot standard visualizations for comparison
# Plot entropy heatmap
plt.figure(figsize=(10, 6))
plt.title("Initial Head Entropy (higher = less focused attention)")
entropy_map = safe_tensor_imshow(entropy_values, title='Visualization of entropy_values').cpu().numpy().cpu().numpy(), cmap="viridis", aspect="auto")
# Ensure entropy visualization has some range
plt.clim(0, max(0.1, entropy_values.max().item()))
plt.colorbar(entropy_map, label="Entropy")
plt.xlabel("Head Index")
plt.ylabel("Layer Index")
# Consider using constrained_layout=True instead of tight_layout()
plt.tight_layout()
plt.show()

# Create a visualization highlighting the relationship between gradient norms and pruning decisions
plt.figure(figsize=(12, 8))
grad_data = grad_norm_values.detach().cpu().numpy()
mask_data = pruning_mask.detach().cpu().numpy()

# Create a masked array where pruned heads are highlighted
masked_grads = np.ma.array(grad_data, mask=~mask_data)

# Base plot with all gradient values
safe_tensor_imshow(grad_data, title='Visualization of grad_data').cpu().numpy().cpu().numpy()).cpu().numpy())
# Overlay plot with pruned heads highlighted
plt.imshow(masked_grads, cmap='Reds', aspect='auto'.detach().cpu().numpy().cpu().numpy())
plt.colorbar(label='Gradient Norm')
plt.title('Gradient Norms with Low-Gradient Heads Highlighted for Pruning')
plt.xlabel('Head Index')
plt.ylabel('Layer Index')
# Consider using constrained_layout=True instead of tight_layout()
plt.tight_layout()
plt.show()# Function removed - using imported version


## Training with Neural Plasticity

Now let's train the model with neural plasticity enabled, allowing it to adaptively prune and restore attention heads.

In [15]:
# Initialize training components
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
total_steps = len(train_dataloader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=WARMUP_STEPS, 
    num_training_steps=total_steps
)

# Print epoch information
print(f"Dataset size: {len(train_dataset)} examples")
print(f"Batch size: {BATCH_SIZE}")
print(f"Steps per epoch: {len(train_dataloader)}")
print(f"Total epochs: {NUM_EPOCHS}")
print(f"Maximum steps per epoch: {MAX_STEPS_PER_EPOCH if MAX_STEPS_PER_EPOCH else 'Unlimited'}")
print(f"Expected total steps: {NUM_EPOCHS * (MAX_STEPS_PER_EPOCH or len(train_dataloader))}")
print(f"Eval interval: {EVAL_INTERVAL} steps")
print(f"Visualization interval: {VISUALIZATION_INTERVAL} steps")
print(f"Inference interval: {INFERENCE_INTERVAL} steps")



In [16]:
# Initialize metrics tracking
metrics_history = {
    "train_loss": [],
    "eval_loss": [],
    "pruned_heads": [],
    "revived_heads": [],
    "sparsity": [],
    "step": [],
    "epoch": [],  # Track epoch number for each step
    "perplexity": [],  # Track perplexity
    "inference_samples": []  # Store sample generations
}

# Import visualization utilities from utils.colab
from utils.colab.visualizations import TrainingMonitor, visualize_gradient_norms

# Create pruning monitor widget
pruning_monitor = TrainingMonitor(
    title="Neural Plasticity Training Progress",
    metrics_to_track=["step", "epoch", "train_loss", "eval_loss", 
                     "pruned_heads", "revived_heads", "sparsity", "perplexity"]
)


# Create pruning monitor widget
pruning_monitor = TrainingMonitor(
    title="Neural Plasticity Training Progress",
    metrics_to_track=["step", "epoch", "train_loss", "eval_loss", 
                     "pruned_heads", "revived_heads", "sparsity", "perplexity"]
)

# Import visualization utilities from utils.colab

# Create output directory for visualizations and checkpoints
import os
output_dir = "pruning_visualizations"
checkpoint_dir = os.path.join(output_dir, "checkpoints")
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

# Inference prompts for consistent tracking
inference_prompts = [
    "Once upon a time",
    "The future of artificial intelligence",
    "In a distant galaxy",
    "Scientists recently discovered"
]



In [17]:
# NOTE: This cell requires the controller to be defined
# Custom function to apply pruning based purely on gradients - CORRECTED VERSION
# Check if controller exists
try:
    controller
except NameError:
    print("ERROR: The controller variable is not defined. Please run the cell that creates the plasticity controller first.")
    # Create a simple dummy controller to avoid breaking the notebook flow
    controller = SimpleNamespace()
    controller.collect_head_metrics = lambda *args, **kwargs: (None, None)
    controller.display_stats = lambda *args, **kwargs: None
    controller.stats = {}
    controller.total_layers = 0
    controller.heads_per_layer = 0


def apply_gradient_pruning(grad_norm_values):
    """Apply gradient-based pruning targeting heads with lowest gradient norms."""
    # Get pruning decisions
    # Get the indices of the heads with the LOWEST gradient norms
    flat_grad_norm = grad_norm_values.view(-1)
    total_heads = grad_norm_values.numel()
    target_prune_count = int(total_heads * PRUNE_PERCENT)
    
    # Key fix: use largest=False to get lowest gradient heads
    _, indices = torch.topk(flat_grad_norm, k=target_prune_count, largest=False)
    
    # Create pruning mask - True means "prune this head"
    pruning_mask = torch.zeros_like(grad_norm_values, dtype=torch.bool)
    pruning_mask.view(-1)[indices] = True
    
    # Convert to list of (layer, head) tuples for pruning
    pruned_heads = []
    for layer in range(controller.total_layers):
        for head in range(controller.heads_per_layer):
            if pruning_mask[layer, head]:  # True means prune (low gradient)
                # Check if head is already pruned
                if not controller.stats[layer][head]['is_zeroed']:
                    pruned_heads.append((layer, head))
    
    # Apply pruning
    for layer, head in pruned_heads:
        result = prune_head_in_model(
            controller.model, 
            layer, 
            head, 
            mode=controller.mode, 
            verbose=False  # Reduce verbosity
        )
        if result:
            # Update controller stats
            controller.stats[layer][head]['is_zeroed'] = True
            controller.stats[layer][head]['zeroed_epochs'] = 1
    
    # Update controller hooks
    controller._update_pruning_hooks(verbose=False)  # Reduce verbosity
    
    # Print stats about the pruned and kept heads - but only if we actually pruned something
    if pruned_heads:
        print(f"Pruned {len(pruned_heads)} heads with lowest gradient norms")
        # Only show detailed metrics at the start
        if not hasattr(apply_gradient_pruning, "has_pruned_before"):
            avg_pruned = grad_norm_values[pruning_mask].mean().item()
            avg_kept = grad_norm_values[~pruning_mask].mean().item()
            print(f"Average gradient of pruned heads: {avg_pruned:.6f}")
            print(f"Average gradient of kept heads: {avg_kept:.6f}")
            print(f"Ratio (kept/pruned): {avg_kept/avg_pruned:.2f}x")
            # Set flag to avoid showing these details every time
            apply_gradient_pruning.has_pruned_before = True
    
    return pruned_heads
# Import visualization utilities
from utils.pruning.visualization_additions import (
    visualize_gradient_norms,
    visualize_attention_matrix,
    visualize_entropy_heatmap,
    visualize_normalized_entropy,
    visualize_entropy_vs_gradient,
    visualize_training_progress
)

In [18]:
# Convert stats dict to regular dict for serialization
def convert_stats_for_checkpoint(stats_dict):
    """Convert defaultdict to regular dict for pickle serialization."""
    regular_dict = {}
    for layer, heads in stats_dict.items():
        regular_dict[layer] = {}
        for head, values in heads.items():
            regular_dict[layer][head] = dict(values)
    return regular_dict



In [19]:
# NOTE: This cell requires the controller to be defined
# Function to save checkpoint
# Check if controller exists
try:
    controller
except NameError:
    print("ERROR: The controller variable is not defined. Please run the cell that creates the plasticity controller first.")
    # Create a simple dummy controller to avoid breaking the notebook flow
    controller = SimpleNamespace()
    controller.collect_head_metrics = lambda *args, **kwargs: (None, None)
    controller.display_stats = lambda *args, **kwargs: None
    controller.stats = {}
    controller.total_layers = 0
    controller.heads_per_layer = 0


def save_checkpoint(step, epoch):
    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_step_{step}.pt")
    # Convert stats_dict to regular dict to avoid pickle issues
    stats_dict = convert_stats_for_checkpoint(controller.stats)
    torch.save({
        'step': step,
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'controller_stats': stats_dict,
        'metrics_history': metrics_history
    }, checkpoint_path)
    print(f"  Checkpoint saved at step {step} (epoch {epoch})")
    return checkpoint_path



In [20]:
# Function to run inference
def run_model_inference():
    model.eval()
    inference_results = {}
    
    for prompt in inference_prompts:
        generated_text = generate_text(prompt, max_length=50)  # Keep it shorter for quick visualization
        inference_results[prompt] = generated_text
        
    print("\n=== Sample Generations ===")
    for prompt, text in inference_results.items():
        print(f"Prompt: {prompt}")
        print(f"Generated: {text[:100]}...")  # Truncate for display
        print("-" * 40)
    
    return inference_results



In [21]:
# Training loop
global_step = 0



In [22]:
# NOTE: This cell requires the controller to be defined
# Check if controller exists
try:
    controller
except NameError:
    print("ERROR: The controller variable is not defined. Please run the cell that creates the plasticity controller first.")
    # Create a simple dummy controller to avoid breaking the notebook flow
    controller = SimpleNamespace()
    controller.collect_head_metrics = lambda *args, **kwargs: (None, None)
    controller.display_stats = lambda *args, **kwargs: None
    controller.stats = {}
    controller.total_layers = 0
    controller.heads_per_layer = 0

# Add memory management utilities
import gc

def clear_memory():
    '''Clear GPU memory cache and run garbage collection'''
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

# Import visualization functions and persistent display widget
# NOTE: TrainingMonitor removed.
# Visualization will be rendered using matplotlib directly.

# Replace display widget with matplotlib visual

def plot_training_metrics(metrics_history):
    steps = metrics_history["step"]
    fig, axs = plt.subplots(3, 1, figsize=(10, 8), sharex=True)

    axs[0].plot(steps, metrics_history["train_loss"], label="Train Loss")
    axs[0].plot(steps, metrics_history["eval_loss"], label="Eval Loss")
    axs[0].set_ylabel("Loss")
    axs[0].legend()

    axs[1].plot(steps, metrics_history["perplexity"], label="Perplexity")
    axs[1].set_ylabel("Perplexity")
    axs[1].legend()

    axs[2].plot(steps, metrics_history["total_pruned"], label="Total Pruned")
    axs[2].plot(steps, metrics_history["sparsity"], label="Sparsity")
    axs[2].set_xlabel("Steps")
    axs[2].legend()

    plt.tight_layout()
    plt.show()


# Initialize metric tracking dictionary
metrics_history = {
    "step": [], "epoch": [], "train_loss": [], "eval_loss": [],
    "perplexity": [], "pruned_heads": [], "revived_heads": [],
    "sparsity": [], "total_pruned": []
}

# Removed TrainingMonitor widget
# Using pruning_monitor already created above

try:
    # Track previous state to reduce logging
    last_total_pruned = 0
    last_visualization_step = 0

    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
        model.train()
        
        epoch_loss = 0.0
        epoch_steps = 0
        
        # For each batch in the dataloader
        for step, batch in enumerate(train_dataloader):
            # Check if we've reached MAX_STEPS_PER_EPOCH for this epoch
            if MAX_STEPS_PER_EPOCH is not None and step >= MAX_STEPS_PER_EPOCH:
                print(f"  Reached maximum steps per epoch ({MAX_STEPS_PER_EPOCH}). Moving to next epoch.")
                break
                
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            
            # Forward pass
            outputs = model(**batch)
            loss = outputs.loss
            
            # Backward pass
            loss.backward()
            
            # Update weights
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            # Track loss
            epoch_loss += loss.item()
            epoch_steps += 1
            global_step += 1
            
            # Periodically evaluate
            if global_step % EVAL_INTERVAL == 0:
                # Evaluate
                model.eval()
                eval_loss, eval_perplexity = evaluate_model(model, validation_dataloader)
                
                # Collect metrics - we only need gradient norms
                _, grad_norm_values = controller.collect_head_metrics(
                    validation_dataloader, 
                    num_batches=2
                )
                
                # Apply gradient-based pruning
                pruned_heads = apply_gradient_pruning(grad_norm_values)
                
                # In this simplified version, we don't revive heads
                revived_heads = []
                
                # Get model info
                model_info = get_model_info(model)
                total_pruned = controller._count_pruned_heads()
                
                # Update metrics
                metrics_history["train_loss"].append(epoch_loss / epoch_steps)
                metrics_history["eval_loss"].append(eval_loss)
                metrics_history["pruned_heads"].append(len(pruned_heads))
                metrics_history["revived_heads"].append(len(revived_heads))
                metrics_history["sparsity"].append(model_info["sparsity"])
                metrics_history["step"].append(global_step)
                metrics_history["epoch"].append(epoch + 1)
                metrics_history["perplexity"].append(eval_perplexity)
                
                # Only print pruning details if something changed
                if total_pruned != last_total_pruned:
                    pruning_info = f"{len(pruned_heads)} new heads, {total_pruned} total ({model_info['sparsity']:.2%} sparsity)"
                    last_total_pruned = total_pruned
                else:
                    pruning_info = f"No new heads pruned. Total: {total_pruned} ({model_info['sparsity']:.2%} sparsity)"
                
                # Get the current pruned heads from controller stats
                all_pruned_heads = []
                for layer in range(controller.total_layers):
                    for head in range(controller.heads_per_layer):
                        if controller.stats[layer][head]['is_zeroed']:
                            all_pruned_heads.append((layer, head))
                
                # Create metrics dictionary for the monitor
                current_metrics = {
                    "step": global_step,
                    "epoch": epoch + 1,
                    "train_loss": epoch_loss / epoch_steps if epoch_steps > 0 else 0,
                    "eval_loss": eval_loss,
                    "perplexity": eval_perplexity,
                    "new_pruned": len(pruned_heads),
                    "total_pruned": total_pruned,
                    "sparsity": model_info["sparsity"]
                }
                
                # Create a figure for gradient norms if we have pruned heads
                def create_gradient_fig():
                    return visualize_gradient_norms(
                        grad_norm_values=grad_norm_values,
                        pruned_heads=all_pruned_heads,
                        revived_heads=revived_heads,
                        title=f"Head Gradient Norms with Pruning Status (Step {global_step}, Epoch {epoch+1})",
                    )
                
                # Update the persistent visualization
                # Determine if we should show the graph based on whether anything changed
                if pruned_heads or (global_step - last_visualization_step >= VISUALIZATION_INTERVAL * 5):
                    # Update with both metrics and figure
                    last_visualization_step = global_step
                    
                    # Only show gradient figure if we have pruned heads
                    if all_pruned_heads:
                        # pruning_monitor.update _metrics(
                            current_metrics, 
                            step=global_step, 
                            epoch=epoch + 1,
                            plot=False  # Don't auto-plot, we'll show our custom figure
                        )
                        # Add our custom gradient figure below the metrics
                        # pruning_monitor.update _with_figure(
                            create_gradient_fig,
                            caption=f"Pruning Status: {len(pruned_heads)} new heads pruned, {total_pruned} total pruned",
                            clear=False  # Don't clear since we just displayed metrics
                        )
                    else:
                        # Just show metrics without figure if no pruning has happened
                        # pruning_monitor.update _metrics(
                            current_metrics, 
                            step=global_step, 
                            epoch=epoch + 1
                        )
                else:
                    # Just update metrics without visualization
                    # pruning_monitor.update _metrics(
                        current_metrics, 
                        step=global_step, 
                        epoch=epoch + 1,
                        plot=False  # Simple update without plots
                    )
                
                # Print minimal status to console
                print(f"  Step {global_step} (Epoch {epoch+1}) - Train loss: {epoch_loss / epoch_steps:.4f}, "
                      f"Eval loss: {eval_loss:.4f}, Perplexity: {eval_perplexity:.2f}")
                print(f"  Pruning: {pruning_info}")
                
                # Run model inference at regular intervals
                if global_step % INFERENCE_INTERVAL == 0:
                    inference_results = run_model_inference()
                    metrics_history["inference_samples"].append({
                        "step": global_step,
                        "epoch": epoch + 1,
                        "results": inference_results
                    })
                
                # Save checkpoint at regular intervals
                if global_step % CHECKPOINT_INTERVAL == 0:
                    save_checkpoint(global_step, epoch + 1)
                
                # Reset for next interval
                epoch_loss = 0.0
                epoch_steps = 0
                
                # Back to training mode
                model.train()
            
            # Just update progress in persistent display occasionally without full metrics
            elif global_step % 50 == 0:
                # Simple progress update
                # pruning_monitor.update (
                    f"""
                    <p><b>Progress Update:</b> Step {global_step} (Epoch {epoch+1})</p>
                    <p>Current loss: {epoch_loss / epoch_steps:.4f}</p>
                    <p>Full metrics will be displayed at next evaluation step.</p>
                    """,
                    notify=False
                )
                print(f"  Progress: Step {global_step} (Epoch {epoch+1})")
        
        print(f"Completed Epoch {epoch+1} - Total steps: {global_step}")
        # Clear memory at the end of each epoch
        clear_memory()
    
    # Save final checkpoint
    final_checkpoint_path = save_checkpoint(global_step, epoch + 1)
    print(f"Training completed! Final checkpoint saved at {final_checkpoint_path}")
    
# Add more specific error handling for common issues
except (MemoryError, RuntimeError) as e:
    print(f"\nMemory or Runtime error: {e}")
    print("Attempting to recover and save checkpoint...")
    # Force cleanup
    clear_memory()
    try:
        recovery_checkpoint_path = save_checkpoint(global_step, epoch + 1)
        print(f"Recovery checkpoint saved at {recovery_checkpoint_path}")
    except Exception as save_error:
        print(f"Could not save checkpoint during recovery: {save_error}")
except KeyboardInterrupt:
    print("\nTraining interrupted by user.")
    # Save checkpoint on interrupt
    interrupt_checkpoint_path = save_checkpoint(global_step, epoch + 1)
    print(f"Checkpoint saved at {interrupt_checkpoint_path}")
except Exception as e:
    print(f"\nTraining error: {e}")
    # Try to save checkpoint on error
    try:
        error_checkpoint_path = save_checkpoint(global_step, epoch + 1)
        print(f"Checkpoint saved at {error_checkpoint_path}")
    except Exception as save_error:
        print(f"Could not save checkpoint: {save_error}")

# Function removed - using imported version

## Visualize Training Progress

Let's visualize the training history to see how neural plasticity affected the model.

In [23]:
# Update the pruning monitor with metrics
pruning_monitor.update_metrics(
    current_metrics,
    step=global_step,
    epoch=epoch + 1,
    plot=True
)

## Generate Text with Final Model

Let's generate text with our plasticity-enhanced model to see the results.

In [24]:
# NOTE: This cell requires the controller to be defined
# Final evaluation
# Check if controller exists
try:
    controller
except NameError:
    print("ERROR: The controller variable is not defined. Please run the cell that creates the plasticity controller first.")
    # Create a simple dummy controller to avoid breaking the notebook flow
    controller = SimpleNamespace()
    controller.collect_head_metrics = lambda *args, **kwargs: (None, None)
    controller.display_stats = lambda *args, **kwargs: None
    controller.stats = {}
    controller.total_layers = 0
    controller.heads_per_layer = 0


final_loss, final_perplexity = evaluate_model(model, validation_dataloader)
print(f"Final evaluation: Loss = {final_loss:.4f}, Perplexity = {final_perplexity:.2f}")
print(f"Baseline:         Loss = {baseline_loss:.4f}, Perplexity = {baseline_perplexity:.2f}")
print(f"Improvement:      {((baseline_loss - final_loss) / baseline_loss * 100):.2f}%")

# Get final summary
summary = controller.get_summary()
print("\nFinal Controller Summary:")
print(f"  Total heads: {summary['total_heads']}")
print(f"  Pruned heads: {summary['pruned_heads']} ({summary['pruning_rate']:.2%})")
print(f"  Model sparsity: {summary['sparsity']:.4f}")
print(f"  Model size: {summary['model_size_mb']:.2f} MB")

## Generate Text with Final Model

Let's generate text with our plasticity-enhanced model to see the results.

In [25]:
# Generate text with final model
final_text = generate_text(prompt)

print("Baseline Model Output:")
print(baseline_text)
print("\nPlasticity-Optimized Model Output:")
print(final_text)

## Try Different Prompts

Let's try generating text with different prompts to see how the model performs.

In [26]:
# Create output directory
from datetime import datetime

output_dir = os.path.join("output", "plasticity", f"run_{datetime.now().strftime('%Y%m%d-%H%M%S')}")
os.makedirs(output_dir, exist_ok=True)

# Save model and tokenizer
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"Model saved to {output_dir}")

## Try Different Prompts

Let's try generating text with different prompts to see how the model performs.

In [27]:
prompts = [
    "The meaning of life is",
    "In a distant galaxy",
    "The future of AI will be",
    "Scientists recently discovered"
]

for prompt in prompts:
    print(f"Prompt: {prompt}")
    generated = generate_text(prompt)
    print(f"Generated: {generated}\n")

# Conclusion

In this notebook, we demonstrated Sentinel AI's neural plasticity system, which enables transformer models to dynamically prune and revive attention heads during training based on their utility.

Key findings:
1. The plasticity system successfully pruned high-entropy, low-gradient heads
2. Some heads were revived when they showed potential for useful learning
3. The final model achieved comparable quality with fewer active heads
4. The brain dynamics visualization shows how attention heads evolve over time

This approach mimics biological neural plasticity, where brains form efficient neural pathways by pruning unused connections and strengthening useful ones.