# The Cost of 1 Million Brains
### Current AI Implementations Are Inefficient, and Existing Energy Tech Cannot Support Tomorrow's Models

<div style="background: linear-gradient(135deg, #9c31e4ff 0%, #292f56 100%); padding: 20px; border-radius: 10px; margin: 20px 0;">
    <h3 style="color: white; margin: 0;">Some Numbers</h3>
    <p style="color: white; margin: 10px 0;">GPT-4 consumes <b>1 million times</b> more power than a human brain while approaching similar intelligence levels. This notebook demonstrates why this is unsustainable and previews a solution that achieves <b>10-100× efficiency gains</b> using brain-inspired computing.</p>
</div>

---

**Brain-Inspired AI Series**  
**[Part 1: The Problem]** | [Part 2: Biology →](link) | [Part 3: Crisis →](link) | [Part 4: Solution →](link) | [Part 5: Impact →](link)

---

## AI Energy Consumption Is Shocking

The AI revolution is running into an energy crisis. The staggering efficiency gap between large-scale AI models and the human brain, from their initial "training" to continuous operation and even single tasks, is quantified in the table below. This notebook directly confronts this challenge by implementing and analyzing brain-inspired Spiking Neural Networks (SNNs) that prove a viable path toward sustainable AI by achieving a **10-100× reduction** in computational energy without sacrificing accuracy.

| Comparison | Metric | GPT-4 | Human Brain | Bottom Line |
|---|---|---|---|---|
| **Upfront "Training"** | Total Energy to Maturity | ~75 GWh<sup>a</sup> | ~4,900 kWh<sup>b</sup> | **>10,000×** more energy for GPT-4's "childhood" |
| **Operational Power** | Continuous Power Draw | ~20 MW<sup>c</sup> | 20 W | GPT-4 runs on the power of **1 million brains** |
| **Task Energy** | Energy for a Single Action | ~1,440 J / query<sup>c</sup> | ~1,200 J / minute | A single query costs more than a minute of thought |

<sup>a</sup>Based on scaling laws from GPT-3 (Patterson et al., 2021): 50-100 GWh

<sup>b</sup>Accounting for higher power consumption during adolescence: 25W × 22 years × 365.25 days/year × 24 hours/day continuous operation

<sup>c</sup>Estimated from datacenter deployments, i.e. query volume and server energy use, which place the likely power consumption in the 14-21 MW range: 2.5 billion queries per day at 0.3 - 0.4 Wh per query divided by 24 hours/day

As we will see below, this inefficiency reality of modern AI is approaching an insurmountable wall.

In [None]:
# Install dependencies
import sys
import subprocess
import importlib.util
import warnings
warnings.filterwarnings('ignore')

def install_package(package):
    """Install using pip"""
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package])

def check_and_install(package_name, import_name=None):
    """Check if installed, if not then install"""
    if import_name is None:
        import_name = package_name
    
    spec = importlib.util.find_spec(import_name)
    if spec is None:
        print(f"Installing {package_name}...")
        try:
            install_package(package_name)
            print(f"{package_name} installed successfully")
        except Exception as e:
            print(f"Failed to install {package_name}: {e}")
            return False
    return True

print("Setting up environment...")

required_packages = [
    # (package_name, import_name)
    ('numpy', 'numpy'),
    ('matplotlib', 'matplotlib'),
    ('torch', 'torch'),
    ('torchvision', 'torchvision'),
    ('tqdm', 'tqdm'),
    ('ipywidgets', 'ipywidgets'),
]

all_installed = True
for package, import_name in required_packages:
    if not check_and_install(package, import_name):
        all_installed = False

if not all_installed:
    print("\nSome packages failed to install. Please install manually:")
    print(f"pip install {' '.join(name for name, _ in required_packages)}")
else:
    print("Environment setup complete.")

# Imports, config, styling, and device setup
print("\nImporting dependencies, configuring styles and device...")
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Circle, Rectangle, FancyBboxPatch, FancyArrowPatch
from matplotlib.animation import FuncAnimation
from matplotlib.ticker import MultipleLocator, FormatStrFormatter, MaxNLocator
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
from torchvision import datasets, transforms
from IPython.display import HTML, display, clear_output
import time
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import json
import os

# Create directories for outputs
Path("figures").mkdir(exist_ok=True)
Path("models").mkdir(exist_ok=True)
Path("data").mkdir(exist_ok=True)

# Define color palette
COLORS = {
    'primary': '#667eea',      # Purple
    'secondary': '#f56565',    # Red  
    'success': '#48bb78',      # Green
    'warning': '#ed8936',      # Orange
    'severe': '#9b59b6',       # Purple
    'info': '#4299e1',         # Blue
    'dark': '#2d3748',         # Dark gray
    'light': '#f7fafc',        # Light gray
    'ann': '#e74c3c',          # ANN Red
    'snn': '#27ae60',          # SNN Green
    'brain': '#3498db',        # Brain Blue
    'dark_text': '#34495e'
}

# Configure matplotlib
plt.style.use('default')
plt.rcParams.update({
    'figure.figsize': (14, 7),
    'figure.facecolor': 'white',
    'axes.facecolor': 'white',
    'axes.edgecolor': '#CCCCCC',
    'axes.linewidth': 1.5,
    'font.size': 11,
    'axes.titlesize': 16,
    'axes.titleweight': 'bold',
    'axes.titlepad': 20,
    'axes.labelsize': 13,
    'axes.labelweight': 'bold',
    'axes.labelpad': 10,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'legend.fontsize': 11,
    'legend.frameon': True,
    'legend.fancybox': True,
    'legend.shadow': True,
    'figure.titlesize': 18,
    'lines.linewidth': 2.5,
    'lines.markersize': 8,
    'grid.alpha': 0.3,
    'grid.linestyle': '--',
    'axes.grid': True,
    'axes.spines.top': False,
    'axes.spines.right': False,
})
plt.rcParams['animation.embed_limit'] = 50.0

# Set random seeds for reproducibility
def set_seeds(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seeds(42)

# Configure device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(f"Using device: {device}")
print(f"{'GPU enabled.' if device.type == 'cuda' else 'Running on CPU (may be slower).'}")

## Part 1: Visualizing the Energy Gap

Let's start with energy estimates using numbers measured from real systems.

In [None]:
def create_energy_comparison():
    """Compare energy consumption of human brain and modern AI systems."""
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))
    fig.patch.set_facecolor('white')
    
    systems = ['Human\nBrain', 'GPT-3\n(Inference)', 'GPT-4\n(Inference)', 'GPT-4\n(Training)', 'GPT-5\n(Training Est.)']
    power_watts = [20, 100, 200, 50_000_000, 175_000_000]
    intelligence_relative = [100, 70, 85, 85, 92] # Relative to human = 100
    
    # Subplot 1: Power Consumption (Log Scale)
    colors = [COLORS['brain'], COLORS['warning'], COLORS['secondary'], COLORS['dark'], COLORS['severe']]
    bars = ax1.bar(systems, power_watts, color=colors, alpha=0.8, edgecolor='black', linewidth=2)

    for bar, power in zip(bars, power_watts):
        height = bar.get_height()
        if power < 1000:
            label = f'{power}W'
        elif power < 1_000_000:
            label = f'{power/1000:.0f}kW'
        else:
            label = f'{power/1_000_000:.0f}MW'
        
        # Adjust label position for log scale
        y_pos = height * 1.8
        ax1.text(bar.get_x() + bar.get_width()/2., y_pos,
                 label, ha='center', va='bottom', fontweight='bold', fontsize=13)
    
    ax1.set_yscale('log')
    ax1.set_ylabel('Power Consumption (Watts)')
    ax1.set_title('Power Requirements of Human Brain vs AI Models')
    ax1.set_ylim([10, 2_000_000_000])
    # ax1.set_xlim([-1, 5])
    ax1.grid(True, alpha=0.3, which='both')
    
    ax1.annotate('8.75 MILLION× more power', 
                 xy=(4, 175_000_000), xytext=(1.5, 1_000_000_000),
                 arrowprops=dict(arrowstyle='-', color='black', lw=1.5),
                 fontsize=14, fontweight='bold', color='black',
                 bbox=dict(boxstyle='round,pad=0.7', facecolor='yellow', 
                           edgecolor='black', alpha=0.9))
    
    # Subplot 2: Efficiency (Intelligence per Watt)
    efficiency = [intel/power for intel, power in zip(intelligence_relative, power_watts)]
    
    bars2 = ax2.bar(systems, efficiency, color=colors, alpha=0.8, edgecolor='black', linewidth=2)
    
    for bar, eff in zip(bars2, efficiency):
        height = bar.get_height()
        if eff >= 0.001:
            label = f'{eff:.3f}'
        else:
            label = f'{eff:.2e}'
        y_pos = height * 1.5
        ax2.text(bar.get_x() + bar.get_width()/2., y_pos,
                 label, ha='center', va='bottom', fontweight='bold', fontsize=11)
    
    ax2.set_yscale('log')
    ax2.set_ylabel('Intelligence per Watt (Efficiency Score)')
    ax2.set_title('Computational Efficiency Comparison')
    ax2.set_ylim([1e-10, 10])
    ax2.grid(True, alpha=0.3, which='both')
    
    # Add reference line for brain efficiency
    ax2.axhline(y=efficiency[0], color='green', linestyle='--', alpha=0.5, linewidth=2.5)
    ax2.text(4.5, efficiency[0] * 1.3, 'Brain efficiency baseline', 
             ha='right', va='bottom', color='green', fontweight='bold', fontsize=11)
    
    # Overall title
    fig.suptitle('Brain-Inspired Computing is Essential To Avert The Energy Crisis of Modern AI', 
                 fontsize=18, fontweight='bold', y=1.02)
    
    plt.tight_layout()
    return fig

# Create and display the comparison
fig1 = create_energy_comparison()
# fsave = 'figures/energy_comparison.png'
# plt.savefig(fsave, dpi=150, bbox_inches='tight', facecolor='white')
plt.show()
# print(f"Plot saved to {fsave}")

For reference, the "intelligence" score in the above-right plot is an average of the following dimensions: 

| Intelligence Facet       | Human | GPT-3 | GPT-4 | GPT-5 |
| ------------------------ | ----- | ----- | ----- | ----- |
| **Analytical Reasoning** | 1.0   | ~0.6  | ~0.9  | >1.0  |
| **Factual Knowledge** | 1.0   | >1.0  | >1.0  | >1.0  |
| **Linguistic Ability** | 1.0   | ~0.8  | ~0.95 | >1.0  |
| **Creativity (Ideation)**| 1.0   | ~0.7  | ~0.9  | ~1.0  |
| **Emotional Intelligence**| 1.0   | ~0.4  | ~0.6  | ~0.7  |
| **Common Sense Reasoning**| 1.0   | ~0.5  | ~0.7  | ~0.8  |
| **Embodied Cognition** | 1.0   | 0.0   | 0.0   | 0.0   |

These ratios are estimations based on available research and are illustrative at best. The field of AI is evolving rapidly, and these comparisons will change.

<div style="background: #fff5f5; border-left: 4px solid #f56565; padding: 15px; margin: 20px 0;">
    <b>The human brain operates on just 20 watts</b>, whereas global <b>GPT-4 service requires an estimated 20 MW to operate</b>, equivalent to <b>1 million human brains</b>. This immense energy gap is a fundamental challenge to the future of scalable AI.
</div>

## Part 2: The Exponential Scaling Crisis

The energy gap is only part of the problem. For the past decade, AI progress has focused on **bigger models = better performance**. Although this formula has driven remarkable advances, its exponential growth trajectory is approaching physical and economic limits. 

### The Math

The relationship between model size and computational requirements is problematic:
- Model sizes double every 6-12 months.
- Training energy scaling follows power laws.
- Data center capacity and chip manufacturing impose physical constraints on scaling.
- Training costs are approaching hundreds of millions of dollars.

#### Energy Scaling Laws

The relationship between model parameters and energy consumption follows empirical laws (Kaplan et al. 2020):

**E = k × N^α**

where:
- E = Training energy consumption
- N = Number of parameters  
- α ≈ 1.3 (empirically observed for transformers)
- k = Hardware efficiency constant

This linear scaling means that **doubling model size requires 2.5x more energy**, and 
- 10× larger model → 22× more energy
- 100× larger model → 500× more energy
- Physical limits reached around 10¹³ parameters

#### Phyical Constraints

##### 1. **Chip Manufacturing Limits**
- Current GPU clusters approach reticle limits
- Heat dissipation becomes impossible at larger scales
- Interconnect bandwidth saturates with system size

##### 2. **Power Grid Capacity**
- GPT-6 training would require a dedicated power plant
- Data centers already strain local electrical infrastructure
- Cooling requirements scale non-linearly with power consumption

##### 3. **Economic Barriers**
- Training costs are growing exponentially: $1M → $100M → $10B
- Only ~5 organizations can afford state-of-the-art training
- R&D investment are approaching GDP of small nations

### Historical Data (2018-2023)

| Model | Year | Parameters | Training Energy | Reference |
|-------|------|------------|-----------------|-------------------|
| **BERT-Large** | 2018 | 340M | ~1 MWh | Strubell et al. 2019 |
| **GPT-2** | 2019 | 1.5B | ~10 MWh | OpenAI scaling estimates |
| **GPT-3** | 2020 | 175B | 1,287 MWh | Patterson et al. 2021 |
| **PaLM** | 2022 | 540B | ~15 GWh | Google Research estimates |
| **GPT-4** | 2023 | ~1.8T* | ~50 GWh* | Industry analysis |

*Estimated values based on scaling laws

### Critical Projections (2025-2027)

**Conservative estimates** based on current scaling trends (assuming no fundamental algorithmic breakthroughs) project larger models requiring significantly more resources:

- **GPT-5 (2025)**: 10T parameters, 800 GWh training energy
- **GPT-6 (2027)**: 100T parameters, 12,000 GWh training energy

### Carbon Footprint Analysis

Using measured emissions data:

| Model | Training Emissions | Equivalent To |
|-------|-------------------|---------------|
| GPT-3 | 552 tons CO₂ | 120 cars for 1 year |
| GPT-4 | ~2,500 tons CO₂ | 540 cars for 1 year |
| GPT-5 (proj.) | 40,000 tons CO₂ | 8,700 cars for 1 year |
| GPT-6 (proj.) | 600,000 tons CO₂ | 130,000 cars for 1 year |

**GPT-6 training would emit more CO₂ than a coal power plant running for 6 months.**

### Resource Competition

This exponential scaling demands impossible resources:

- Requiring 100% of global GPU production for compute power
- Consuming entire power plant outputs
- Data center cooling approaching municipal supply levels
- Exhausting available rare earth elements and global chip manufacturing capacity

In [None]:
def create_scaling_visualization():
    """Shows how the energy problem explodes with model size."""
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
    
    # Model scaling data
    years = np.array([2018, 2019, 2020, 2022, 2023, 2025, 2027])
    model_names = ['BERT', 'GPT-2', 'GPT-3', 'PaLM', 'GPT-4', 'GPT-5*', 'GPT-6*']

    # Parameter counts (billions)
    parameters = np.array([
        0.34,    # BERT-Large: 340M parameters (Devlin et al. 2018)
        1.5,     # GPT-2: 1.5B parameters (Radford et al. 2019)
        175,     # GPT-3: 175B parameters (Brown et al. 2020)
        540,     # PaLM: 540B parameters (Chowdhery et al. 2022)
        1800,    # GPT-4: ~1.8T parameters (estimated, OpenAI hasn't disclosed)
        10000,   # GPT-5: 10T parameters (projected based on scaling trends)
        100000   # GPT-6: 100T parameters (projected, approaching physical limits)
    ])
    
    # Training energy consumption (GWh)
    energy_training = np.array([
        0.001,   # BERT: ~1 MWh (Strubell et al. 2019, "Energy and Policy Considerations for Deep Learning in NLP")
        0.01,    # GPT-2: ~10 MWh (estimated from compute requirements)
        1.3,     # GPT-3: 1,287 MWh (Patterson et al. 2021, "Energy and Policy Considerations for Deep Learning in NLP")
        15,      # PaLM: ~15 GWh (estimated from Google's reported compute)
        50,      # GPT-4: ~50 GWh (estimated based on training compute and efficiency)
        800,     # GPT-5: 800 GWh (projected using E ∝ N^1.3 scaling law)
        12000    # GPT-6: 12,000 GWh (exceeds small country consumption)
    ])
    
    # Mark projected vs historical data
    projected_mask = np.array([False, False, False, False, False, True, True])

    # Create color palette
    colors = {
        'historical': '#2E86AB',    # Blue for verified data
        'projected': '#F24236',     # Red for projections
        'crisis': '#FF6B6B',        # Bright red for crisis zones
        'warning': '#F18F01',       # Orange for warning zones
        'safe': '#4CAF50'           # Green for manageable zones
    }
    
    # Subplot 1: Parameter Growth
    ax1.semilogy(years[~projected_mask], parameters[~projected_mask], 'o-', 
                 color=colors['historical'], linewidth=3, markersize=10, 
                 label='Historical', alpha=0.9)
    ax1.semilogy(years[projected_mask], parameters[projected_mask], 's--', 
                 color=colors['projected'], linewidth=3, markersize=10, 
                 label='Projected', alpha=0.9)
    
    # Add red dashed line connecting historical to projected
    last_historical_idx = np.where(~projected_mask)[0][-1]
    first_projected_idx = np.where(projected_mask)[0][0]
    ax1.plot(
        [years[last_historical_idx], years[first_projected_idx]],
        [parameters[last_historical_idx], parameters[first_projected_idx]],
        '--', color=colors['projected'], linewidth=3, alpha=0.9
    )
    
    # Annotate models
    annotation_offsets = [15, -25, 15, -25, 15, -25, 15]
    for i, (year, name, param) in enumerate(zip(years, model_names, parameters)):
        color = colors['projected'] if projected_mask[i] else colors['historical']
        weight = 'bold' if projected_mask[i] else 'normal'
        ax1.annotate(name, (year, param), textcoords="offset points", 
                    xytext=(0, annotation_offsets[i]), ha='center', 
                    fontsize=10, fontweight=weight, color=color)
    
    # Add trend line to emphasize exponential growth
    z = np.polyfit(years, np.log10(parameters), 1)
    trend_line = 10**(z[0] * years + z[1])
    ax1.plot(years, trend_line, ':', color='gray', alpha=0.7, linewidth=2, 
             label=f'Exponential Trend (10^{z[0]:.1f}x per year)')
    
    ax1.set_xlabel('Year', fontsize=13, fontweight='bold')
    ax1.set_ylabel('Parameters (Billions)', fontsize=13, fontweight='bold')
    ax1.set_title('Exponential Growth in Model Size', 
                  fontsize=15, fontweight='bold')
    ax1.grid(True, alpha=0.3, which='both')
    ax1.legend(fontsize=11, loc='lower right')
    ax1.set_xlim(2017, 2028)
    ax1.set_ylim(0.1, 200000)
    
    # Subplot 2: Energy Crisis Visualization
    ax2.semilogy(years[~projected_mask], energy_training[~projected_mask], 'o-', 
                 color=colors['historical'], linewidth=3, markersize=10, 
                 label='Historical', alpha=0.9)
    ax2.semilogy(years[projected_mask], energy_training[projected_mask], 's--', 
                 color=colors['projected'], linewidth=3, markersize=10, 
                 label='Projected', alpha=0.9)
    
    # Add red dashed line connecting historical to projected
    ax2.plot(
        [years[last_historical_idx], years[first_projected_idx]],
        [energy_training[last_historical_idx], energy_training[first_projected_idx]],
        '--', color=colors['projected'], linewidth=3, alpha=0.9
    )
    
    # Add context zones with real-world energy consumption references
    ax2.axhspan(1000, 100000, alpha=0.2, color=colors['crisis'], 
                label='UNSUSTAINABLE (Small Country)')
    ax2.axhspan(100, 1000, alpha=0.2, color=colors['warning'], 
                label='CHALLENGING (Major City)')  
    ax2.axhspan(0.0001, 100, alpha=0.2, color=colors['safe'], 
                label='MANAGEABLE (Data Center)')
    
    # Add specific energy consumption context labels
#     energy_contexts = [
#         (2027.2, 15000, 'ENTIRE\nICELAND\nANNUAL\nCONSUMPTION', 'darkred'),
#         (2027.2, 500, 'NEW YORK CITY\nMONTHLY\nCONSUMPTION', 'darkorange'),
#         (2027.2, 10, 'LARGE DATA\nCENTER\nCAPACITY', 'darkgreen')
#     ]
    
#     for year, energy, text, color in energy_contexts:
#         ax2.text(year, energy, text, fontsize=9, fontweight='bold', 
#                 color=color, ha='left', va='center',
#                 bbox=dict(boxstyle='round,pad=0.3', facecolor='white', 
#                          edgecolor=color, alpha=0.8))
    
    ax2.set_xlabel('Year', fontsize=13, fontweight='bold')
    ax2.set_ylabel('Training Energy (GWh)', fontsize=13, fontweight='bold')
    ax2.set_title('Energy Requirements Approaching Physical Limits', 
                  fontsize=15, fontweight='bold')
    ax2.grid(True, alpha=0.3, which='both')
    ax2.legend(loc='lower right', fontsize=10)
    ax2.set_xlim(2017, 2028)
    ax2.set_ylim(0.0001, 100000)
    
    # Add crisis annotation
#     ax2.annotate('PHYSICAL\nLIMITS\nREACHED', 
#                 xy=(2027, 12000), xytext=(2024, 40000),
#                 arrowprops=dict(arrowstyle='->', color=colors['crisis'], lw=3),
#                 fontsize=12, fontweight='bold', color=colors['crisis'],
#                 bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', 
#                          edgecolor=colors['crisis'], alpha=0.9, linewidth=2))
    
    # Add scaling law to plot
    ax2.text(2019, 5000, 'Energy ∝ Parameters^1.3\n(Empirical Scaling Law)', 
            fontsize=11, style='italic', 
            bbox=dict(boxstyle='round,pad=0.4', facecolor='lightblue', alpha=0.7))
    
    # Main title
    fig.suptitle('Current Scaling Trends of AI Model Size and Energy ' +
                'Consumption Lead to Impossible Energy Requirements', 
                fontsize=16, fontweight='bold', y=0.98)
    
    plt.tight_layout()
    return fig

fig2 = create_scaling_visualization()
# fsave = 'figures/scaling_crisis.png'
# plt.savefig(fsave, dpi=150, bbox_inches='tight')
plt.show()
# print(f"Plot saved to {fsave}")

## Part 3: Innovation To Avert Energy Crisis

Immediate consequences of the energy scaling crisis are all limiting:

- Training costs will exclude all but a handful of organizations
- Research democratization becomes impossible without efficiency breakthroughs
- Carbon regulations may limit AI development in many countries

### Sparked Interest In Neuromorphic Computing 

Taking cues from nature, which has optimized neural networks over millions of years of R&D, researchers have been drawn to **neuromorphic computing**: hardware and algorithms that mimic the brain's energy-efficient information processing. 

At the heart of this approach are **Spiking Neural Networks (SNNs)**, which process information through discrete spikes rather than continuous activations, and completely change the energy equation:

**Traditional Artificial Neural Networks (ANNs)**: Energy ∝ (Number of model parameters)^1.3
**Spiking Neural Networks (SNNs)**: Energy ∝ (Spike activity)

In SNNs, spiking activity is sparse, i.e. typically only 1-10% of neurons are active at once,similar to the human brain. The brain also shows what's possible:
- **20 watts total power consumption**
- **86 billion neurons, 100 trillion synapses**
- **Equivalent to 10^15 operations/second at 20W**
- **10^6 times more efficient than current ANN-based AI systems**

The potential efficiency gains over ANNs are possible because SNNs
- Process info only when spikes occur
- Eliminate unnecessary matrix multiplications
- Encode info in spike timing instead of amplitude
- Leverage hardware that is optimized for spike processing

### Key Differences Between ANNs and SNNs

| Aspect | Traditional ANNs | Spiking Neural Networks |
|--------|------------------|-------------------------|
| **Info Encoding** | Continuous values (0.0 to 1.0) | Binary spikes (0 or 1) |
| **Computation** | Matrix multiplications | Event-driven spike processing |
| **Hardware** | GPUs/CPUs (power-hungry) | Neuromorphic chips (ultra-low power) |
| **Biological Realism** | Abstract mathematical functions | Mimics actual neuron behavior |
| **Energy Efficiency** | High power consumption | Orders of magnitude less power |

### See It For Yourself

In the following, we build an ANN and SNN, process the same visual data through the two different neural networks, and then evaluate their energy consumption in real-time. For each neural network, we'll count actual mathematical operations and convert them to real power consumption. 

### Technical Implementation

#### Surrogate Gradient Learning for SNNs

The biggest challenge in training SNNs is the non-differentiable spike function, which makes standard backpropagation during training impossible. We solve this using surrogate gradients, e.g. a smooth sigmoid derivative, during the backward pass while maintaining exact spike behavior during forward pass. This approach enables learning.

In [None]:
class SurrogateGradientLIF(torch.autograd.Function):
    """Surrogate gradient for training SNNs."""
    @staticmethod
    def forward(ctx, input, threshold=1.0):
        ctx.save_for_backward(input)
        ctx.threshold = threshold
        return (input >= threshold).float()
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        alpha = 10.0 # Controls surrogate gradient steepness
        # Use sigmoid derivative as smooth approximation
        grad = grad_output * alpha * torch.sigmoid(alpha * input) * (1 - torch.sigmoid(alpha * input))
        return grad, None


#### Biologically-Inspired Spiking Neural Network

To finish building out our SNN, we use a Integrate-and-Fire (LIF) neuron model with learnable parameters and biologically-accurate features:

- Membrane potential that leaks over time (alpha parameter)
- Spike threshold crossing that triggers output spike
- Membrane that resets after spiking (beta parameter)
- Temporal dynamics across multiple time steps

In [None]:
class SpikingNeuralNet(nn.Module):
    """Biologically-inspired SNN with training."""
    
    def __init__(self, input_size=784, hidden_size=128, output_size=10, 
                 timesteps=10, v_threshold=0.5):
        super().__init__()
        
        # Standard linear layers for synaptic connections
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        
        # Learnable neuron dynamics parameters
        self.alpha = nn.Parameter(torch.ones(1) * 0.9)  # Membrane leak rate
        self.beta = nn.Parameter(torch.ones(1) * 0.8)   # Reset strength after spike
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.timesteps = timesteps
        self.v_threshold = v_threshold
        
        # Initialize weights with proper scaling for spiking networks
        nn.init.normal_(self.fc1.weight, mean=0, std=np.sqrt(2/input_size))
        nn.init.normal_(self.fc2.weight, mean=0, std=np.sqrt(2/hidden_size))
        
        # Track spike statistics for analysis
        self.spike_rates = {'input': 0, 'hidden': 0, 'output': 0}
        
    def encode_input(self, x):
        """
        Convert static input to temporal spike trains using rate encoding.
        Higher pixel intensities → higher spike probability
        """
        batch_size = x.shape[0]
        x_normalized = (x - x.min()) / (x.max() - x.min() + 1e-8)
        
        spike_trains = []
        for t in range(self.timesteps):
            # Add temporal variation to make encoding more realistic
            phase = (t / self.timesteps) * 2 * np.pi
            rate_modulation = 0.5 + 0.5 * np.sin(phase)
            spike_prob = x_normalized * rate_modulation
            spikes = torch.bernoulli(spike_prob)
            spike_trains.append(spikes)
            
        return spike_trains
    
    def forward(self, x, meter=None):
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)
        
        # Encode input as spike trains
        input_spikes = self.encode_input(x)
        
        # Initialize membrane potentials
        v1 = torch.zeros(batch_size, self.hidden_size, device=x.device)
        v2 = torch.zeros(batch_size, self.output_size, device=x.device)
        
        # Accumulate output spikes over time
        output_spikes = torch.zeros(batch_size, self.output_size, device=x.device)
        
        # Track total spikes for efficiency analysis
        total_input_spikes = 0
        total_hidden_spikes = 0
        total_output_spikes = 0
        
        # Process each time step
        for t in range(self.timesteps):
            # Layer 1: Input → Hidden
            h1 = self.fc1(input_spikes[t])
            v1 = self.alpha * v1 + h1  # Leaky integration
            
            # Generate spikes using surrogate gradient
            spike_func = SurrogateGradientLIF.apply
            spikes1 = spike_func(v1, self.v_threshold)
            v1 = v1 * (1 - spikes1) * self.beta  # Reset spiked neurons
            
            # Layer 2: Hidden → Output
            h2 = self.fc2(spikes1)
            v2 = self.alpha * v2 + h2
            spikes2 = spike_func(v2, self.v_threshold)
            v2 = v2 * (1 - spikes2) * self.beta
            
            output_spikes += spikes2
            
            # Count spikes for energy analysis
            input_spike_count = input_spikes[t].sum().item()
            hidden_spike_count = spikes1.sum().item()
            output_spike_count = spikes2.sum().item()
            
            total_input_spikes += input_spike_count
            total_hidden_spikes += hidden_spike_count
            total_output_spikes += output_spike_count
            
            # Energy accounting for neuromorphic hardware simulation
            if meter:
                active_synapses = (
                    input_spike_count * self.hidden_size +
                    hidden_spike_count * self.output_size
                )
                
                # Memory access for spike storage/routing
                memory_bytes = 4 * (input_spike_count + hidden_spike_count + output_spike_count)
                
                meter.add_operations(
                    spikes=input_spike_count + hidden_spike_count + output_spike_count,
                    synaptic_ops=active_synapses,
                    memory_bytes=memory_bytes
                )
        
        # Calculate network-wide spike rates for analysis
        total_neurons = self.input_size + self.hidden_size + self.output_size
        total_possible_spikes = total_neurons * self.timesteps * batch_size
        actual_spikes = total_input_spikes + total_hidden_spikes + total_output_spikes
        
        self.spike_rates = {
            'input': total_input_spikes / (self.input_size * self.timesteps * batch_size),
            'hidden': total_hidden_spikes / (self.hidden_size * self.timesteps * batch_size),
            'output': total_output_spikes / (self.output_size * self.timesteps * batch_size),
            'overall': actual_spikes / total_possible_spikes
        }
        
        # Return average spike count as network output
        return output_spikes / self.timesteps
    


### Hardware-Realistic Energy Modeling

Our energy model is based on published measurements of NVIDIA's V100 (for our ANN) and Intel's Loihi (SNN).

In [None]:
class EnergyMeter:
    """Simulated energy modeling based on published hardware measurements:
    - NVIDIA V100: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf
    - Intel Loihi: Davies et al., "Loihi: A Neuromorphic Manycore Processor", IEEE Micro 2018
    
    These are theoretical calculations, not actual measurements.
    Real-world energy consumption varies based on:
    - Workload characteristics
    - Cooling overhead (adds 30-40%)
    - Power delivery efficiency (85-90%)"""
    
    def __init__(self, name, device_type='gpu'):
        self.name = name
        self.device_type = device_type
        
        if device_type == 'gpu':
            # NVIDIA V100 measurements from datasheet
            self.joules_per_mac = 4.6e-12         # 4.6 pJ per multiply-accumulate
            self.memory_energy_per_byte = 2.6e-9  # HBM2: 2.6 nJ per byte access
            self.idle_power = 10.0                # Idle GPU power consumption (W)
            self.active_power_multiplier = 15.0   # Active processing power boost
            
        elif device_type == 'neuromorphic':
            # Intel Loihi measurements from research papers
            self.joules_per_spike = 23e-12       # 23 pJ per spike event
            self.joules_per_synop = 120e-15      # 120 fJ per synaptic operation
            self.memory_energy_per_byte = 0.1e-9 # SRAM: 0.1 nJ per byte (much lower than HBM2)
            self.idle_power = 0.050              # Loihi idle: 50mW (vs 10W for GPU)
            self.active_power_multiplier = 3.0   # Lower power boost for neuromorphic
            
        self.reset_counters()
        
    def reset_counters(self):
        """Reset energy tracking counters"""
        self.total_energy = 0
        self.operations = 0
        self.memory_bytes = 0
        self.instant_power = 0
        self.time_elapsed = 0
        self.is_active = False
        
    def add_operations(self, macs=0, spikes=0, synaptic_ops=0, memory_bytes=0):
        """Add computational operations."""
        if self.device_type == 'gpu':
            self.operations += macs
        else: # neuromorphic
            self.operations += spikes + synaptic_ops
            
        self.memory_bytes += memory_bytes
        self.is_active = (self.operations > 0)
    
    def compute_energy(self, dt=0.001):
        """Calculate realistic energy consumption with activity-dependent power.
        Energy = (Static Power) * Time + Dynamic Energy per Operation
        """
        if self.device_type == 'gpu':
            # GPU: High static power, high dynamic energy
            if self.is_active:
                base_power = self.idle_power * self.active_power_multiplier
            else:
                base_power = self.idle_power
                
            dynamic_energy = (self.operations * self.joules_per_mac + 
                            self.memory_bytes * self.memory_energy_per_byte)
            static_energy = base_power * dt
            
        else: # neuromorphic
            # Neuromorphic: Very low static power, very low dynamic energy
            if self.is_active:
                base_power = self.idle_power * self.active_power_multiplier
            else:
                base_power = self.idle_power
                
            spike_energy = self.operations * self.joules_per_spike
            memory_energy = self.memory_bytes * self.memory_energy_per_byte
            static_energy = base_power * dt
            dynamic_energy = spike_energy + memory_energy
        
        frame_energy = dynamic_energy + static_energy
        self.total_energy += frame_energy
        self.instant_power = frame_energy / dt if dt > 0 else 0
        self.time_elapsed += dt
        
        # Reset per-frame counters
        self.operations = 0
        self.memory_bytes = 0
        self.is_active = False
        
        return self.instant_power, self.total_energy
    
    def get_metrics(self):
        """Return energy metrics."""
        avg_power = self.total_energy / max(self.time_elapsed, 1e-9)
        
        # Battery capacity calculation for 3.7V Li-ion battery
        battery_voltage = 3.7
        mah = (self.total_energy / battery_voltage) / 3.6  # Convert J to mAh
        
        return {
            'total_energy_j': self.total_energy,
            'avg_power_w': avg_power,
            'instant_power_w': self.instant_power,
            'battery_mah': mah,
            'time_elapsed_s': self.time_elapsed
        }
   

### Real-Time ANN vs SNN Energy Consumption While Processing MNIST Digits

The following visualization compares the ANN and SNN hardware across multiple dimensions.

In [None]:
# CODE SECTION THAT NEEDS TO BE EDITED

def create_metric_card(ax, title, value, unit, color='#667eea'):
    """Create metric display cards for the dashboard."""
    ax.clear()
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.axis('off')
    
    # Draw card background with rounded corners
    fancy_box = FancyBboxPatch(
        (0.05, 0.05), 0.9, 0.9,
        boxstyle="round,pad=0.03",
        facecolor='#ffffff',
        edgecolor=color,
        linewidth=2.5,
        alpha=0.8
    )
    ax.add_patch(fancy_box)
    
    # Add text elements
    ax.text(0.5, 0.82, title.upper(), fontsize=9, ha='center', va='center',
            color='#495057', fontweight='bold')
    ax.text(0.5, 0.5, f"{value}", fontsize=14, ha='center', va='center',
            color=color, fontweight='bold')
    ax.text(0.5, 0.25, unit, fontsize=8, ha='center', va='center',
            color='#6c757d')

def create_visualization():
    """Compare ANN vs SNN performance:
    1. Input images being processed
    2. Network activity patterns
    3. Output predictions and accuracy
    4. Real-time power consumption
    5. Cumulative energy usage
    6. Energy efficiency advantage
    """
    
    # Load MNIST dataset for demo
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = datasets.MNIST(
        root='./data', train=True, download=True, transform=transform
    )
    
    demo_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=1, shuffle=True
    )
    
    # Initialize both network types for comparison
    ann = nn.Sequential(
        nn.Flatten(),
        nn.Linear(784, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    ).to(device)
    
    snn = SpikingNeuralNet(
        input_size=784, 
        hidden_size=128, 
        output_size=10,
        timesteps=10
    ).to(device)
    
    # Quick training phase to ensure meaningful predictions
    print("Training models for demo...")
    optimizer_ann = torch.optim.Adam(ann.parameters(), lr=0.001)
    optimizer_snn = torch.optim.Adam(snn.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=64, shuffle=True
    )
    
    # Train both networks briefly (batch_idx below determine loop exit)
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx >= 50:
            break
            
        data, target = data.to(device), target.to(device)
        
        # Train ANN
        optimizer_ann.zero_grad()
        output_ann = ann(data)
        loss_ann = criterion(output_ann, target)
        loss_ann.backward()
        optimizer_ann.step()
        
        # Train SNN
        optimizer_snn.zero_grad()
        output_snn = snn(data)
        loss_snn = criterion(output_snn, target)
        loss_snn.backward()
        optimizer_snn.step()
    
    print("Training complete. Generating animation frames for live comparison...")
    
    # Initialize energy meters for both architectures
    ann_meter = EnergyMeter("ANN", device_type='gpu')
    snn_meter = EnergyMeter("SNN", device_type='neuromorphic')
    
    # Create dashboard
    fig = plt.figure(figsize=(20, 18))
    fig.patch.set_facecolor('#f8f9fa')
    
    # Main title
    fig.suptitle(
        'Real-Time Energy Comparison: ' +
        'Traditional Neural Networks (GPU) vs Spiking Neural Networks (Neuromorphic)',
        fontsize=16, fontweight='bold', y=0.98
    )
    
    # Create grid layout for multiple plots
    gs = gridspec.GridSpec(10, 6, figure=fig,
                          height_ratios=[0.8, 0.8, 0.8, 0.8, 1, 1, 1, 1, 1, 1],
                          width_ratios=[1, 1, 1, 1, 1, 1],
                          hspace=1.1, wspace=0.5,
                          left=0.05, right=0.95, top=0.92, bottom=0.05)

    # Define subplot positions
    ax_input = fig.add_subplot(gs[1:3, 0])
    ax_ann_activity = fig.add_subplot(gs[:2, 1:3])
    ax_ann_output = fig.add_subplot(gs[:2, 3:5])
    ax_ann_accuracy = fig.add_subplot(gs[:2, 5])
    ax_snn_activity = fig.add_subplot(gs[2:4, 1:3])
    ax_snn_output = fig.add_subplot(gs[2:4, 3:5])
    ax_snn_accuracy = fig.add_subplot(gs[2:4, 5])
    ax_power = fig.add_subplot(gs[4:6, :5])    
    ax_power_ann_card = fig.add_subplot(gs[4, 5])
    ax_power_snn_card = fig.add_subplot(gs[5, 5])
    ax_energy = fig.add_subplot(gs[6:8, :5])
    ax_energy_ann_card = fig.add_subplot(gs[6, 5])
    ax_energy_snn_card = fig.add_subplot(gs[7, 5])
    ax_efficiency = fig.add_subplot(gs[8:, :5])
    ax_energy_card = fig.add_subplot(gs[8:, 5])
    
    # Configure input image display
    ax_input.set_title('Input Image', fontsize=10, fontweight='bold', pad=5)
    ax_input.axis('off')
    input_img = ax_input.imshow(np.zeros((28, 28)), cmap='viridis', vmin=-1, vmax=1)
    
    # Configure network activity displays
    ax_ann_activity.set_title('ANN Activity', fontsize=10, fontweight='bold', pad=5)
    ax_ann_activity.set_ylim(0, 105)
    ax_ann_activity.set_ylabel('Activation Level (%)', fontsize=11, fontweight='bold')
    ax_ann_activity.set_xticks([0, 1, 2])
    ax_ann_activity.set_xticklabels(['Input', 'Hidden', 'Output'], fontsize=8)
    ax_ann_activity.spines['top'].set_visible(False)
    ax_ann_activity.spines['right'].set_visible(False)
    
    ax_snn_activity.set_title('SNN Activity', fontsize=10, fontweight='bold', pad=5)
    ax_snn_activity.set_ylim(0, 105)
    ax_snn_activity.set_ylabel('Spike Rate (%)', fontsize=11, fontweight='bold')
    ax_snn_activity.set_xticks([0, 1, 2])
    ax_snn_activity.set_xticklabels(['Input', 'Hidden', 'Output'], fontsize=8)
    ax_snn_activity.spines['top'].set_visible(False)
    ax_snn_activity.spines['right'].set_visible(False)
    
    # Initialize activity bar charts
    ann_bars = ax_ann_activity.bar([0, 1, 2], [0, 0, 0], color=COLORS['ann'], alpha=0.7)
    snn_bars = ax_snn_activity.bar([0, 1, 2], [0, 0, 0], color=COLORS['snn'], alpha=0.7)
    
    # Configure output prediction displays
    ax_ann_output.set_title('Output Predictions', fontsize=10, fontweight='bold', pad=5)
    ax_ann_output.set_ylim(0, 1.05)
    ax_ann_output.set_ylabel('Confidence', fontsize=9)
    ax_ann_output.set_xticks(range(10))
    ax_ann_output.set_xticklabels(range(10), fontsize=8)
    
    ax_snn_output.set_title('Output Predictions', fontsize=10, fontweight='bold', pad=5)
    ax_snn_output.set_ylim(0, 1.05)
    ax_snn_output.set_ylabel('Confidence', fontsize=9)
    ax_snn_output.set_xticks(range(10))
    ax_snn_output.set_xticklabels(range(10), fontsize=8)
    
    ann_output_bars = ax_ann_output.bar(range(10), np.zeros(10), color='salmon', alpha=0.7)
    snn_output_bars = ax_snn_output.bar(range(10), np.zeros(10), color='lightgreen', alpha=0.7)
    
    # Configure power consumption log plot
    ax_power.set_title('Instantaneous Power Draw', fontsize=11, fontweight='bold', pad=15)
    ax_power.set_xlabel('Time (seconds)', fontsize=9)
    ax_power.set_ylabel('Power (W)', fontsize=9)
    ax_power.set_xlim(0, 5)
    ax_power.set_yscale('log')
    ax_power.set_ylim(0.01, 100)
    ax_power.grid(True, alpha=0.3, which='both')
    
    ann_power_line, = ax_power.plot([], [], 'r-', linewidth=2.5, label='ANN (GPU)', alpha=0.8)
    snn_power_line, = ax_power.plot([], [], 'g-', linewidth=2.5, label='SNN (Neuromorphic)', alpha=0.8)
    ax_power.legend(loc='upper right', fontsize=9)
    
    # Configure cumulative energy plot
    ax_energy.set_title('Cumulative Energy Consumption', fontsize=11, fontweight='bold', pad=15)
    ax_energy.set_xlabel('Time (seconds)', fontsize=9)
    ax_energy.set_ylabel('Energy (J)', fontsize=9)
    ax_energy.set_xlim(0, 5)
    ax_energy.grid(True, alpha=0.3)
    
    ann_energy_line, = ax_energy.plot([], [], 'r-', linewidth=2.5, label='ANN', alpha=0.8)
    snn_energy_line, = ax_energy.plot([], [], 'g-', linewidth=2.5, label='SNN', alpha=0.8)
    ax_energy.legend(loc='upper left', fontsize=9)
    
    # Configure efficiency advantage plot
    ax_efficiency.set_title('Energy Efficiency Advantage', fontsize=11, fontweight='bold', pad=15)
    ax_efficiency.set_xlabel('Time (seconds)', fontsize=9)
    ax_efficiency.set_ylabel('SNN Efficiency Gain (×)', fontsize=9)
    ax_efficiency.set_xlim(0, 5)
    ax_efficiency.set_ylim(0, 500)
    ax_efficiency.grid(True, alpha=0.3)
    
    efficiency_line, = ax_efficiency.plot([], [], 'b-', linewidth=3, alpha=0.8)
    efficiency_fill = None # Will be updated dynamically
    ax_efficiency.axhline(y=1, color='gray', linestyle='--', alpha=0.5, label='Break-even')
    ax_efficiency.legend(loc='upper left', fontsize=9)
    
    # Initialize data storage for time series
    time_data = []
    ann_power_data = []
    snn_power_data = []
    ann_energy_data = []
    snn_energy_data = []
    efficiency_data = []
    
    # Animation state variables
    data_iter = iter(demo_loader)
    samples_processed = 0
    ann_correct = 0
    snn_correct = 0
    processing_new_sample = False
    ann_pred_display = None
    snn_pred_display = None
    
    def animate(frame):
        """Animation function called for each frame of visualization."""
        nonlocal data_iter, samples_processed, ann_correct, snn_correct, processing_new_sample
        nonlocal efficiency_fill, ann_pred_display, snn_pred_display
        
        dt = 0.025 # 1/dt FPS
        current_time = frame * dt
        
        # Process new sample every 0.5 seconds
        if frame % int(0.5 / dt) == 0:
            processing_new_sample = True
            
            try:
                image, label = next(data_iter)
            except StopIteration:
                data_iter = iter(demo_loader)
                image, label = next(data_iter)
            
            image, label = image.to(device), label.to(device)
            
            # Run inference with energy tracking
            with torch.no_grad():
                # ANN inference - traditional approach
                ann_output = ann(image)
                ann_probs = F.softmax(ann_output, dim=1).squeeze().cpu().numpy()
                ann_pred = ann_output.argmax(1).item()
                
                # Count operations for ANN (matrix multiplications)
                ann_meter.add_operations(
                    macs=784 * 128 + 128 * 10, # Two dense layer operations
                    memory_bytes=4 * (784 + 128 + 10) * 2 # Weight and activation storage
                )
                
                # SNN inference - event-driven
                snn_output = snn(image, meter=snn_meter)
                snn_probs = F.softmax(snn_output, dim=1).squeeze().cpu().numpy()
                snn_pred = snn_output.argmax(1).item()

                ann_pred_display = ann_pred
                snn_pred_display = snn_pred
                
                # Track accuracy for both networks
                if ann_pred == label.item():
                    ann_correct += 1
                if snn_pred == label.item():
                    snn_correct += 1
                    
                samples_processed += 1
            
            # Update input image display
            input_img.set_data(image.squeeze().cpu().numpy())
            
            # Flash ANN activity bars (merely for visual feedback)
            ann_bars[0].set_height(100)
            ann_bars[0].set_color('#ff6b6b')
            ann_bars[1].set_height(100)
            ann_bars[1].set_color('#ff6b6b')
            ann_bars[2].set_height(100)
            ann_bars[2].set_color('#ff6b6b')
            
            # Update SNN activity with measured spike rates
            spike_rates = [
                snn.spike_rates['input'] * 100,
                snn.spike_rates['hidden'] * 100,
                snn.spike_rates['output'] * 100
            ]
            snn_bars[0].set_height(spike_rates[0])
            snn_bars[1].set_height(spike_rates[1])
            snn_bars[2].set_height(spike_rates[2])
            
            # Update output prediction displays with color coding
            for i, bar in enumerate(ann_output_bars):
                bar.set_height(ann_probs[i])
                if i == ann_pred:
                    bar.set_color(COLORS['ann'] if ann_pred == label.item() else '#ff9999')
                else:
                    bar.set_color('salmon')
            
            for i, bar in enumerate(snn_output_bars):
                bar.set_height(snn_probs[i])
                if i == snn_pred:
                    bar.set_color(COLORS['snn'] if snn_pred == label.item() else '#99ff99')
                else:
                    bar.set_color('lightgreen')
        else:
            # Fade ANN activity bars back to normal color (just for visual feedback)
            if processing_new_sample:
                ann_bars[0].set_color(COLORS['ann'])
                ann_bars[1].set_color(COLORS['ann'])
                ann_bars[2].set_color(COLORS['ann'])
                processing_new_sample = False
        
        # Update energy measurements for both systems
        ann_power, ann_energy = ann_meter.compute_energy(dt)
        snn_power, snn_energy = snn_meter.compute_energy(dt)
        
        # Store time series data
        time_data.append(current_time)
        ann_power_data.append(ann_power)
        snn_power_data.append(snn_power)
        ann_energy_data.append(ann_energy)
        snn_energy_data.append(snn_energy)
        
        # Calculate efficiency advantage of SNN compared to ANN
        if snn_energy > 1e-12:
            efficiency = ann_energy / snn_energy
            efficiency_data.append(min(efficiency, 10000)) # Cap for display purposes
        else:
            efficiency_data.append(0)
        
        # Update all time series plots
        ann_power_line.set_data(time_data, ann_power_data)
        snn_power_line.set_data(time_data, snn_power_data)
        ann_energy_line.set_data(time_data, ann_energy_data)
        snn_energy_line.set_data(time_data, snn_energy_data)
        efficiency_line.set_data(time_data, efficiency_data)
        
        # Update efficiency fill area
        if efficiency_fill:
            efficiency_fill.remove()
        efficiency_fill = ax_efficiency.fill_between(
            time_data, 1, efficiency_data, 
            where=[e > 1 for e in efficiency_data],
            color='blue', alpha=0.2
        )
        
        # Dynamic y-axis scaling for energy plot
        if len(ann_energy_data) > 0:
            max_energy = max(max(ann_energy_data), max(snn_energy_data)) * 1.2
            ax_energy.set_ylim(0, max_energy)
        
        # Update all metric cards with current values
        ann_metrics = ann_meter.get_metrics()
        snn_metrics = snn_meter.get_metrics()
        
        # Power consumption cards
        create_metric_card(ax_power_ann_card, "ANN Power", 
                         f"{ann_metrics['instant_power_w']:.1f}", "W", COLORS['ann'])
        create_metric_card(ax_power_snn_card, "SNN Power",
                         f"{snn_metrics['instant_power_w']*1000:.1f}", "mW", COLORS['snn'])

        # Cumulative energy cards
        create_metric_card(ax_energy_ann_card, "ANN Energy", 
                         f"{ann_metrics['total_energy_j']:.2f}", "J", COLORS['ann'])   
        create_metric_card(ax_energy_snn_card, "SNN Energy",
                         f"{snn_metrics['total_energy_j']*1000:.2f}", "mJ", COLORS['snn'])
        
        # Energy efficiency advantage card
        if len(efficiency_data) > 0 and efficiency_data[-1] > 0:
            create_metric_card(ax_energy_card, "SNN Advantage",
                             f"{efficiency_data[-1]:.1f}×", "less energy", '#3498db')
        
        # Accuracy cards with prediction display
        if samples_processed > 0:
            ann_acc = (ann_correct / samples_processed) * 100
            snn_acc = (snn_correct / samples_processed) * 100

            ann_unit_text = f"Predicted: {ann_pred_display}\nAccuracy: {ann_acc:.0f}%"
            snn_unit_text = f"Predicted: {snn_pred_display}\nAccuracy: {snn_acc:.0f}%"

            create_metric_card(ax_ann_accuracy, "ANN Performance", 
                            ann_unit_text, f"{samples_processed} samples", COLORS['ann'])
            create_metric_card(ax_snn_accuracy, "SNN Performance",
                            snn_unit_text, f"{samples_processed} samples", COLORS['snn'])
        
        return [ann_power_line, snn_power_line, ann_energy_line, snn_energy_line,
                efficiency_line, input_img] + list(ann_bars) + list(snn_bars) + \
               list(ann_output_bars) + list(snn_output_bars)
    
    # Create and return animation
    anim = FuncAnimation(fig, animate, frames=1000, interval=25, blit=False)
    
    return fig, anim

# if __name__ == "__main__":
    
#     # Create and display the energy comparison
#     fig, anim = create_visualization()
    
#     # For Jupyter/Colab
#     display(HTML(anim.to_jshtml()))

In [None]:
def create_visualization():
    """Compare ANN vs SNN energy and accuracy performances:
    1. Input images being processed
    2. Network activity patterns
    3. Output predictions and accuracy
    4. Real-time power consumption
    5. Cumulative energy usage
    6. Energy efficiency advantage
    """
    
    DEMO_CONFIG = {
        'n_samples': 50,            # Number of MNIST samples to process
        'training_batches': 200,   # Pre-training iterations (more = better accuracy)
        'animation_speed': 100,    # ms per frame (lower = faster)
        'timesteps': 20,            # SNN temporal resolution
        'sample_interval': 20,     # Frames between new samples
    }
    
    # Load MNIST dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = datasets.MNIST(
        root='./data', train=True, download=True, transform=transform
    )
    
    demo_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=1, shuffle=True
    )
    
    # Initialize networks with better architecture
    ann = nn.Sequential(
        nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(128, 10)
    ).to(device)
    
    snn = SpikingNeuralNet(
        input_size=784, 
        hidden_size=256, 
        output_size=10,
        timesteps=DEMO_CONFIG['timesteps']
    ).to(device)
    
    print("Training networks for meaningful predictions...")
    optimizer_ann = torch.optim.Adam(ann.parameters(), lr=0.001)
    optimizer_snn = torch.optim.Adam(snn.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=64, shuffle=True
    )
    
    # Training loop with progress
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx >= DEMO_CONFIG['training_batches']:
            break
            
        if batch_idx % 20 == 0:
            print(f"   Batch {batch_idx}/{DEMO_CONFIG['training_batches']}")
            
        data, target = data.to(device), target.to(device)
        
        # Train ANN
        optimizer_ann.zero_grad()
        output_ann = ann(data)
        loss_ann = criterion(output_ann, target)
        loss_ann.backward()
        optimizer_ann.step()
        
        # Train SNN
        optimizer_snn.zero_grad()
        output_snn = snn(data)
        loss_snn = criterion(output_snn, target)
        loss_snn.backward()
        optimizer_snn.step()
    
    print("Training complete. Now calculating energy consumption...")
    
    # Animation with improved layout
    fig = plt.figure(figsize=(16, 10))
    fig.patch.set_facecolor("#f8f9fa")

    fig.suptitle(
        'Real-Time Neural Network Comparison: Traditional AI vs Brain-Inspired AI',
        fontsize=20, fontweight='bold', y=0.96, color='#2c3e50')

    # Improved gridspec with dedicated flow row at top
    gs = gridspec.GridSpec(4, 8, figure=fig,
                          height_ratios=[0.4, 1.2, 1.2, 0.4],
                          width_ratios=[1.0, 0.18, 1.5, 0.18, 1.5, 0.18, 1.0, 0.1],
                          hspace=0.35, wspace=0.18,
                          left=0.06, right=0.94, top=0.92, bottom=0.08)        
    
    # Input image (spans both ANN/SNN rows for visual prominence)
    ax_input = fig.add_subplot(gs[1:3, 0])
    ax_input.set_title('Input Digit', fontsize=16, fontweight='bold', pad=15, color='#34495e')
    ax_input.axis('off')
    input_img = ax_input.imshow(np.zeros((28, 28)), cmap='gray', vmin=-1, vmax=1)

    # Network activity comparison - same width as prediction plots
    ax_ann_activity = fig.add_subplot(gs[1, 2])
    ax_ann_activity.set_title('ANN Activity Pattern', fontsize=13, fontweight='bold', 
                             color=COLORS['ann'], pad=12)
    ax_ann_activity.set_ylabel('Layer Activity (%)', fontsize=11)
    ax_ann_activity.set_ylim(0, 100)
    ax_ann_activity.set_xticks([0, 1, 2])
    ax_ann_activity.set_xticklabels(['Input', 'Hidden', 'Output'], fontsize=10)
    ax_ann_activity.grid(True, alpha=0.3)
    
    ax_snn_activity = fig.add_subplot(gs[2, 2])
    ax_snn_activity.set_title('SNN Spike Pattern', fontsize=13, fontweight='bold', 
                             color=COLORS['snn'], pad=12)
    ax_snn_activity.set_ylabel('Spike Rate (%)', fontsize=11)
    ax_snn_activity.set_ylim(0, 100)
    ax_snn_activity.set_xticks([0, 1, 2])
    ax_snn_activity.set_xticklabels(['Input', 'Hidden', 'Output'], fontsize=10)
    ax_snn_activity.grid(True, alpha=0.3)
    
    # Prediction displays - same width as activity plots
    ax_ann_pred = fig.add_subplot(gs[1, 4])
    ax_ann_pred.set_title('ANN Prediction', fontsize=13, fontweight='bold', 
                         color=COLORS['ann'], pad=12)
    ax_ann_pred.set_ylabel('Confidence', fontsize=11)
    ax_ann_pred.set_ylim(0, 1)
    ax_ann_pred.set_xticks(range(10))
    ax_ann_pred.grid(True, alpha=0.3)
    
    ax_snn_pred = fig.add_subplot(gs[2, 4])
    ax_snn_pred.set_title('SNN Prediction', fontsize=13, fontweight='bold', 
                         color=COLORS['snn'], pad=12)
    ax_snn_pred.set_ylabel('Confidence', fontsize=11)
    ax_snn_pred.set_ylim(0, 1)
    ax_snn_pred.set_xticks(range(10))
    ax_snn_pred.grid(True, alpha=0.3)
    
    # Large digit display for predictions
    ax_ann_digit = fig.add_subplot(gs[1, 6])
    ax_ann_digit.set_title('ANN Guess', fontsize=13, fontweight='bold', pad=12, color='#34495e')
    ax_ann_digit.axis('off')
    
    ax_snn_digit = fig.add_subplot(gs[2, 6])
    ax_snn_digit.set_title('SNN Guess', fontsize=13, fontweight='bold', pad=12, color='#34495e') 
    ax_snn_digit.axis('off')
    
    # Status bar for current sample info
    ax_status = fig.add_subplot(gs[3, :])
    ax_status.axis('off')

    # Style improvements for all axes
    for ax in (ax_ann_activity, ax_snn_activity, ax_ann_pred, ax_snn_pred):
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_color('#bdc3c7')
        ax.spines['bottom'].set_color('#bdc3c7')
        ax.tick_params(colors='#34495e', labelsize=9)
    
    # Initialize visualization elements
    ann_bars = ax_ann_activity.bar([0, 1, 2], [0, 0, 0], color=COLORS['ann'], alpha=0.8, width=0.6)
    snn_bars = ax_snn_activity.bar([0, 1, 2], [0, 0, 0], color=COLORS['snn'], alpha=0.8, width=0.6)
    ann_pred_bars = ax_ann_pred.bar(range(10), np.zeros(10), color=COLORS['ann'], alpha=0.8, width=0.7)
    snn_pred_bars = ax_snn_pred.bar(range(10), np.zeros(10), color=COLORS['snn'], alpha=0.8, width=0.7)

    # Create flow diagram in top row
    ax_flow = fig.add_subplot(gs[0, :])
    ax_flow.set_axis_off()
    
    # Flow elements with proper alignment to columns below
    flow_axes   = [ax_input, ax_ann_activity, ax_ann_pred, ax_ann_digit]
    flow_labels = ['Input Digit', 'Processor Activity', 'Analyze Predictions', 'Output Digit']
    flow_colors = ['#34495e', '#7f8c8d', '#3498db', "#b587e9"]
    
    # Vertical band (within the top row area) for the flow boxes
    flow_bbox = ax_flow.get_position()           # figure coords (0..1)
    y0 = flow_bbox.y0 + 0.18*flow_bbox.height    # a bit of top/bottom padding
    y1 = flow_bbox.y1 - 0.18*flow_bbox.height
    h  = y1 - y0

    rects_xyw = []
    for ax in flow_axes:
        bb = ax.get_position()                   # figure coords for each target axis
        x, w = bb.x0, bb.width                   # use *exact* width & left edge
        rects_xyw.append((x, w))

    # Draw boxes, labels, and arrows in figure coordinates for perfect alignment
    for i, ((x, w), label, color) in enumerate(zip(rects_xyw, flow_labels, flow_colors)):
        rect = Rectangle((x, y0), w, h,
                        transform=fig.transFigure,  # <- figure coords
                        facecolor=color, alpha=0.15,
                        edgecolor=color, linewidth=2, zorder=10)
        fig.add_artist(rect)

        fig.text(x + w/2, y0 + h/2, label,
                transform=fig.transFigure,
                ha='center', va='center', fontsize=11, fontweight='bold',
                color=color, zorder=11)

        if i < len(rects_xyw) - 1:
            x_next, _ = rects_xyw[i+1]
            arrow = FancyArrowPatch((x + w, y0 + h/2), (x_next, y0 + h/2),
                                    transform=fig.transFigure,
                                    arrowstyle='-|>', linewidth=2.5,
                                    mutation_scale=10, mutation_aspect=1.1,
                                    color='#2c3e50', alpha=0.7, zorder=12)
            fig.add_artist(arrow)
    
    # Initialize energy tracking
    ann_meter = EnergyMeter("ANN", device_type='gpu')
    snn_meter = EnergyMeter("SNN", device_type='neuromorphic')
    
    # Data collection for static plots
    analysis_data = {
        'samples': [],
        'ann_predictions': [],
        'snn_predictions': [],
        'true_labels': [],
        'ann_energy': [],
        'snn_energy': [],
        'ann_power': [],
        'snn_power': [],
        'snn_spike_rates': [],
        'processing_times': []
    }
    
    # Animation state
    data_iter = iter(demo_loader)
    samples_processed = 0
    current_image = None
    current_label = None
    
    def animate(frame):
        nonlocal data_iter, samples_processed, current_image, current_label
        
        # Process new sample at intervals
        if frame % DEMO_CONFIG['sample_interval'] == 0 and samples_processed < DEMO_CONFIG['n_samples']:
            
            try:
                current_image, current_label = next(data_iter)
            except StopIteration:
                data_iter = iter(demo_loader)
                current_image, current_label = next(data_iter)
            
            current_image, current_label = current_image.to(device), current_label.to(device)
            
            # Update input display
            input_img.set_data(current_image.squeeze().cpu().numpy())
            
            # Run inference
            start_time = time.time()
            with torch.no_grad():
                # ANN inference
                ann_output = ann(current_image)
                ann_probs = F.softmax(ann_output, dim=1).squeeze().cpu().numpy()
                ann_pred = ann_output.argmax(1).item()
                
                # ANN energy
                ann_meter.add_operations(
                    macs=784 * 256 + 256 * 128 + 128 * 10,
                    memory_bytes=4 * (784 + 256 + 128 + 10) * 2
                )
                ann_power, ann_energy = ann_meter.compute_energy(0.001)
                
                # SNN inference
                snn_output = snn(current_image, meter=snn_meter)
                snn_probs = F.softmax(snn_output, dim=1).squeeze().cpu().numpy()
                snn_pred = snn_output.argmax(1).item()

                # SNN energy
                snn_power, snn_energy = snn_meter.compute_energy(0.001)
                
            processing_time = time.time() - start_time
            
            # Store data for plots
            analysis_data['samples'].append(samples_processed)
            analysis_data['ann_predictions'].append(ann_pred)
            analysis_data['snn_predictions'].append(snn_pred)
            analysis_data['true_labels'].append(current_label.item())
            analysis_data['ann_energy'].append(ann_energy * 1000) # Convert to mJ
            analysis_data['snn_energy'].append(snn_energy * 1000) # Convert to mJ
            analysis_data['ann_power'].append(ann_power)
            analysis_data['snn_power'].append(snn_power)
            analysis_data['snn_spike_rates'].append(snn.spike_rates.get('overall', 0))
            analysis_data['processing_times'].append(processing_time)
            
            # Update activity displays
            # ANN: Estimate realistic activity heights for dense computation
            input_activity = (current_image.squeeze() > 0).float().mean().item() * 100
            hidden_activity = 50 + np.random.normal(0, 5) # ~50% after ReLU
            output_activity = 100 # All output neurons active
            
            ann_bars[0].set_height(input_activity)
            ann_bars[1].set_height(hidden_activity) 
            ann_bars[2].set_height(output_activity)
            for bar in ann_bars:
                bar.set_color(COLORS['ann'])
                bar.set_alpha(0.8)
                
            # SNN: Show actual sparse activity
            spike_rates = [
                snn.spike_rates.get('input', 0) * 100,
                snn.spike_rates.get('hidden', 0) * 100,
                snn.spike_rates.get('output', 0) * 100
            ]
            for i, bar in enumerate(snn_bars):
                bar.set_height(spike_rates[i])
                bar.set_alpha(0.8)
            
            # Update prediction displays
            for i, (ann_bar, snn_bar) in enumerate(zip(ann_pred_bars, snn_pred_bars)):
                ann_bar.set_height(ann_probs[i])
                snn_bar.set_height(snn_probs[i])
                
                # Highlight most likely prediction
                ann_color = COLORS['ann'] if i == ann_pred else '#ffb3b3'
                snn_color = COLORS['snn'] if i == snn_pred else '#b3ffb3'
                ann_bar.set_color(ann_color)
                snn_bar.set_color(snn_color)
                ann_bar.set_alpha(0.9 if i == ann_pred else 0.4)
                snn_bar.set_alpha(0.9 if i == snn_pred else 0.4)
            
            # Display predicted digits
            ax_ann_digit.clear()
            ax_ann_digit.text(0.5, 0.5, str(ann_pred), fontsize=60, 
                             fontweight='bold', ha='center', va='center',
                             color=COLORS['ann'] if ann_pred == current_label.item() else '#ffb3b3')
            ax_ann_digit.set_xlim(0, 1)
            ax_ann_digit.set_ylim(0, 1)
            ax_ann_digit.axis('off')
            
            ax_snn_digit.clear()
            ax_snn_digit.text(0.5, 0.5, str(snn_pred), fontsize=60,
                             fontweight='bold', ha='center', va='center', 
                             color=COLORS['snn'] if snn_pred == current_label.item() else '#b3ffb3')
            ax_snn_digit.set_xlim(0, 1)
            ax_snn_digit.set_ylim(0, 1)
            ax_snn_digit.axis('off')
            
            # Update status
            ax_status.clear()
            status_text = (f"Sample {samples_processed + 1}/{DEMO_CONFIG['n_samples']} | "
                          f"True Label: {current_label.item()} | "
                          f"ANN: {ann_pred} {'✓' if ann_pred == current_label.item() else '✗'} | "
                          f"SNN: {snn_pred} {'✓' if snn_pred == current_label.item() else '✗'} | "
                          f"Energy Ratio: {ann_energy/max(snn_energy, 1e-12):.0f}:1")
            
            ax_status.text(0.5, 0.5, status_text,
                          ha='center', va='center', fontsize=13, 
                          fontweight='bold', color='#2c3e50')
            ax_status.set_xlim(0, 1)
            ax_status.set_ylim(0, 1)
            ax_status.axis('off')            
            
            samples_processed += 1
            
        return []
    
    # Create animation
    frames = DEMO_CONFIG['sample_interval'] * DEMO_CONFIG['n_samples'] + 20
    anim = FuncAnimation(fig, animate, frames=frames, 
                        interval=DEMO_CONFIG['animation_speed'], blit=False)
    
    return fig, anim, analysis_data

def create_static_analysis(analysis_data):    
    if len(analysis_data['samples']) == 0:
        print("No data collected for analysis")
        return None
    
    fig, ((ax1, ax2, ax3, ax4)) = plt.subplots(4, 1, figsize=(8, 14))
    fig.patch.set_facecolor('white')
    # fig.suptitle('Traditional AI (ANN) vs Brain-Inspired AI (SNN)', fontsize=16, fontweight='bold')
    
    # Calculate running accuracy metrics
    ann_running_accuracy = []
    snn_running_accuracy = []
    
    ann_correct_count = 0
    snn_correct_count = 0
    
    for i, (ann_pred, snn_pred, true_label) in enumerate(zip(
        analysis_data['ann_predictions'], 
        analysis_data['snn_predictions'], 
        analysis_data['true_labels'])):
        
        if ann_pred == true_label:
            ann_correct_count += 1
        if snn_pred == true_label:
            snn_correct_count += 1
            
        # Running accuracy = correct so far / samples so far
        ann_running_accuracy.append((ann_correct_count / (i + 1)) * 100)
        snn_running_accuracy.append((snn_correct_count / (i + 1)) * 100)
    
    final_ann_accuracy = ann_running_accuracy[-1] if ann_running_accuracy else 0
    final_snn_accuracy = snn_running_accuracy[-1] if snn_running_accuracy else 0
    
    total_ann_energy = sum(analysis_data['ann_energy'])
    total_snn_energy = sum(analysis_data['snn_energy'])
    energy_ratio = total_ann_energy / max(total_snn_energy, 1e-9)
    
    # 1. Running Accuracy Comparison
    ax1.plot(analysis_data['samples'], ann_running_accuracy, 
            'r-', linewidth=2, label=f'ANN (Final: {final_ann_accuracy:.1f}%)')
    ax1.plot(analysis_data['samples'], snn_running_accuracy, 
            'g-', linewidth=2, label=f'SNN (Final: {final_snn_accuracy:.1f}%)')
    ax1.set_xlabel('Sample Number')
    ax1.set_ylabel('Running Accuracy (%)')
    ax1.set_title('Classification Accuracy Over Time')
    ax1.set_ylim(0, 105)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 2. Energy Consumption Over Time
    ax2.plot(analysis_data['samples'], np.cumsum(analysis_data['ann_energy']), 
            'r-', linewidth=2, label='ANN')
    ax2.plot(analysis_data['samples'], np.cumsum(analysis_data['snn_energy']), 
            'g-', linewidth=2, label='SNN')
    ax2.set_xlabel('Sample Number')
    ax2.set_ylabel('Cumulative Energy (mJ)')
    ax2.set_title('Energy Consumption Over Time')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # 3. Power Draw Comparison
    ax3.plot(analysis_data['samples'], analysis_data['ann_power'], 
            'r-', alpha=0.7, label='ANN')
    ax3.plot(analysis_data['samples'], np.array(analysis_data['snn_power']), 
            'g-', alpha=0.7, label='SNN')
    ax3.set_xlabel('Sample Number') 
    ax3.set_ylabel('Power (W)')
    ax3.set_title('Instantaneous Power Draw')
    ax3.set_yscale('log')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # 4. Efficiency Advantage
    efficiency_ratios = [a/max(s, 1e-12) for a, s in 
                        zip(analysis_data['ann_energy'], analysis_data['snn_energy'])]
    ax4.plot(analysis_data['samples'], efficiency_ratios, 'b-', linewidth=2)
    # ax4.axhline(y=1, color='gray', linestyle='--', alpha=0.5)
    ax4.set_xlabel('Sample Number')
    ax4.set_ylabel('Energy Efficiency (ANN/SNN)')
    ax4.set_title(f'SNN Efficiency Advantage (Avg: {np.mean(efficiency_ratios):.0f}×)')
    ax4.grid(True, alpha=0.3)

    xmin, xmax = analysis_data['samples'][0], analysis_data['samples'][-1]

    for ax in (ax1, ax2, ax3, ax4):
        ax.set_title(ax.get_title(), pad=8)
        ax.xaxis.labelpad = 6
        ax.yaxis.labelpad = 6
        ax.set_xlim(xmin, xmax+0.5)
        ax.xaxis.set_major_locator(MaxNLocator(nbins=8, integer=True, prune='both'))
        ax.xaxis.set_major_formatter(FormatStrFormatter('%d'))

    plt.tight_layout() 
    
    return fig

if __name__ == "__main__":
    fig_anim, anim, data = create_visualization()
    
    display(HTML(anim.to_jshtml()))
    
    fig_static = create_static_analysis(data)
    if fig_static:
        plt.show()

## Part 4: SNNs Are A Promising Alternative To ANNs

This visualization demonstrates that **SNNs can use >100× less energy** than dense ANNs while maintaining competitive accuracy.

In addition to significantly reducing the carbon footprint of AI systems, the improved efficiency of SNNs is especially critical for deploying AI on battery-powered devices or always-on sensors and for achieving real-time / low-latency processing for robotics and autonomous systems. 

### The Biological Inspiration

The dramatic improvement in energy efficiency is achieved through **sparse activity, event-driven processing, and temporal dynamics**. 

| **Aspect** | **Traditional ANNs** | **Brain-Inspired SNNs** | **Impact** |
|------------|---------------------|------------------------|------------|
| **Information Encoding** | Continuous values (0.0-1.0) | Binary spikes over time | 10× less data movement |
| **Computation Model** | Synchronous matrix operations | Asynchronous event processing | 100× fewer operations |
| **Activity Pattern** | ~50% neurons always active | ~5% neurons fire when needed | 10× less power draw |
| **Hardware Utilization** | Constant high power draw | Power scales with activity | Days vs. hours of battery |
| **Biological Fidelity** | Mathematical abstraction | Mimics actual neurons | Future-proof design |

<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 25px; border-radius: 15px; margin: 25px 0; box-shadow: 0 4px 6px rgba(0,0,0,0.1);">
    <h3 style="color: white; margin: 0; font-size: 24px;">The Biological Blueprint</h3>
    <p style="color: white; margin: 15px 0; font-size: 16px; line-height: 1.6;">
        Biological neurons fire only when necessary, amounting to 1-5% of the time. This sparse activity pattern achieves the brain's extraordinary efficiency. By mimicking this principle, Spiking Neural Networks (SNNs) achieve significant energy savings while maintaining computational accuracy.
    </p>
</div>

### Seeing Sparsity in Action

The following code block shows how sparse, event-driven computation achieves these efficiency gains:

In [None]:
def create_sparse_activity_visualization():
    """Compare dense ANN and sparse SNN computation patterns with real energy measurements."""

    # Run actual inference to get real metrics
    print("Computing real energy metrics...")
    
    # Prepare test data
    test_batch = torch.randn(32, 1, 28, 28).to(device)
    
    # Initialize networks
    ann_model = nn.Sequential(
        nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    ).to(device)
    
    snn_model = SpikingNeuralNet(
        input_size=784, 
        hidden_size=256, 
        output_size=10, 
        timesteps=10
    ).to(device)
    
    # Measure energy consumption
    ann_meter = EnergyMeter("ANN", device_type='gpu')
    snn_meter = EnergyMeter("SNN", device_type='neuromorphic')
    
    with torch.no_grad():

        # ANN forward pass
        ann_output = ann_model(test_batch)
        ann_meter.add_operations(
            macs=32 * (784*256 + 256*128 + 128*10),  # Batch size * layer operations
            memory_bytes=4 * 32 * (784 + 256 + 128 + 10) * 2
        )
        _, ann_energy_j = ann_meter.compute_energy(0.001)
        
        # SNN forward pass
        snn_output = snn_model(test_batch, meter=snn_meter)
        _, snn_energy_j = snn_meter.compute_energy(0.001)
    
    # Calculate metrics
    ann_energy_mj = ann_energy_j * 1000
    snn_energy_mj = snn_energy_j * 1000
    efficiency_gain = ann_energy_mj / max(snn_energy_mj, 1e-9) # Avoid divide by zero
    snn_sparsity = 100 - (snn_model.spike_rates.get('overall', 0.05) * 100)
    
    # Set up figure and grid layout
    fig = plt.figure(figsize=(14, 14))
    fig.patch.set_facecolor('#f8f9fa')

    gs = gridspec.GridSpec(
        3, 2, 
        figure=fig,
        height_ratios=[1.5, 1.5, 1.0],
        width_ratios=[1, 1],
        hspace=0.35, 
        wspace=0.25,
        left=0.06, right=0.94, top=0.90, bottom=0.05
    )
    
    # Main title
    fig.suptitle(
        f"Sparse Activity Is Key To {efficiency_gain:.0f}× Better Energy Efficiency",
        fontsize=22, fontweight='bold', y=0.98
    )
    
    # Subtitle for context
    fig.text(0.5, 0.94, 
             'Comparison of traditional AI (ANNs) vs. brain-inspired AI (SNNs)',
             ha='center', fontsize=14, style='italic', color='#555')
    
    # ANN Activity Heatmap
    ax_ann_activity = fig.add_subplot(gs[0, 0])
    
    # Generate realistic ANN activity
    np.random.seed(42)
    neurons_per_layer = 30
    time_steps = 20
    
    # ANN: Dense activity (~50% active after ReLU)
    ann_activity = np.random.beta(2, 2, (neurons_per_layer, time_steps))
    ann_activity = (ann_activity > 0.5).astype(float) * np.random.uniform(0.3, 1.0, (neurons_per_layer, time_steps))
    
    im_ann = ax_ann_activity.imshow(
        ann_activity, 
        aspect='auto', 
        cmap='hot',
        interpolation='nearest',
        vmin=0, vmax=1
    )
    
    ax_ann_activity.set_title('Traditional ANN: Dense Activity', 
                              fontsize=14, fontweight='bold', color=COLORS['ann'], pad=10)
    ax_ann_activity.set_xlabel('Time Steps', fontsize=11)
    ax_ann_activity.set_ylabel('Neurons', fontsize=11)
    
    # Add activity percentage annotation
    ann_active_pct = (ann_activity > 0.01).mean() * 100
    ax_ann_activity.text(0.98, 0.98, f'{ann_active_pct:.0f}% Active', 
                         transform=ax_ann_activity.transAxes,
                         ha='right', va='top',
                         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                         fontsize=12, fontweight='bold', color=COLORS['ann'])
    
    # Colorbar for ANN
    cbar_ann = plt.colorbar(im_ann, ax=ax_ann_activity, fraction=0.046, pad=0.04)
    cbar_ann.set_label('Activation Level', fontsize=10)
    
    # SNN Spike Raster
    ax_snn_activity = fig.add_subplot(gs[0, 1])
    
    # SNN: Sparse spikes (~5% active)
    snn_spikes = np.random.random((neurons_per_layer, time_steps)) < 0.05
    
    # Create visually distinct spikes
    spike_display = np.zeros((neurons_per_layer, time_steps, 3))
    spike_positions = np.where(snn_spikes)
    
    for i, j in zip(spike_positions[0], spike_positions[1]):
        spike_display[i, j] = [1, 0.8, 0] # Yellow spikes
    
    ax_snn_activity.imshow(spike_display, aspect='auto', interpolation='nearest')
    
    # Add spike markers for clarity
    for i, j in zip(spike_positions[0], spike_positions[1]):
        ax_snn_activity.plot(j, i, 'o', color='yellow', markersize=4, 
                            markeredgecolor='orange', markeredgewidth=0.5)
    
    ax_snn_activity.set_title('Brain-Inspired SNN: Sparse Spikes', 
                              fontsize=14, fontweight='bold', color=COLORS['snn'], pad=10)
    ax_snn_activity.set_xlabel('Time Steps', fontsize=11)
    ax_snn_activity.set_ylabel('Neurons', fontsize=11)
    ax_snn_activity.set_facecolor('#1a1a1a')
    
    # Add sparsity annotation
    snn_active_pct = snn_spikes.mean() * 100
    ax_snn_activity.text(0.98, 0.98, f'{snn_active_pct:.1f}% Active', 
                         transform=ax_snn_activity.transAxes,
                         ha='right', va='top',
                         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                         fontsize=12, fontweight='bold', color=COLORS['snn'])
    
    # Energy comparison
    ax_energy = fig.add_subplot(gs[1, 0])
    
    # Create bar chart
    categories = ['Per\nInference', 'Per\nSecond', 'Per\nHour']
    ann_values = [ann_energy_mj, ann_energy_mj*1000, ann_energy_mj*3600000]
    snn_values = [snn_energy_mj, snn_energy_mj*1000, snn_energy_mj*3600000]
    
    x = np.arange(len(categories))
    width = 0.35
    
    bars_ann = ax_energy.bar(x - width/2, ann_values, width, 
                             label='ANN', color=COLORS['ann'], 
                             alpha=0.8, edgecolor='black', linewidth=1.5)
    bars_snn = ax_energy.bar(x + width/2, snn_values, width,
                             label='SNN', color=COLORS['snn'],
                             alpha=0.8, edgecolor='black', linewidth=1.5)
    
    # Add labels
    for bars in [bars_ann, bars_snn]:
        for bar in bars:
            height = bar.get_height()
            if height > 1000000:
                label = f'{height/1000000:.1f}J'
            elif height > 1000:
                label = f'{height/1000:.1f}mJ'
            else:
                label = f'{height:.2f}mJ'
            
            ax_energy.text(bar.get_x() + bar.get_width()/2., height,
                          label, ha='center', va='bottom',
                          fontsize=9, fontweight='bold')
    
    ax_energy.set_yscale('log')
    ax_energy.set_ylabel('Energy Consumption', fontsize=11)
    ax_energy.set_title('Energy Usage Comparison', fontsize=14, fontweight='bold', pad=10)
    ax_energy.set_xticks(x)
    ax_energy.set_xticklabels(categories)
    ax_energy.legend(loc='upper left', fontsize=11)
    ax_energy.grid(True, alpha=0.3, axis='y', which='both')
    
    # Activity Distribution
    ax_distribution = fig.add_subplot(gs[1, 1])
    
    # Calculate per-neuron activity stats
    ann_per_neuron = (ann_activity > 0.01).mean(axis=1) * 100
    snn_per_neuron = snn_spikes.mean(axis=1) * 100
    
    # Create overlapping histograms
    bins = np.linspace(0, max(ann_per_neuron.max(), 20), 15)
    
    ax_distribution.hist(ann_per_neuron, bins=bins, alpha=0.6, 
                        color=COLORS['ann'], edgecolor='black',
                        label=f'ANN (μ={ann_per_neuron.mean():.0f}%)', 
                        linewidth=1.5)
    ax_distribution.hist(snn_per_neuron, bins=bins, alpha=0.6,
                        color=COLORS['snn'], edgecolor='black', 
                        label=f'SNN (μ={snn_per_neuron.mean():.1f}%)',
                        linewidth=1.5)
    
    ax_distribution.set_xlabel('Activity Rate (%)', fontsize=11)
    ax_distribution.set_ylabel('Number of Neurons', fontsize=11)
    ax_distribution.set_title('Activity Distribution', fontsize=14, fontweight='bold', pad=10)
    ax_distribution.legend(loc='upper right', fontsize=10)
    ax_distribution.grid(True, alpha=0.3, axis='y')
    
    # Add annotation about bimodal vs uniform distribution
    ax_distribution.axvline(x=50, color='red', linestyle='--', alpha=0.3, linewidth=2)
    ax_distribution.axvline(x=5, color='green', linestyle='--', alpha=0.3, linewidth=2)
    
    # Scaling projection
    ax_scaling = fig.add_subplot(gs[2, :])
    
    model_sizes = np.logspace(0, 4, 50) # 1 to 10,000 relative size
    ann_energy_scaling = model_sizes ** 1.3 # Empirical scaling
    snn_energy_scaling = model_sizes ** 0.8 # Better scaling with sparsity
    
    ax_scaling.loglog(model_sizes, ann_energy_scaling, '-',
                     color=COLORS['ann'], linewidth=3, label='ANN: O(n^1.3)')
    ax_scaling.loglog(model_sizes, snn_energy_scaling, '-',
                     color=COLORS['snn'], linewidth=3, label='SNN: O(n^0.8)')
    
    # Fill the gap to hilite advantage
    ax_scaling.fill_between(model_sizes, ann_energy_scaling, snn_energy_scaling,
                           where=(ann_energy_scaling > snn_energy_scaling),
                           alpha=0.2, color='green', label='Energy Saved')
    
    # Denote current and future model sizes
    current_size = 1000
    future_size = 10000
    
    ax_scaling.scatter([current_size], [current_size**1.3], s=100, color=COLORS['ann'],
                      marker='o', zorder=5, edgecolor='black', linewidth=2)
    ax_scaling.scatter([current_size], [current_size**0.8], s=100, color=COLORS['snn'],
                      marker='o', zorder=5, edgecolor='black', linewidth=2)
    
    # Add annotations for gap at scale
    gap_current = current_size**1.3 / current_size**0.8
    gap_future = future_size**1.3 / future_size**0.8
    
    ax_scaling.annotate(f'{gap_current:.0f}× gap', 
                       xy=(current_size, current_size**1.05),
                       xytext=(current_size*2, current_size**1.1),
                       arrowprops=dict(arrowstyle='->', color='black', lw=1.5),
                       fontsize=11, fontweight='bold',
                       bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.7))
    
    ax_scaling.set_xlabel('Model Size (relative)', fontsize=11)
    ax_scaling.set_ylabel('Energy Consumption (relative)', fontsize=11)
    ax_scaling.set_title('Energy Scaling Gap Widens with Model Size', 
                        fontsize=14, fontweight='bold', pad=10)
    ax_scaling.legend(loc='upper left', fontsize=10)
    ax_scaling.grid(True, alpha=0.3, which='both')
    
    plt.tight_layout()
    
    # Return fig and metrics for subsequent use
    metrics_dict = {
        'ann_energy_mj': ann_energy_mj,
        'snn_energy_mj': snn_energy_mj,
        'efficiency_gain': efficiency_gain,
        'ann_activity': ann_active_pct,
        'snn_activity': snn_active_pct,
        'snn_sparsity': snn_sparsity
    }
    
    return fig, metrics_dict

fig_comprehensive, sparsity_metrics = create_sparse_activity_visualization()
plt.show()

# print("\n📊 Sparsity Analysis Complete:")
# print(f"   • ANN Activity: {sparsity_metrics['ann_activity']:.1f}%")
# print(f"   • SNN Activity: {sparsity_metrics['snn_activity']:.1f}%") 
# print(f"   • Efficiency Gain: {sparsity_metrics['efficiency_gain']:.0f}×")
# print(f"   • Energy Saved: {(1 - sparsity_metrics['snn_energy_mj']/sparsity_metrics['ann_energy_mj'])*100:.1f}%")

## A Paradigm Shift in AI Computing

We've demonstrated that brain-inspired computing (SNNs) achieves comparable accuracy as traditional AI (ANNs) while reducing energy consumption by TBD-TBD×. This improvement is one promising solution to the energy crisis facing AI development.

### The Implications Are Exciting

1. Today, SNNs can be deployed immediately on existing hardware with **10-100× efficiency gains**
2. Tomorrow, neuromorphic chips may achieve >1000× improvements
3. SNNs makes AGI-scale models **more environmentally and economically viable**
4. Viability enables sophisticated AI on edge devices and in developing regions

<div style="background: linear-gradient(to right, #667eea, #764ba2); padding: 2px; border-radius: 10px; margin: 20px 0;">
    <div style="background: white; border-radius: 8px; padding: 20px;">
        <h4 style="color: #667eea; margin-top: 0;">Coming in Notebook 2: The Biological Blueprint</h4>
        <p style="color: #4a5568; line-height: 1.8;">
            We'll analyze real neural recordings (from my Ph.D. research in Donhee Ham's group at Harvard) to understand how biological neurons achieve such remarkable efficiency.
        </p>
        <p style="color: #667eea; font-weight: bold; margin-top: 15px;">
            Preview: TBD
        </p>
    </div>
</div>

**Brain-Inspired AI Series**  
**[Part 1: The Energy Crisis ✓]** | [Part 2: Biological Blueprint →](link) | [Part 3: Mathematics →](link) | [Part 4: Implementation →](link) | [Part 5: Deployment →](link)