# Microglia-Inspired Dynamic Pruning - Full Experiment

This notebook runs the complete pruning experiment on Phi-3-Mini with GSM8K.

**What we're doing:** Training small "agent" networks to learn which attention heads can be pruned during inference. Inspired by how microglia prune synapses in the brain.

**Expected results:**
- 20-30% of attention heads pruned
- ~15% latency improvement
- <2% accuracy loss

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Tommaso-R-Marena/microglia-pruning/blob/main/notebooks/microglia_pruning_demo.ipynb)

## Setup

This takes ~2 minutes on a T4 GPU.

In [None]:
# Remove any existing clone and get fresh copy
import os
import shutil

if os.path.exists('/content/microglia-pruning'):
    shutil.rmtree('/content/microglia-pruning')
    print('Removed old clone')

# Clone repo with latest code
!git clone https://github.com/Tommaso-R-Marena/microglia-pruning.git
%cd microglia-pruning

# Verify we have latest commit
!git log --oneline -1

In [None]:
# Install dependencies
!pip install -q torch transformers accelerate bitsandbytes peft datasets scipy numpy tqdm matplotlib
!pip install -q fvcore

print('Installation complete')

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import sys
import time

sys.path.insert(0, '/content/microglia-pruning')
from src.system import MicrogliaPruningSystem

print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

torch.manual_seed(42)
np.random.seed(42)

## Part 1: Load Base Model

We're using Phi-3-Mini (3.8B parameters). Downloads ~7.5 GB (takes ~3-5 min first time).

**You should see:** 'Fixing Phi-3 EOS token issue...' message.

In [None]:
system = MicrogliaPruningSystem(
    model='microsoft/phi-3-mini-4k-instruct',
    num_heads=32,
    hidden_dim=128,
    temperature=1.0
)

print('\nSystem initialized!')
print(f'\nModel size: {sum(p.numel() for p in system.model.parameters())/1e9:.2f}B parameters')
print(f'Agent size: {sum(p.numel() for p in system.agents.parameters())/1e6:.2f}M parameters')
print(f'\nAgent overhead: {sum(p.numel() for p in system.agents.parameters())/sum(p.numel() for p in system.model.parameters())*100:.3f}%')

## Part 2: Test Baseline Performance

Test baseline model on simple math problems.

**You should see proper answers** (not garbage output).

In [None]:
test_questions = [
    'A store sells apples for $2 each. If Sarah buys 5 apples, how much does she spend?',
    'John has 15 candies. He gives 3 to each of his 4 friends. How many candies does he have left?',
    'A rectangle has a length of 8 meters and a width of 5 meters. What is its area?'
]

print('Testing baseline (unpruned) model:\n')

for i, question in enumerate(test_questions, 1):
    prompt = f'Question: {question}\nAnswer:'
    start = time.time()
    output = system.generate(prompt, max_new_tokens=100)
    elapsed = time.time() - start
    answer = output.split('Answer:')[-1].strip()[:200]
    print(f'Q{i}: {question}')
    print(f'A{i}: {answer}')
    print(f'Time: {elapsed:.2f}s\n')

## Part 3: Train Pruning Agents

Train agents to learn which heads to prune.

**Training time:** ~15-20 minutes on T4 for 3 epochs

In [None]:
system.train(
    dataset_name='gsm8k',
    num_epochs=3,
    batch_size=2,
    learning_rate=1e-4,
    alpha_schedule=(0.01, 0.2),
    use_lora=False
)

print('\nTraining complete!')

## Part 4: Visualize Training Progress

In [None]:
history = system.training_history

if history:
    epochs = range(1, len(history) + 1)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    ax1.plot(epochs, [h['task_loss'] for h in history], 'b-o', linewidth=2, markersize=8)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Task Loss', fontsize=12)
    ax1.set_title('Math Problem Solving Performance', fontsize=13, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    
    ax2.plot(epochs, [h['sparsity_loss'] for h in history], 'r-o', linewidth=2, markersize=8)
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Sparsity Loss', fontsize=12)
    ax2.set_title('Pruning Pressure', fontsize=13, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print('No training history available')

## Part 5: Evaluate Pruned Model

In [None]:
results = system.evaluate(
    dataset_name='gsm8k',
    split='test',
    max_samples=200
)

print('\n' + '='*50)
print('FINAL RESULTS')
print('='*50)
print(f"Accuracy: {results['accuracy']:.1%}")
print(f"Correct: {results['correct']}/{results['total']}")
print(f"Sparsity: {results['sparsity']:.1%} heads pruned")
print('='*50)