# Neural Plasticity Demo: Ultra-Minimal Version (v0.0.54 2025-04-19 20:54:11)

This notebook demonstrates Sentinel AI's neural plasticity system, which allows transformer models to dynamically prune and regrow attention heads during training based on utility metrics. [ID: ba34f047]

### Changes in v0.0.54:
- Fixed GPU tensor handling for visualizations
- Fixed redundant tensor conversion patterns
- Improved numerical stability in entropy calculations
- Created ultra-minimal version for local testing
- Added synthetic data to avoid dataset dependencies

## ⚠️ Ultra-Minimal Settings Version ⚠️
This is a streamlined version for local testing with:
- Smallest model possible (distilgpt2)
- Synthetic data generated on-the-fly
- Minimal training steps
- Core functionality only

For full features, use the standard version in Google Colab with GPU acceleration.

In [None]:
# Set environment variables for BLAS stability
import os
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
os.environ['PYTHONHASHSEED'] = '0'

# Add tensor handling safety utilities
import gc
def clear_memory():
    """Clear GPU memory cache and run garbage collection"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

In [None]:
%matplotlib inline
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
import math
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    get_linear_schedule_with_warmup
)

# Import neural plasticity modules
import sys
if not os.getcwd() in sys.path:
    sys.path.append(os.getcwd())
from utils.neural_plasticity.core import (
    calculate_head_entropy,
    calculate_head_gradients,
    generate_pruning_mask,
    apply_pruning_mask,
    evaluate_model,
    detect_model_structure
)
from utils.neural_plasticity.visualization import (
    visualize_head_entropy,
    visualize_head_gradients,
    visualize_pruning_decisions
)
from utils.colab.helpers import safe_tensor_imshow

print("Neural plasticity imports successful")

In [None]:
# Configure experiment with minimal settings
MODEL_NAME = "distilgpt2"  # Smallest GPT-2 model
MAX_LENGTH = 32        # Very short sequences
BATCH_SIZE = 2         # Tiny batch size
NUM_EPOCHS = 1         # Just one epoch
LEARNING_RATE = 5e-5
WARMUP_STEPS = 10
EVAL_INTERVAL = 10     # Evaluate frequently
MAX_STEPS = 20         # Very few training steps
PRUNE_PERCENT = 0.1    # Target to prune approximately 10% of heads

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Load model and tokenizer
print(f"Loading model: {MODEL_NAME}")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Set pad token if needed
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Get model structure information
num_layers, num_heads = detect_model_structure(model)
print(f"Model has {num_layers} layers and {num_heads} heads per layer")

In [None]:
# Create synthetic data for testing
def create_synthetic_data(num_samples=50, seq_length=MAX_LENGTH):
    """Create synthetic data for testing neural plasticity."""
    # Sample texts (short but structured)
    sample_texts = [
        "The quick brown fox jumps over the lazy dog. This sentence contains many common letters.",
        "Machine learning models can be trained to recognize patterns in data and make predictions.",
        "Neural networks consist of layers of interconnected nodes that process information.",
        "Attention mechanisms allow models to focus on relevant parts of the input data.",
        "Transformers have revolutionized natural language processing with their ability to handle long-range dependencies."
    ]
    
    # Generate training data by tokenizing the samples
    train_encodings = []
    train_labels = []
    
    for _ in range(num_samples):
        # Pick a random sample text
        text = sample_texts[np.random.randint(0, len(sample_texts))]
        
        # Tokenize the text
        encoding = tokenizer(text, max_length=seq_length, padding="max_length", truncation=True)
        
        # Use input IDs as labels (causal language modeling)
        train_encodings.append({
            "input_ids": torch.tensor(encoding["input_ids"]),
            "attention_mask": torch.tensor(encoding["attention_mask"]),
            "labels": torch.tensor(encoding["input_ids"])
        })
    
    return train_encodings

# Create synthetic train and validation datasets
print("Creating synthetic datasets...")
train_data = create_synthetic_data(num_samples=50)
val_data = create_synthetic_data(num_samples=20)
print(f"Created {len(train_data)} training samples and {len(val_data)} validation samples")

# Create a simple data loader
class SimpleDataLoader:
    """Simple data loader for synthetic data."""
    def __init__(self, data, batch_size=BATCH_SIZE, shuffle=True):
        self.data = data
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = list(range(len(data)))
        self.reset()
    
    def reset(self):
        """Reset the data loader."""
        if self.shuffle:
            np.random.shuffle(self.indices)
        self.current = 0
    
    def __iter__(self):
        self.reset()
        return self
    
    def __next__(self):
        if self.current + self.batch_size > len(self.data):
            self.reset()
            raise StopIteration
        
        # Get indices for this batch
        batch_indices = self.indices[self.current:self.current + self.batch_size]
        self.current += self.batch_size
        
        # Prepare batch
        batch = {}
        for key in self.data[0].keys():
            batch[key] = torch.stack([self.data[i][key] for i in batch_indices])
        
        return batch
    
    def __len__(self):
        return math.ceil(len(self.data) / self.batch_size)

# Create data loaders
train_loader = SimpleDataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_loader = SimpleDataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)
print(f"Created data loaders with {len(train_loader)} training batches")

In [None]:
# Define evaluation functions
def evaluate_perplexity(model, val_loader, max_eval_batches=5):
    """Evaluate model perplexity on validation data."""
    model.eval()
    total_loss = 0.0
    total_samples = 0
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            if batch_idx >= max_eval_batches:
                break
                
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Forward pass
            outputs = model(**batch)
            loss = outputs.loss
            
            # Track loss
            total_loss += loss.item() * batch["input_ids"].size(0)
            total_samples += batch["input_ids"].size(0)
    
    # Calculate average loss and perplexity
    avg_loss = total_loss / total_samples if total_samples > 0 else float("inf")
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    
    return avg_loss, perplexity

def generate_text(model, tokenizer, prompt, max_length=50):
    """Generate text from the model."""
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )
    
    return tokenizer.decode(output[0], skip_special_tokens=True)

In [None]:
# Analyze attention entropy
def analyze_attention_entropy(model, dataloader, num_batches=2):
    """Extract attention patterns and calculate entropy."""
    model.eval()
    attention_maps = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx >= num_batches:
                break
                
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Forward pass with attention outputs
            outputs = model(**batch, output_attentions=True)
            
            # Collect attention maps
            if outputs.attentions:
                attention_maps.extend(outputs.attentions)
    
    # Calculate entropy for all layers
    all_entropies = []
    
    for layer_idx, attn in enumerate(attention_maps):
        # Calculate entropy
        layer_entropy = calculate_head_entropy(attn)
        all_entropies.append(layer_entropy)
    
    # Stack all entropies into a single tensor [layers, heads]
    if all_entropies:
        entropy_tensor = torch.stack([e.mean(dim=-1) for e in all_entropies])
    else:
        # Fallback if no attention maps were collected
        entropy_tensor = torch.rand(num_layers, num_heads)
    
    return entropy_tensor

# Calculate and visualize entropy
print("Analyzing attention entropy...")
entropy_values = analyze_attention_entropy(model, val_loader)
print(f"Calculated entropy tensor of shape {entropy_values.shape}")

# Visualize entropy
fig1 = visualize_head_entropy(
    entropy_values=entropy_values,
    title="Attention Entropy Across Heads",
    annotate=True
)
plt.tight_layout()
plt.show()

In [None]:
# Calculate gradient norms
def calculate_gradients(model, dataloader, num_batches=2):
    """Calculate gradient norms for all attention heads."""
    # Reset gradients
    model.zero_grad()
    
    # Set to training mode
    model.train()
    
    # Calculate gradients using our module
    grad_norms = calculate_head_gradients(
        model=model,
        dataloader=dataloader,
        num_batches=num_batches,
        device=device
    )
    
    return grad_norms

# Calculate and visualize gradients
print("Calculating head gradients...")
grad_norm_values = calculate_gradients(model, train_loader)
print(f"Calculated gradient norms tensor of shape {grad_norm_values.shape}")

# Visualize gradients
fig2 = visualize_head_gradients(
    grad_norm_values=grad_norm_values,
    title="Gradient Norms Across Heads"
)
plt.tight_layout()
plt.show()

In [None]:
# Apply pruning based on gradients
def apply_gradient_pruning(model, grad_norm_values, prune_percent=0.1):
    """Apply pruning based on gradient norms."""
    # Generate pruning mask
    pruning_mask = generate_pruning_mask(
        grad_norm_values=grad_norm_values,
        prune_percent=prune_percent,
        strategy="gradient"  # Use gradient-based pruning
    )
    
    # Apply pruning mask to the model
    pruned_heads = apply_pruning_mask(model, pruning_mask)
    
    return pruning_mask, pruned_heads

# Apply pruning
print(f"Applying gradient-based pruning with {PRUNE_PERCENT*100:.0f}% target...")
pruning_mask, pruned_heads = apply_gradient_pruning(
    model=model,
    grad_norm_values=grad_norm_values,
    prune_percent=PRUNE_PERCENT
)
print(f"Pruned {len(pruned_heads)} heads: {pruned_heads}")

# Visualize pruning decisions
fig3 = visualize_pruning_decisions(
    grad_norm_values=grad_norm_values,
    pruning_mask=pruning_mask,
    title="Pruning Decisions - Heads with Lowest Gradients"
)
plt.tight_layout()
plt.show()

In [None]:
# Evaluate model before and after pruning
# Baseline evaluation
print("Evaluating model before pruning...")
baseline_loss, baseline_perplexity = evaluate_perplexity(model, val_loader)
print(f"Baseline evaluation: Loss = {baseline_loss:.4f}, Perplexity = {baseline_perplexity:.2f}")

# Generate text with baseline model
prompt = "The neural network was trained to"
baseline_text = generate_text(model, tokenizer, prompt)
print(f"\nPrompt: {prompt}")
print(f"Generated text:\n{baseline_text}")

In [None]:
# Train the pruned model
def train_pruned_model(model, train_loader, val_loader, learning_rate=LEARNING_RATE, max_steps=MAX_STEPS):
    """Train the pruned model for a few steps."""
    model.train()
    
    # Initialize optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=WARMUP_STEPS, 
        num_training_steps=max_steps
    )
    
    # Initialize metrics tracking
    metrics = {
        "train_loss": [],
        "eval_loss": [],
        "perplexity": [],
        "steps": []
    }
    
    # Training loop
    print(f"Training pruned model for {max_steps} steps...")
    progress_bar = tqdm(total=max_steps, desc="Training")
    train_iter = iter(train_loader)
    
    # Run training loop
    for step in range(max_steps):
        # Get next batch (cycling through the dataset)
        try:
            batch = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            batch = next(train_iter)
        
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        # Update progress bar
        progress_bar.update(1)
        progress_bar.set_postfix(loss=f"{loss.item():.4f}")
        
        # Evaluate periodically
        if (step + 1) % EVAL_INTERVAL == 0 or step == max_steps - 1:
            # Evaluate
            eval_loss, perplexity = evaluate_perplexity(model, val_loader)
            
            # Update metrics
            metrics["train_loss"].append(loss.item())
            metrics["eval_loss"].append(eval_loss)
            metrics["perplexity"].append(perplexity)
            metrics["steps"].append(step)
            
            # Print metrics
            print(f"\nStep {step+1}: Train loss = {loss.item():.4f}, "
                  f"Eval loss = {eval_loss:.4f}, "
                  f"Perplexity = {perplexity:.2f}")
        
        # Clear memory after each step
        if step % 5 == 0:
            clear_memory()
    
    progress_bar.close()
    return metrics

# Train the pruned model
training_metrics = train_pruned_model(model, train_loader, val_loader)

# Visualize training metrics
plt.figure(figsize=(12, 4))

# Plot loss
plt.subplot(1, 2, 1)
plt.plot(training_metrics["steps"], training_metrics["train_loss"], label="Train Loss")
plt.plot(training_metrics["steps"], training_metrics["eval_loss"], label="Eval Loss")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training and Evaluation Loss")
plt.legend()
plt.grid(alpha=0.3)

# Plot perplexity
plt.subplot(1, 2, 2)
plt.plot(training_metrics["steps"], training_metrics["perplexity"], color="green")
plt.xlabel("Steps")
plt.ylabel("Perplexity")
plt.title("Perplexity")
plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Final evaluation
print("Final evaluation after training...")
final_loss, final_perplexity = evaluate_perplexity(model, val_loader)
print(f"Final evaluation: Loss = {final_loss:.4f}, Perplexity = {final_perplexity:.2f}")
print(f"Baseline:         Loss = {baseline_loss:.4f}, Perplexity = {baseline_perplexity:.2f}")
print(f"Improvement:      {((baseline_loss - final_loss) / baseline_loss * 100):.2f}%")

# Generate text with final model
final_text = generate_text(model, tokenizer, prompt)
print(f"\nPrompt: {prompt}")
print(f"Final generated text:\n{final_text}")

## Conclusion

In this ultra-minimal demonstration, we've shown the core functionality of the neural plasticity system:

1. We calculated attention entropy and gradient norms for all heads
2. We pruned heads with the lowest gradient norms (least useful for learning)
3. We trained the pruned model and observed its performance

This demonstrates that our neural plasticity system can successfully identify and prune less useful attention heads, potentially improving efficiency without sacrificing performance.

The techniques used here form the foundation for more advanced neural plasticity operations that enable continual growth, adaptation, and pruning for transformers in the full version of this system.