# Backpropagation Visualization Mastery: Interactive Deep Learning Analysis

**PyTorch Mastery Hub - Advanced Visualization Module**

**Authors:** PyTorch Mastery Hub Development Team  
**Institution:** PyTorch Mastery Hub  
**Module:** 02_autograd_backpropagation  
**Date:** December 2024

## Overview

This notebook provides comprehensive visualization and analysis of backpropagation algorithms in deep neural networks. We focus on creating interactive visualizations that reveal the inner workings of gradient computation, computational graphs, and optimization dynamics to build deep intuition for neural network training.

## Key Objectives
1. Build and visualize computational graphs interactively
2. Animate gradient flow through neural networks in real-time
3. Explore 3D loss landscapes and optimization trajectories
4. Analyze activation patterns and weight distributions
5. Monitor training dynamics and gradient behavior
6. Create diagnostic tools for debugging neural networks

## 1. Setup and Environment Configuration

```python
# Essential imports for advanced visualization
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.patches import FancyBboxPatch, Circle, Arrow
from matplotlib.collections import LineCollection
import seaborn as sns
import networkx as nx
from mpl_toolkits.mplot3d import Axes3D

import pandas as pd
from pathlib import Path
import json
import time
import warnings
from typing import Dict, List, Tuple, Optional, Any
from collections import defaultdict

warnings.filterwarnings('ignore')

# Import custom utilities
import sys
import os
sys.path.append(os.path.join(os.getcwd(), '..', '..'))

try:
    from src.utils.device_utils import get_device
    from src.visualization.training_viz import TrainingVisualizer
except ImportError:
    print("⚠️ Custom utilities not found, using fallback implementations")
    
    def get_device():
        return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Environment setup
device = get_device()
torch.manual_seed(42)
np.random.seed(42)

# Enhanced plotting configuration
plt.style.use('seaborn-v0_8')
sns.set_palette('husl')
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

# Create results directory
results_dir = Path('../results/notebooks/backprop_visualization')
results_dir.mkdir(parents=True, exist_ok=True)

print("🎨 PyTorch Mastery Hub - Backpropagation Visualization Mastery")
print("=" * 65)
print(f"📱 Device: {device}")
print(f"🎭 PyTorch version: {torch.__version__}")
print(f"📁 Results directory: {results_dir}")
print(f"✨ Ready to visualize the magic of backpropagation!\n")
```

## 2. Computational Graph Construction and Visualization

### 2.1 Interactive Computational Graph Builder

```python
class ComputationGraphVisualizer:
    """Advanced computational graph visualization with interactive features."""
    
    def __init__(self):
        self.nodes = {}
        self.edges = []
        self.node_counter = 0
        self.colors = {
            'input': '#FF6B6B',      # Red for inputs
            'operation': '#4ECDC4',   # Teal for operations  
            'parameter': '#45B7D1',   # Blue for parameters
            'output': '#96CEB4',      # Green for outputs
            'loss': '#FFEAA7',       # Yellow for loss
            'gradient': '#DDA0DD'     # Purple for gradients
        }
        self.analysis_results = {}
    
    def add_node(self, tensor, node_type='operation', label=None):
        """Add a node to the computational graph with metadata."""
        node_id = f"node_{self.node_counter}"
        self.node_counter += 1
        
        if label is None:
            if hasattr(tensor, 'grad_fn') and tensor.grad_fn is not None:
                label = tensor.grad_fn.__class__.__name__
            else:
                label = 'Tensor'
        
        self.nodes[node_id] = {
            'tensor': tensor,
            'type': node_type,
            'label': label,
            'shape': tuple(tensor.shape),
            'requires_grad': tensor.requires_grad,
            'grad_fn': tensor.grad_fn,
            'memory_usage': tensor.element_size() * tensor.nelement(),
            'dtype': str(tensor.dtype)
        }
        
        return node_id
    
    def add_edge(self, from_node, to_node, operation=None, gradient_flow=None):
        """Add an edge with optional gradient flow information."""
        self.edges.append({
            'from': from_node,
            'to': to_node,
            'operation': operation,
            'gradient_flow': gradient_flow
        })
    
    def visualize_graph(self, figsize=(16, 12), layout='hierarchical', show_gradients=True):
        """Create comprehensive computational graph visualization."""
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize)
        
        # 1. Main graph visualization
        self._draw_main_graph(ax1, layout, show_gradients)
        
        # 2. Node statistics
        self._draw_node_statistics(ax2)
        
        # 3. Memory usage analysis
        self._draw_memory_analysis(ax3)
        
        # 4. Gradient flow analysis
        self._draw_gradient_flow_summary(ax4)
        
        plt.suptitle('Computational Graph Analysis Dashboard', fontsize=16, fontweight='bold')
        plt.tight_layout()
        
        # Save visualization
        plt.savefig(results_dir / 'computational_graph_analysis.png', dpi=300, bbox_inches='tight')
        
        return fig
    
    def _draw_main_graph(self, ax, layout, show_gradients):
        """Draw the main computational graph."""
        # Create NetworkX graph
        G = nx.DiGraph()
        
        # Add nodes with attributes
        for node_id, node_data in self.nodes.items():
            G.add_node(node_id, **node_data)
        
        # Add edges
        for edge in self.edges:
            G.add_edge(edge['from'], edge['to'], **edge)
        
        # Choose layout
        if layout == 'hierarchical':
            pos = self._hierarchical_layout(G)
        else:
            pos = nx.spring_layout(G, k=3, iterations=50)
        
        # Draw edges with gradient flow indicators
        for edge in G.edges(data=True):
            x1, y1 = pos[edge[0]]
            x2, y2 = pos[edge[1]]
            
            # Edge color based on gradient flow
            edge_color = 'red' if show_gradients and edge[2].get('gradient_flow') else 'gray'
            edge_width = 3 if edge[2].get('gradient_flow') else 1
            
            ax.annotate('', xy=(x2, y2), xytext=(x1, y1),
                       arrowprops=dict(arrowstyle='->', lw=edge_width, 
                                     color=edge_color, alpha=0.7))
            
            # Add operation label
            if edge[2].get('operation'):
                mid_x, mid_y = (x1 + x2) / 2, (y1 + y2) / 2
                ax.text(mid_x, mid_y, edge[2]['operation'], 
                       bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.8),
                       ha='center', va='center', fontsize=8)
        
        # Draw nodes
        for node_id, (x, y) in pos.items():
            node_data = self.nodes[node_id]
            
            # Node size based on tensor size
            tensor_size = np.prod(node_data['shape'])
            node_size = min(0.15, 0.05 + np.log10(max(1, tensor_size)) * 0.02)
            
            # Node color based on type
            color = self.colors.get(node_data['type'], '#CCCCCC')
            
            # Draw node
            circle = Circle((x, y), node_size, color=color, alpha=0.8, zorder=3)
            ax.add_patch(circle)
            
            # Node label
            ax.text(x, y, node_data['label'], ha='center', va='center', 
                   fontsize=9, fontweight='bold', zorder=4)
            
            # Shape and gradient info
            shape_str = str(node_data['shape'])
            grad_str = "∇" if node_data['requires_grad'] else ""
            ax.text(x, y-node_size-0.05, f"{shape_str}\n{grad_str}", 
                   ha='center', va='center', fontsize=7, alpha=0.8, zorder=4)
        
        # Create legend
        legend_elements = [
            plt.Circle((0, 0), 0.1, color=color, label=node_type.title())
            for node_type, color in self.colors.items()
        ]
        ax.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(0, 1))
        
        ax.set_xlim(-1.3, 1.3)
        ax.set_ylim(-1.3, 1.3)
        ax.set_aspect('equal')
        ax.axis('off')
        ax.set_title('Computational Graph Structure', fontweight='bold')
    
    def _draw_node_statistics(self, ax):
        """Draw node type distribution and statistics."""
        node_types = [node['type'] for node in self.nodes.values()]
        type_counts = pd.Series(node_types).value_counts()
        
        # Pie chart of node types
        colors = [self.colors.get(node_type, '#CCCCCC') for node_type in type_counts.index]
        wedges, texts, autotexts = ax.pie(type_counts.values, labels=type_counts.index, 
                                         autopct='%1.1f%%', colors=colors, startangle=90)
        
        # Enhance text
        for autotext in autotexts:
            autotext.set_color('white')
            autotext.set_fontweight('bold')
        
        ax.set_title('Node Type Distribution', fontweight='bold')
    
    def _draw_memory_analysis(self, ax):
        """Draw memory usage analysis."""
        node_names = []
        memory_usage = []
        
        for node_id, node_data in self.nodes.items():
            node_names.append(f"{node_data['label'][:8]}")
            memory_usage.append(node_data['memory_usage'] / 1024)  # Convert to KB
        
        if memory_usage:
            bars = ax.bar(range(len(memory_usage)), memory_usage, 
                         color=[self.colors.get(self.nodes[f"node_{i}"]['type'], '#CCCCCC') 
                               for i in range(len(memory_usage))], alpha=0.7)
            
            ax.set_xlabel('Nodes')
            ax.set_ylabel('Memory Usage (KB)')
            ax.set_title('Memory Usage by Node', fontweight='bold')
            ax.set_xticks(range(len(node_names)))
            ax.set_xticklabels(node_names, rotation=45, ha='right')
            
            # Add value labels
            for bar, mem in zip(bars, memory_usage):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + max(memory_usage)*0.01,
                       f'{mem:.1f}', ha='center', va='bottom', fontsize=8)
        
        ax.grid(True, alpha=0.3)
    
    def _draw_gradient_flow_summary(self, ax):
        """Draw gradient flow summary."""
        grad_nodes = [node for node in self.nodes.values() if node['requires_grad']]
        
        if grad_nodes:
            # Count nodes by type that require gradients
            grad_types = [node['type'] for node in grad_nodes]
            grad_counts = pd.Series(grad_types).value_counts()
            
            ax.bar(grad_counts.index, grad_counts.values, 
                  color='lightcoral', alpha=0.7)
            ax.set_title('Gradient-Enabled Nodes', fontweight='bold')
            ax.set_ylabel('Count')
            ax.tick_params(axis='x', rotation=45)
        else:
            ax.text(0.5, 0.5, 'No gradient-enabled nodes', 
                   ha='center', va='center', transform=ax.transAxes, fontsize=12)
            ax.set_title('Gradient Flow Analysis', fontweight='bold')
    
    def _hierarchical_layout(self, G):
        """Create hierarchical layout for computational graph."""
        levels = {}
        for node in nx.topological_sort(G):
            if not list(G.predecessors(node)):
                levels[node] = 0
            else:
                levels[node] = max(levels[pred] for pred in G.predecessors(node)) + 1
        
        # Group nodes by level
        level_groups = defaultdict(list)
        for node, level in levels.items():
            level_groups[level].append(node)
        
        # Assign positions
        pos = {}
        max_level = max(level_groups.keys()) if level_groups else 0
        
        for level, nodes in level_groups.items():
            y = 1 - (level / max(max_level, 1)) * 2  # Top to bottom
            
            if len(nodes) == 1:
                pos[nodes[0]] = (0, y)
            else:
                x_positions = np.linspace(-0.8, 0.8, len(nodes))
                for i, node in enumerate(nodes):
                    pos[node] = (x_positions[i], y)
        
        return pos
    
    def save_analysis(self):
        """Save computational graph analysis to JSON."""
        analysis_data = {
            'graph_statistics': {
                'total_nodes': len(self.nodes),
                'total_edges': len(self.edges),
                'gradient_nodes': sum(1 for node in self.nodes.values() if node['requires_grad']),
                'total_parameters': sum(np.prod(node['shape']) for node in self.nodes.values() 
                                      if node['type'] == 'parameter'),
                'total_memory_kb': sum(node['memory_usage'] for node in self.nodes.values()) / 1024
            },
            'node_details': {
                node_id: {
                    'type': data['type'],
                    'label': data['label'],
                    'shape': data['shape'],
                    'requires_grad': data['requires_grad'],
                    'memory_kb': data['memory_usage'] / 1024
                }
                for node_id, data in self.nodes.items()
            }
        }
        
        with open(results_dir / 'computational_graph_analysis.json', 'w') as f:
            json.dump(analysis_data, f, indent=2)
        
        return analysis_data

def demonstrate_computational_graphs():
    """Demonstrate computational graph visualization with examples."""
    print("🌐 Computational Graph Visualization Examples")
    print("=" * 50)
    
    # Example 1: Simple arithmetic computation
    print("\n📊 Example 1: Simple Arithmetic Operations")
    
    visualizer1 = ComputationGraphVisualizer()
    
    x = torch.tensor(2.0, requires_grad=True)
    y = torch.tensor(3.0, requires_grad=True)
    
    # Add input nodes
    x_node = visualizer1.add_node(x, 'input', 'x')
    y_node = visualizer1.add_node(y, 'input', 'y')
    
    # Perform operations
    a = x * y
    a_node = visualizer1.add_node(a, 'operation', 'x*y')
    visualizer1.add_edge(x_node, a_node, 'mul')
    visualizer1.add_edge(y_node, a_node, 'mul')
    
    b = a + x
    b_node = visualizer1.add_node(b, 'operation', 'a+x')
    visualizer1.add_edge(a_node, b_node, 'add')
    visualizer1.add_edge(x_node, b_node, 'add')
    
    c = torch.sin(b)
    c_node = visualizer1.add_node(c, 'output', 'sin(b)')
    visualizer1.add_edge(b_node, c_node, 'sin')
    
    # Visualize
    fig1 = visualizer1.visualize_graph()
    plt.show()
    
    # Save analysis
    analysis1 = visualizer1.save_analysis()
    print(f"Graph Statistics: {analysis1['graph_statistics']}")
    
    # Example 2: Neural Network Layer
    print("\n🧠 Example 2: Neural Network Layer Computation")
    
    visualizer2 = ComputationGraphVisualizer()
    
    # Create layer components
    batch_size, input_size, output_size = 3, 4, 2
    x_nn = torch.randn(batch_size, input_size, requires_grad=True)
    W = torch.randn(input_size, output_size, requires_grad=True)
    b = torch.randn(output_size, requires_grad=True)
    
    # Add nodes
    x_node = visualizer2.add_node(x_nn, 'input', 'Input\nX')
    w_node = visualizer2.add_node(W, 'parameter', 'Weights\nW')
    b_node = visualizer2.add_node(b, 'parameter', 'Bias\nb')
    
    # Forward pass
    matmul_result = torch.mm(x_nn, W)
    matmul_node = visualizer2.add_node(matmul_result, 'operation', 'MatMul')
    visualizer2.add_edge(x_node, matmul_node, 'input')
    visualizer2.add_edge(w_node, matmul_node, 'weight')
    
    output = matmul_result + b
    output_node = visualizer2.add_node(output, 'output', 'Linear\nOutput')
    visualizer2.add_edge(matmul_node, output_node, 'add')
    visualizer2.add_edge(b_node, output_node, 'bias')
    
    # Apply activation
    activated = torch.relu(output)
    activation_node = visualizer2.add_node(activated, 'operation', 'ReLU')
    visualizer2.add_edge(output_node, activation_node, 'activation')
    
    # Visualize
    fig2 = visualizer2.visualize_graph()
    plt.show()
    
    # Save analysis
    analysis2 = visualizer2.save_analysis()
    print(f"Neural Network Graph Statistics: {analysis2['graph_statistics']}")
    
    return [visualizer1, visualizer2], [analysis1, analysis2]

# Run computational graph demonstrations
graph_visualizers, graph_analyses = demonstrate_computational_graphs()

print(f"\n💾 Computational graph analyses saved to {results_dir}")
print("\n📊 Key Insights:")
print("• Red circles indicate gradient-enabled tensors")
print("• Node size reflects tensor dimensionality")
print("• Edge colors show gradient flow (red = active)")
print("• Memory usage helps identify computational bottlenecks")
```

## 3. Gradient Flow Animation and Analysis

### 3.1 Real-Time Gradient Flow Visualizer

```python
class GradientFlowAnimator:
    """Advanced gradient flow visualization with real-time monitoring."""
    
    def __init__(self, model, input_data, target, loss_fn=None):
        self.model = model
        self.input_data = input_data
        self.target = target
        self.loss_fn = loss_fn or nn.MSELoss()
        
        # Data collection
        self.layer_activations = []
        self.layer_gradients = []
        self.parameter_gradients = []
        self.gradient_norms = []
        self.activation_stats = {}
        
        # Hook management
        self.hooks = []
        self.analysis_results = {}
        
        self._register_hooks()
    
    def _register_hooks(self):
        """Register comprehensive hooks for monitoring."""
        
        def forward_hook(name):
            def hook(module, input, output):
                if isinstance(output, torch.Tensor):
                    self.layer_activations.append({
                        'name': name,
                        'output': output.detach().clone(),
                        'input': input[0].detach().clone() if input and isinstance(input[0], torch.Tensor) else None,
                        'mean': float(output.mean()),
                        'std': float(output.std()),
                        'min': float(output.min()),
                        'max': float(output.max()),
                        'sparsity': float((output == 0).float().mean())
                    })
            return hook
        
        def backward_hook(name):
            def hook(module, grad_input, grad_output):
                if grad_output and grad_output[0] is not None:
                    grad_tensor = grad_output[0].detach().clone()
                    self.layer_gradients.append({
                        'name': name,
                        'grad_output': grad_tensor,
                        'grad_norm': float(grad_tensor.norm()),
                        'grad_mean': float(grad_tensor.mean()),
                        'grad_std': float(grad_tensor.std())
                    })
            return hook
        
        # Register hooks for all modules
        for name, module in self.model.named_modules():
            if len(list(module.children())) == 0:  # Leaf modules only
                self.hooks.append(module.register_forward_hook(forward_hook(name)))
                self.hooks.append(module.register_backward_hook(backward_hook(name)))
    
    def capture_gradient_flow(self):
        """Perform forward and backward pass with comprehensive monitoring."""
        # Clear previous data
        self.layer_activations.clear()
        self.layer_gradients.clear()
        self.parameter_gradients.clear()
        
        # Forward pass
        self.model.zero_grad()
        self.model.train()
        
        output = self.model(self.input_data)
        loss = self.loss_fn(output, self.target)
        
        # Backward pass
        loss.backward()
        
        # Collect parameter gradients
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                self.parameter_gradients.append({
                    'name': name,
                    'grad': param.grad.detach().clone(),
                    'param': param.detach().clone(),
                    'grad_norm': float(param.grad.norm()),
                    'param_norm': float(param.norm()),
                    'grad_to_param_ratio': float(param.grad.norm() / (param.norm() + 1e-8))
                })
        
        return float(loss.item()), output.detach().clone()
    
    def create_comprehensive_analysis(self, figsize=(20, 16)):
        """Create comprehensive gradient flow analysis dashboard."""
        # Capture data
        loss_value, model_output = self.capture_gradient_flow()
        
        # Create dashboard
        fig = plt.figure(figsize=figsize)
        gs = fig.add_gridspec(4, 4, hspace=0.3, wspace=0.3)
        
        # 1. Network Architecture Overview
        ax1 = fig.add_subplot(gs[0, :2])
        self._plot_network_architecture(ax1)
        
        # 2. Real-time Loss and Output
        ax2 = fig.add_subplot(gs[0, 2:])
        self._plot_model_output_analysis(ax2, model_output, loss_value)
        
        # 3. Activation Flow Analysis
        ax3 = fig.add_subplot(gs[1, :2])
        self._plot_activation_analysis(ax3)
        
        # 4. Gradient Magnitude Distribution
        ax4 = fig.add_subplot(gs[1, 2:])
        self._plot_gradient_magnitude_analysis(ax4)
        
        # 5. Parameter Gradient Analysis
        ax5 = fig.add_subplot(gs[2, :2])
        self._plot_parameter_gradient_analysis(ax5)
        
        # 6. Gradient Flow Direction
        ax6 = fig.add_subplot(gs[2, 2:])
        self._plot_gradient_flow_direction(ax6)
        
        # 7. Activation Statistics Heatmap
        ax7 = fig.add_subplot(gs[3, :2])
        self._plot_activation_statistics_heatmap(ax7)
        
        # 8. Gradient Health Diagnostics
        ax8 = fig.add_subplot(gs[3, 2:])
        self._plot_gradient_health_diagnostics(ax8)
        
        plt.suptitle(f'Gradient Flow Analysis Dashboard (Loss: {loss_value:.6f})', 
                    fontsize=18, fontweight='bold', y=0.98)
        
        # Save comprehensive analysis
        plt.savefig(results_dir / 'gradient_flow_analysis.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        return fig, self._compile_analysis_results(loss_value)
    
    def _plot_network_architecture(self, ax):
        """Plot enhanced network architecture diagram."""
        layer_info = []
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                layer_info.append(f"Linear\n{module.in_features}→{module.out_features}")
            elif isinstance(module, (nn.ReLU, nn.Sigmoid, nn.Tanh, nn.LeakyReLU)):
                layer_info.append(module.__class__.__name__)
            elif isinstance(module, nn.Dropout):
                layer_info.append(f"Dropout\n(p={module.p})")
        
        if layer_info:
            x_positions = np.linspace(0.1, 0.9, len(layer_info))
            
            for i, (x_pos, layer_name) in enumerate(zip(x_positions, layer_info)):
                # Draw layer representation
                height = 0.6
                rect = FancyBboxPatch(
                    (x_pos - 0.04, 0.2), 0.08, height,
                    boxstyle="round,pad=0.01",
                    facecolor='lightblue' if 'Linear' in layer_name else 'lightgreen',
                    edgecolor='navy', alpha=0.7, linewidth=2
                )
                ax.add_patch(rect)
                
                # Add layer label
                ax.text(x_pos, 0.1, layer_name, ha='center', va='center', 
                       fontsize=9, fontweight='bold')
                
                # Draw connections
                if i < len(layer_info) - 1:
                    ax.arrow(x_pos + 0.04, 0.5, x_positions[i+1] - x_pos - 0.08, 0,
                            head_width=0.03, head_length=0.02, fc='darkblue', ec='darkblue')
        
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.set_title('Neural Network Architecture', fontweight='bold', fontsize=14)
        ax.axis('off')
    
    def _plot_model_output_analysis(self, ax, model_output, loss_value):
        """Plot model output and loss analysis."""
        if model_output.dim() == 2:
            # For batch outputs, show distribution
            output_flat = model_output.flatten().cpu().numpy()
            target_flat = self.target.flatten().cpu().numpy()
            
            # Create scatter plot of predictions vs targets
            ax.scatter(target_flat, output_flat, alpha=0.6, s=50)
            
            # Perfect prediction line
            min_val = min(target_flat.min(), output_flat.min())
            max_val = max(target_flat.max(), output_flat.max())
            ax.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, linewidth=2)
            
            ax.set_xlabel('Target Values')
            ax.set_ylabel('Predicted Values')
            ax.set_title(f'Predictions vs Targets\nLoss: {loss_value:.6f}', fontweight='bold')
            ax.grid(True, alpha=0.3)
            
            # Calculate R²
            correlation = np.corrcoef(target_flat, output_flat)[0, 1]
            ax.text(0.05, 0.95, f'Correlation: {correlation:.3f}', 
                   transform=ax.transAxes, fontsize=10, 
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        else:
            ax.text(0.5, 0.5, f'Loss: {loss_value:.6f}\nOutput shape: {model_output.shape}', 
                   ha='center', va='center', transform=ax.transAxes, fontsize=12)
            ax.set_title('Model Output Summary', fontweight='bold')
    
    def _plot_activation_analysis(self, ax):
        """Plot activation flow and statistics."""
        if not self.layer_activations:
            ax.text(0.5, 0.5, 'No activation data captured', ha='center', va='center')
            ax.set_title('Activation Analysis')
            return
        
        layer_names = [act['name'] if act['name'] else f"Layer_{i}" for i, act in enumerate(self.layer_activations)]
        activation_means = [act['mean'] for act in self.layer_activations]
        activation_stds = [act['std'] for act in self.layer_activations]
        sparsity = [act['sparsity'] for act in self.layer_activations]
        
        x = np.arange(len(layer_names))
        width = 0.25
        
        # Multiple bar plot
        bars1 = ax.bar(x - width, activation_means, width, label='Mean', alpha=0.8, color='skyblue')
        bars2 = ax.bar(x, activation_stds, width, label='Std Dev', alpha=0.8, color='lightcoral')
        bars3 = ax.bar(x + width, sparsity, width, label='Sparsity', alpha=0.8, color='lightgreen')
        
        ax.set_xlabel('Layers')
        ax.set_ylabel('Activation Statistics')
        ax.set_title('Layer-wise Activation Analysis', fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels([name[:8] + '...' if len(name) > 8 else name for name in layer_names], 
                          rotation=45, ha='right')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Add value labels
        for bars in [bars1, bars2, bars3]:
            for bar in bars:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                       f'{height:.3f}', ha='center', va='bottom', fontsize=8)
    
    def _plot_gradient_magnitude_analysis(self, ax):
        """Plot gradient magnitude analysis."""
        if not self.parameter_gradients:
            ax.text(0.5, 0.5, 'No gradient data available', ha='center', va='center')
            ax.set_title('Gradient Magnitude Analysis')
            return
        
        param_names = [pg['name'].split('.')[-1][:10] for pg in self.parameter_gradients]
        grad_norms = [pg['grad_norm'] for pg in self.parameter_gradients]
        param_norms = [pg['param_norm'] for pg in self.parameter_gradients]
        
        x = np.arange(len(param_names))
        
        # Dual y-axis plot
        ax2 = ax.twinx()
        
        bars1 = ax.bar(x - 0.2, grad_norms, 0.4, label='Gradient Norm', 
                      alpha=0.8, color='red')
        bars2 = ax2.bar(x + 0.2, param_norms, 0.4, label='Parameter Norm', 
                       alpha=0.8, color='blue')
        
        ax.set_xlabel('Parameters')
        ax.set_ylabel('Gradient Norm', color='red')
        ax2.set_ylabel('Parameter Norm', color='blue')
        ax.set_title('Gradient vs Parameter Magnitudes', fontweight='bold')
        
        ax.set_xticks(x)
        ax.set_xticklabels(param_names, rotation=45, ha='right')
        ax.set_yscale('log')
        ax2.set_yscale('log')
        
        # Combined legend
        lines1, labels1 = ax.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
        
        ax.grid(True, alpha=0.3)
    
    def _plot_parameter_gradient_analysis(self, ax):
        """Plot parameter gradient detailed analysis."""
        if not self.parameter_gradients:
            ax.text(0.5, 0.5, 'No parameter gradients', ha='center', va='center')
            ax.set_title('Parameter Gradient Analysis')
            return
        
        # Calculate gradient-to-parameter ratios
        ratios = [pg['grad_to_param_ratio'] for pg in self.parameter_gradients]
        param_names = [pg['name'].split('.')[-1][:8] for pg in self.parameter_gradients]
        
        # Create scatter plot
        colors = plt.cm.viridis(np.linspace(0, 1, len(ratios)))
        scatter = ax.scatter(range(len(ratios)), ratios, c=colors, s=100, alpha=0.7)
        
        # Add trend line
        if len(ratios) > 1:
            z = np.polyfit(range(len(ratios)), ratios, 1)
            p = np.poly1d(z)
            ax.plot(range(len(ratios)), p(range(len(ratios))), "r--", alpha=0.8)
        
        ax.set_xlabel('Parameter Index')
        ax.set_ylabel('Gradient/Parameter Ratio')
        ax.set_title('Gradient-to-Parameter Ratio Analysis', fontweight='bold')
        ax.set_xticks(range(len(param_names)))
        ax.set_xticklabels(param_names, rotation=45, ha='right')
        ax.set_yscale('log')
        ax.grid(True, alpha=0.3)
        
        # Add threshold lines
        ax.axhline(y=0.01, color='orange', linestyle='--', alpha=0.7, label='Conservative threshold')
        ax.axhline(y=0.1, color='red', linestyle='--', alpha=0.7, label='Aggressive threshold')
        ax.legend()
    
    def _plot_gradient_flow_direction(self, ax):
        """Plot gradient flow direction through the network."""
        if not self.layer_gradients:
            ax.text(0.5, 0.5, 'No layer gradients captured', ha='center', va='center')
            ax.set_title('Gradient Flow Direction')
            return
        
        # Reverse order for backpropagation flow
        layer_names = [lg['name'] if lg['name'] else f"Layer_{i}" 
                      for i, lg in enumerate(reversed(self.layer_gradients))]
        grad_norms = [lg['grad_norm'] for lg in reversed(self.layer_gradients)]
        
        # Create flow visualization
        x_pos = range(len(layer_names))
        
        # Plot gradient flow as connected line with markers
        ax.plot(x_pos, grad_norms, 'o-', linewidth=3, markersize=8, 
               color='red', alpha=0.7, label='Gradient Flow')
        
        # Add directional arrows
        for i in range(len(x_pos) - 1):
            ax.annotate('', xy=(x_pos[i+1], grad_norms[i+1]), 
                       xytext=(x_pos[i], grad_norms[i]),
                       arrowprops=dict(arrowstyle='->', lw=2, color='darkred', alpha=0.6))
        
        ax.set_xlabel('Network Layers (Output → Input)')
        ax.set_ylabel('Gradient Norm (Log Scale)')
        ax.set_title('Backpropagation Flow Analysis', fontweight='bold')
        ax.set_xticks(x_pos)
        ax.set_xticklabels([name[:10] for name in layer_names], rotation=45, ha='right')
        ax.set_yscale('log')
        ax.grid(True, alpha=0.3)
        ax.legend()
        
        # Highlight potential vanishing/exploding gradients
        if grad_norms:
            max_grad = max(grad_norms)
            min_grad = min(grad_norms)
            if max_grad / min_grad > 100:
                ax.text(0.02, 0.98, 'Warning: Large gradient variation detected!', 
                       transform=ax.transAxes, fontsize=10, color='red',
                       bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.8))
    
    def _plot_activation_statistics_heatmap(self, ax):
        """Plot activation statistics as heatmap."""
        if not self.layer_activations:
            ax.text(0.5, 0.5, 'No activation statistics', ha='center', va='center')
            ax.set_title('Activation Statistics Heatmap')
            return
        
        # Prepare data for heatmap
        layer_names = [act['name'][:10] if act['name'] else f"L_{i}" 
                      for i, act in enumerate(self.layer_activations)]
        
        statistics = ['mean', 'std', 'min', 'max', 'sparsity']
        data_matrix = []
        
        for stat in statistics:
            row = [act[stat] for act in self.layer_activations]
            data_matrix.append(row)
        
        # Create heatmap
        im = ax.imshow(data_matrix, cmap='viridis', aspect='auto')
        
        # Set ticks and labels
        ax.set_xticks(range(len(layer_names)))
        ax.set_xticklabels(layer_names, rotation=45, ha='right')
        ax.set_yticks(range(len(statistics)))
        ax.set_yticklabels(statistics)
        
        # Add colorbar
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        
        # Add text annotations
        for i in range(len(statistics)):
            for j in range(len(layer_names)):
                text = ax.text(j, i, f'{data_matrix[i][j]:.2f}',
                             ha="center", va="center", color="white", fontsize=8)
        
        ax.set_title('Activation Statistics Heatmap', fontweight='bold')
    
    def _plot_gradient_health_diagnostics(self, ax):
        """Plot gradient health diagnostics."""
        if not self.parameter_gradients:
            ax.text(0.5, 0.5, 'No gradient data for diagnostics', ha='center', va='center')
            ax.set_title('Gradient Health Diagnostics')
            return
        
        # Calculate health metrics
        grad_norms = [pg['grad_norm'] for pg in self.parameter_gradients]
        
        # Gradient health categories
        vanishing_threshold = 1e-6
        exploding_threshold = 1.0
        
        healthy = sum(1 for g in grad_norms if vanishing_threshold < g < exploding_threshold)
        vanishing = sum(1 for g in grad_norms if g <= vanishing_threshold)
        exploding = sum(1 for g in grad_norms if g >= exploding_threshold)
        
        # Create pie chart
        sizes = [healthy, vanishing, exploding]
        labels = [f'Healthy\n({healthy})', f'Vanishing\n({vanishing})', f'Exploding\n({exploding})']
        colors = ['green', 'orange', 'red']
        
        # Only include non-zero categories
        non_zero_mask = [s > 0 for s in sizes]
        sizes = [s for s, mask in zip(sizes, non_zero_mask) if mask]
        labels = [l for l, mask in zip(labels, non_zero_mask) if mask]
        colors = [c for c, mask in zip(colors, non_zero_mask) if mask]
        
        if sizes:
            wedges, texts, autotexts = ax.pie(sizes, labels=labels, colors=colors, 
                                             autopct='%1.1f%%', startangle=90)
            
            for autotext in autotexts:
                autotext.set_color('white')
                autotext.set_fontweight('bold')
        
        ax.set_title('Gradient Health Status', fontweight='bold')
        
        # Add summary text
        total_params = len(grad_norms)
        health_score = (healthy / total_params * 100) if total_params > 0 else 0
        
        ax.text(0.02, 0.02, f'Health Score: {health_score:.1f}%', 
               transform=ax.transAxes, fontsize=12, fontweight='bold',
               bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
    
    def _compile_analysis_results(self, loss_value):
        """Compile comprehensive analysis results."""
        results = {
            'loss': loss_value,
            'network_statistics': {
                'total_parameters': sum(p.numel() for p in self.model.parameters()),
                'trainable_parameters': sum(p.numel() for p in self.model.parameters() if p.requires_grad),
                'total_layers': len(list(self.model.modules())) - 1,  # Exclude the model itself
            },
            'activation_analysis': {},
            'gradient_analysis': {},
            'health_metrics': {}
        }
        
        # Activation analysis
        if self.layer_activations:
            results['activation_analysis'] = {
                'layer_count': len(self.layer_activations),
                'mean_activation': np.mean([act['mean'] for act in self.layer_activations]),
                'mean_sparsity': np.mean([act['sparsity'] for act in self.layer_activations]),
                'activation_range': {
                    'min': min([act['min'] for act in self.layer_activations]),
                    'max': max([act['max'] for act in self.layer_activations])
                }
            }
        
        # Gradient analysis
        if self.parameter_gradients:
            grad_norms = [pg['grad_norm'] for pg in self.parameter_gradients]
            results['gradient_analysis'] = {
                'parameter_count': len(self.parameter_gradients),
                'mean_grad_norm': float(np.mean(grad_norms)),
                'std_grad_norm': float(np.std(grad_norms)),
                'min_grad_norm': float(np.min(grad_norms)),
                'max_grad_norm': float(np.max(grad_norms)),
                'grad_norm_ratio': float(np.max(grad_norms) / (np.min(grad_norms) + 1e-8))
            }
            
            # Health metrics
            vanishing_count = sum(1 for g in grad_norms if g < 1e-6)
            exploding_count = sum(1 for g in grad_norms if g > 1.0)
            healthy_count = len(grad_norms) - vanishing_count - exploding_count
            
            results['health_metrics'] = {
                'healthy_gradients': healthy_count,
                'vanishing_gradients': vanishing_count,
                'exploding_gradients': exploding_count,
                'health_score': healthy_count / len(grad_norms) * 100 if grad_norms else 0
            }
        
        return results
    
    def save_analysis(self, results):
        """Save comprehensive gradient flow analysis."""
        with open(results_dir / 'gradient_flow_analysis.json', 'w') as f:
            json.dump(results, f, indent=2)
        
        print(f"💾 Gradient flow analysis saved to {results_dir / 'gradient_flow_analysis.json'}")
        return results
    
    def cleanup(self):
        """Remove hooks to prevent memory leaks."""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()

def demonstrate_gradient_flow_analysis():
    """Demonstrate gradient flow analysis with different network architectures."""
    print("\n🌊 Gradient Flow Analysis Demonstrations")
    print("=" * 50)
    
    # Example 1: Healthy Network
    print("\n✅ Example 1: Healthy Network Architecture")
    
    class HealthyNetwork(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.Sequential(
                nn.Linear(10, 20),
                nn.ReLU(),
                nn.Linear(20, 15),
                nn.ReLU(),
                nn.Linear(15, 5),
                nn.ReLU(),
                nn.Linear(5, 1)
            )
        
        def forward(self, x):
            return self.layers(x)
    
    # Create healthy network analysis
    healthy_net = HealthyNetwork()
    healthy_input = torch.randn(5, 10)
    healthy_target = torch.randn(5, 1)
    
    healthy_analyzer = GradientFlowAnimator(healthy_net, healthy_input, healthy_target)
    healthy_fig, healthy_results = healthy_analyzer.create_comprehensive_analysis()
    healthy_analyzer.save_analysis(healthy_results)
    healthy_analyzer.cleanup()
    
    print(f"Healthy Network Health Score: {healthy_results['health_metrics']['health_score']:.1f}%")
    
    # Example 2: Problematic Network (Deep with potential issues)
    print("\n⚠️ Example 2: Problematic Deep Network")
    
    class ProblematicNetwork(nn.Module):
        def __init__(self):
            super().__init__()
            layers = []
            # Very deep network with sigmoid activations (prone to vanishing gradients)
            input_size = 20
            for i in range(8):
                layers.extend([
                    nn.Linear(input_size, input_size),
                    nn.Sigmoid()  # Sigmoid can cause vanishing gradients
                ])
            layers.append(nn.Linear(input_size, 1))
            self.layers = nn.Sequential(*layers)
        
        def forward(self, x):
            return self.layers(x)
    
    # Create problematic network analysis
    problem_net = ProblematicNetwork()
    problem_input = torch.randn(3, 20)
    problem_target = torch.randn(3, 1)
    
    problem_analyzer = GradientFlowAnimator(problem_net, problem_input, problem_target)
    problem_fig, problem_results = problem_analyzer.create_comprehensive_analysis()
    problem_analyzer.save_analysis(problem_results)
    problem_analyzer.cleanup()
    
    print(f"Problematic Network Health Score: {problem_results['health_metrics']['health_score']:.1f}%")
    
    # Comparison analysis
    print("\n📊 Comparative Analysis:")
    print(f"Healthy Network:")
    print(f"  - Health Score: {healthy_results['health_metrics']['health_score']:.1f}%")
    print(f"  - Mean Gradient Norm: {healthy_results['gradient_analysis']['mean_grad_norm']:.2e}")
    print(f"  - Gradient Range: {healthy_results['gradient_analysis']['grad_norm_ratio']:.2f}")
    
    print(f"\nProblematic Network:")
    print(f"  - Health Score: {problem_results['health_metrics']['health_score']:.1f}%")
    print(f"  - Mean Gradient Norm: {problem_results['gradient_analysis']['mean_grad_norm']:.2e}")
    print(f"  - Gradient Range: {problem_results['gradient_analysis']['grad_norm_ratio']:.2f}")
    
    return [healthy_analyzer, problem_analyzer], [healthy_results, problem_results]

# Run gradient flow demonstrations
flow_analyzers, flow_results = demonstrate_gradient_flow_analysis()

print(f"\n💡 Key Gradient Flow Insights:")
print("• Activation statistics reveal signal strength through layers")
print("• Gradient magnitude analysis shows learning capacity")
print("• Health diagnostics identify vanishing/exploding gradient problems")
print("• Flow direction visualization shows backpropagation efficiency")
print("• Real-time monitoring enables proactive training adjustments")
```

## 4. 3D Loss Landscape Exploration

### 4.1 Interactive Loss Landscape Visualizer

```python
class LossLandscapeVisualizer:
    """Advanced 3D loss landscape visualization with optimization path tracking."""
    
    def __init__(self, model, data_loader, loss_fn):
        self.model = model
        self.data_loader = data_loader
        self.loss_fn = loss_fn
        self.original_params = self._get_parameters()
        self.optimization_history = []
        self.landscape_data = {}
    
    def _get_parameters(self):
        """Get flattened model parameters."""
        return torch.cat([p.flatten() for p in self.model.parameters() if p.requires_grad])
    
    def _set_parameters(self, params):
        """Set model parameters from flattened vector."""
        offset = 0
        for p in self.model.parameters():
            if p.requires_grad:
                num_params = p.numel()
                p.data = params[offset:offset+num_params].view(p.shape)
                offset += num_params
    
    def _compute_loss(self, params=None):
        """Compute loss for given parameters."""
        if params is not None:
            self._set_parameters(params)
        
        total_loss = 0
        total_samples = 0
        
        self.model.eval()
        with torch.no_grad():
            for batch_data, batch_target in self.data_loader:
                output = self.model(batch_data)
                loss = self.loss_fn(output, batch_target)
                total_loss += loss.item() * batch_data.size(0)
                total_samples += batch_data.size(0)
        
        return total_loss / total_samples if total_samples > 0 else 0.0
    
    def generate_random_directions(self, num_directions=2):
        """Generate random orthonormal directions for landscape exploration."""
        param_size = len(self.original_params)
        
        # Generate random directions
        directions = []
        for i in range(num_directions):
            direction = torch.randn(param_size)
            
            # Orthogonalize against previous directions
            for prev_dir in directions:
                direction = direction - (direction @ prev_dir) * prev_dir
            
            # Normalize
            direction = direction / direction.norm()
            directions.append(direction)
        
        return directions
    
    def create_2d_landscape_slice(self, direction1=None, direction2=None, 
                                 alpha_range=(-1, 1), beta_range=(-1, 1), resolution=30):
        """Create 2D slice of loss landscape with enhanced analysis."""
        
        # Generate directions if not provided
        if direction1 is None or direction2 is None:
            directions = self.generate_random_directions(2)
            direction1 = directions[0] if direction1 is None else direction1
            direction2 = directions[1] if direction2 is None else direction2
        
        alphas = np.linspace(alpha_range[0], alpha_range[1], resolution)
        betas = np.linspace(beta_range[0], beta_range[1], resolution)
        
        losses = np.zeros((resolution, resolution))
        gradient_norms = np.zeros((resolution, resolution))
        
        print(f"🏔️ Computing {resolution}x{resolution} loss landscape...")
        
        # Store original state
        original_loss = self._compute_loss()
        
        for i, alpha in enumerate(alphas):
            for j, beta in enumerate(betas):
                # Perturb parameters
                perturbed_params = self.original_params + alpha * direction1 + beta * direction2
                
                # Compute loss
                loss = self._compute_loss(perturbed_params)
                losses[i, j] = loss
                
                # Compute gradient norm at this point (optional, expensive)
                if resolution <= 20:  # Only for small grids
                    self._set_parameters(perturbed_params)
                    self.model.zero_grad()
                    
                    total_grad_norm = 0
                    sample_count = 0
                    
                    for batch_data, batch_target in self.data_loader:
                        if sample_count >= 1:  # Limit to first batch for speed
                            break
                        output = self.model(batch_data)
                        batch_loss = self.loss_fn(output, batch_target)
                        batch_loss.backward()
                        
                        grad_norm = torch.cat([p.grad.flatten() for p in self.model.parameters() 
                                             if p.grad is not None]).norm()
                        total_grad_norm += grad_norm.item()
                        sample_count += 1
                    
                    gradient_norms[i, j] = total_grad_norm / sample_count if sample_count > 0 else 0
            
            if (i + 1) % max(1, resolution // 10) == 0:
                print(f"  Progress: {(i+1)/resolution*100:.1f}%")
        
        # Restore original parameters
        self._set_parameters(self.original_params)
        
        # Store landscape data
        self.landscape_data = {
            'alphas': alphas,
            'betas': betas,
            'losses': losses,
            'gradient_norms': gradient_norms,
            'direction1': direction1,
            'direction2': direction2,
            'original_loss': original_loss,
            'min_loss': float(np.min(losses)),
            'max_loss': float(np.max(losses)),
            'current_position': (0, 0)  # Origin represents current parameters
        }
        
        return alphas, betas, losses, gradient_norms
    
    def create_comprehensive_landscape_analysis(self, figsize=(20, 16)):
        """Create comprehensive loss landscape analysis dashboard."""
        
        # Generate landscape data
        alphas, betas, losses, gradient_norms = self.create_2d_landscape_slice(resolution=25)
        
        # Create dashboard
        fig = plt.figure(figsize=figsize)
        gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.3, height_ratios=[1.2, 1, 1])
        
        # 1. 3D Surface Plot (large, top row)
        ax1 = fig.add_subplot(gs[0, :2], projection='3d')
        self._plot_3d_surface(ax1, alphas, betas, losses)
        
        # 2. 2D Contour Plot with Current Position
        ax2 = fig.add_subplot(gs[0, 2:])
        self._plot_contour_with_analysis(ax2, alphas, betas, losses)
        
        # 3. Loss Cross-Sections
        ax3 = fig.add_subplot(gs[1, :2])
        self._plot_loss_cross_sections(ax3, alphas, betas, losses)
        
        # 4. Gradient Magnitude Heatmap
        ax4 = fig.add_subplot(gs[1, 2:])
        self._plot_gradient_magnitude_heatmap(ax4, alphas, betas, gradient_norms)
        
        # 5. Loss Statistics
        ax5 = fig.add_subplot(gs[2, 0])
        self._plot_loss_statistics(ax5, losses)
        
        # 6. Curvature Analysis
        ax6 = fig.add_subplot(gs[2, 1])
        self._plot_curvature_analysis(ax6, alphas, betas, losses)
        
        # 7. Optimization Difficulty Assessment
        ax7 = fig.add_subplot(gs[2, 2])
        self._plot_optimization_difficulty(ax7, losses, gradient_norms)
        
        # 8. Landscape Summary
        ax8 = fig.add_subplot(gs[2, 3])
        self._plot_landscape_summary(ax8)
        
        plt.suptitle('Loss Landscape Comprehensive Analysis', fontsize=18, fontweight='bold')
        
        # Save analysis
        plt.savefig(results_dir / 'loss_landscape_analysis.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        return fig, self.landscape_data
    
    def _plot_3d_surface(self, ax, alphas, betas, losses):
        """Plot 3D loss surface with enhanced visualization."""
        A, B = np.meshgrid(alphas, betas)
        
        # Create surface plot with custom colormap
        surface = ax.plot_surface(A, B, losses.T, cmap='viridis', alpha=0.9,
                                 linewidth=0.5, antialiased=True, edgecolor='none')
        
        # Mark current position
        current_loss = self.landscape_data['original_loss']
        ax.scatter([0], [0], [current_loss], color='red', s=100, alpha=1.0, 
                  label=f'Current Position\n(Loss: {current_loss:.4f})')
        
        # Mark global minimum in the slice
        min_idx = np.unravel_index(np.argmin(losses), losses.shape)
        min_alpha = alphas[min_idx[0]]
        min_beta = betas[min_idx[1]]
        min_loss = losses[min_idx]
        ax.scatter([min_alpha], [min_beta], [min_loss], color='lime', s=100, alpha=1.0,
                  label=f'Local Minimum\n(Loss: {min_loss:.4f})')
        
        ax.set_xlabel('Direction 1 (α)', fontweight='bold')
        ax.set_ylabel('Direction 2 (β)', fontweight='bold')
        ax.set_zlabel('Loss', fontweight='bold')
        ax.set_title('3D Loss Landscape Surface', fontweight='bold', fontsize=14)
        ax.legend()
        
        # Add colorbar
        fig = ax.figure
        cbar = fig.colorbar(surface, ax=ax, shrink=0.5, aspect=5)
        cbar.set_label('Loss Value', fontweight='bold')
    
    def _plot_contour_with_analysis(self, ax, alphas, betas, losses):
        """Plot contour map with detailed analysis."""
        A, B = np.meshgrid(alphas, betas)
        
        # Create contour plot
        levels = np.logspace(np.log10(np.min(losses)), np.log10(np.max(losses)), 15)
        contour = ax.contour(A, B, losses.T, levels=levels, alpha=0.8, colors='black', linewidths=0.5)
        contourf = ax.contourf(A, B, losses.T, levels=levels, cmap='viridis', alpha=0.7)
        
        # Add contour labels
        ax.clabel(contour, inline=True, fontsize=8, fmt='%.3f')
        
        # Mark important points
        ax.plot(0, 0, 'ro', markersize=12, label='Current Position', markeredgecolor='darkred')
        
        # Mark local minima
        min_idx = np.unravel_index(np.argmin(losses), losses.shape)
        min_alpha = alphas[min_idx[0]]
        min_beta = betas[min_idx[1]]
        ax.plot(min_alpha, min_beta, 's', color='lime', markersize=10, 
               label='Local Minimum', markeredgecolor='darkgreen')
        
        # Add gradient descent path simulation
        self._add_gradient_descent_simulation(ax, alphas, betas, losses)
        
        ax.set_xlabel('Direction 1 (α)', fontweight='bold')
        ax.set_ylabel('Direction 2 (β)', fontweight='bold')
        ax.set_title('Loss Contours with Analysis', fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Add colorbar
        plt.colorbar(contourf, ax=ax, label='Loss Value')
    
    def _add_gradient_descent_simulation(self, ax, alphas, betas, losses):
        """Add simulated gradient descent path to contour plot."""
        # Simple gradient descent simulation
        A, B = np.meshgrid(alphas, betas)
        
        # Compute gradients using finite differences
        dy, dx = np.gradient(losses.T)
        
        # Normalize gradients
        grad_magnitude = np.sqrt(dx**2 + dy**2)
        dx_norm = dx / (grad_magnitude + 1e-8)
        dy_norm = dy / (grad_magnitude + 1e-8)
        
        # Sample points for gradient vectors
        step = max(1, len(alphas) // 8)
        X_sample = A[::step, ::step]
        Y_sample = B[::step, ::step]
        U_sample = -dx_norm[::step, ::step]  # Negative for descent
        V_sample = -dy_norm[::step, ::step]
        
        # Add gradient vectors
        ax.quiver(X_sample, Y_sample, U_sample, V_sample, 
                 alpha=0.6, scale=20, width=0.003, color='white')
        
        # Simulate a gradient descent path from current position
        path_alpha, path_beta = [0], [0]
        learning_rate = 0.1
        
        for step in range(20):
            current_alpha, current_beta = path_alpha[-1], path_beta[-1]
            
            # Find nearest grid point
            alpha_idx = np.argmin(np.abs(alphas - current_alpha))
            beta_idx = np.argmin(np.abs(betas - current_beta))
            
            if 0 <= alpha_idx < len(alphas)-1 and 0 <= beta_idx < len(betas)-1:
                # Compute gradient at current position
                grad_alpha = (losses[alpha_idx+1, beta_idx] - losses[alpha_idx, beta_idx]) / (alphas[1] - alphas[0])
                grad_beta = (losses[alpha_idx, beta_idx+1] - losses[alpha_idx, beta_idx]) / (betas[1] - betas[0])
                
                # Update position
                new_alpha = current_alpha - learning_rate * grad_alpha
                new_beta = current_beta - learning_rate * grad_beta
                
                # Keep within bounds
                new_alpha = np.clip(new_alpha, alphas[0], alphas[-1])
                new_beta = np.clip(new_beta, betas[0], betas[-1])
                
                path_alpha.append(new_alpha)
                path_beta.append(new_beta)
            else:
                break
        
        # Plot gradient descent path
        ax.plot(path_alpha, path_beta, 'y-', linewidth=3, alpha=0.8, 
               label='Simulated GD Path')
        ax.plot(path_alpha[-1], path_beta[-1], 'yo', markersize=8, 
               markeredgecolor='orange')
    
    def _plot_loss_cross_sections(self, ax, alphas, betas, losses):
        """Plot loss cross-sections along principal directions."""
        # Cross-section along direction 1 (beta=0)
        center_beta_idx = len(betas) // 2
        loss_alpha = losses[:, center_beta_idx]
        
        # Cross-section along direction 2 (alpha=0)
        center_alpha_idx = len(alphas) // 2
        loss_beta = losses[center_alpha_idx, :]
        
        # Plot both cross-sections
        ax.plot(alphas, loss_alpha, 'b-', linewidth=3, label='Direction 1 (α)', marker='o', markersize=4)
        ax.plot(betas, loss_beta, 'r-', linewidth=3, label='Direction 2 (β)', marker='s', markersize=4)
        
        # Mark current position
        current_loss = self.landscape_data['original_loss']
        ax.axvline(x=0, color='gray', linestyle='--', alpha=0.7, linewidth=2)
        ax.axhline(y=current_loss, color='gray', linestyle='--', alpha=0.7, linewidth=2,
                  label=f'Current Loss: {current_loss:.4f}')
        
        # Mark minima in each direction
        min_alpha_idx = np.argmin(loss_alpha)
        min_beta_idx = np.argmin(loss_beta)
        
        ax.plot(alphas[min_alpha_idx], loss_alpha[min_alpha_idx], 'bo', markersize=10, 
               markerfacecolor='lightblue', markeredgecolor='darkblue')
        ax.plot(betas[min_beta_idx], loss_beta[min_beta_idx], 'ro', markersize=10,
               markerfacecolor='lightcoral', markeredgecolor='darkred')
        
        ax.set_xlabel('Perturbation Amount', fontweight='bold')
        ax.set_ylabel('Loss', fontweight='bold')
        ax.set_title('Loss Cross-Sections Along Principal Directions', fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Add curvature information
        alpha_curvature = np.mean(np.diff(loss_alpha, 2))
        beta_curvature = np.mean(np.diff(loss_beta, 2))
        
        ax.text(0.02, 0.98, f'Curvature α: {alpha_curvature:.3f}\nCurvature β: {beta_curvature:.3f}',
               transform=ax.transAxes, fontsize=10, verticalalignment='top',
               bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    def _plot_gradient_magnitude_heatmap(self, ax, alphas, betas, gradient_norms):
        """Plot gradient magnitude heatmap."""
        if np.any(gradient_norms > 0):
            A, B = np.meshgrid(alphas, betas)
            
            # Create heatmap
            heatmap = ax.imshow(gradient_norms.T, extent=[alphas[0], alphas[-1], betas[0], betas[-1]], 
                               origin='lower', cmap='plasma', aspect='auto')
            
            # Add contours
            contour = ax.contour(A, B, gradient_norms.T, levels=8, colors='white', alpha=0.5, linewidths=1)
            ax.clabel(contour, inline=True, fontsize=8, fmt='%.2f')
            
            # Mark current position
            ax.plot(0, 0, 'wo', markersize=10, markeredgecolor='black', linewidth=2,
                   label='Current Position')
            
            plt.colorbar(heatmap, ax=ax, label='Gradient Magnitude')
            ax.set_xlabel('Direction 1 (α)', fontweight='bold')
            ax.set_ylabel('Direction 2 (β)', fontweight='bold')
            ax.set_title('Gradient Magnitude Landscape', fontweight='bold')
            ax.legend()
        else:
            ax.text(0.5, 0.5, 'Gradient data not available\n(Enable for small landscapes)', 
                   ha='center', va='center', transform=ax.transAxes, fontsize=12)
            ax.set_title('Gradient Magnitude Analysis', fontweight='bold')
    
    def _plot_loss_statistics(self, ax, losses):
        """Plot loss distribution statistics."""
        loss_flat = losses.flatten()
        
        # Create histogram
        n, bins, patches = ax.hist(loss_flat, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
        
        # Add statistics
        mean_loss = np.mean(loss_flat)
        median_loss = np.median(loss_flat)
        std_loss = np.std(loss_flat)
        
        ax.axvline(mean_loss, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_loss:.4f}')
        ax.axvline(median_loss, color='orange', linestyle='--', linewidth=2, label=f'Median: {median_loss:.4f}')
        ax.axvline(self.landscape_data['original_loss'], color='green', linestyle='-', linewidth=3,
                  label=f'Current: {self.landscape_data["original_loss"]:.4f}')
        
        ax.set_xlabel('Loss Value', fontweight='bold')
        ax.set_ylabel('Frequency', fontweight='bold')
        ax.set_title('Loss Distribution', fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Add percentile information
        percentiles = [10, 25, 75, 90]
        perc_values = np.percentile(loss_flat, percentiles)
        
        info_text = f'Std: {std_loss:.4f}\n'
        for p, v in zip(percentiles, perc_values):
            info_text += f'P{p}: {v:.4f}\n'
        
        ax.text(0.98, 0.98, info_text, transform=ax.transAxes, fontsize=9,
               verticalalignment='top', horizontalalignment='right',
               bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    def _plot_curvature_analysis(self, ax, alphas, betas, losses):
        """Plot curvature analysis of the loss landscape."""
        # Compute second derivatives (curvature)
        d2_alpha = np.diff(losses, 2, axis=0)
        d2_beta = np.diff(losses, 2, axis=1)
        
        # Mixed derivative
        d_alpha = np.diff(losses, 1, axis=0)
        d2_alpha_beta = np.diff(d_alpha, 1, axis=1)
        
        # Average curvatures
        curvature_stats = {
            'Alpha direction': np.mean(d2_alpha),
            'Beta direction': np.mean(d2_beta),
            'Mixed (α,β)': np.mean(d2_alpha_beta),
            'Total curvature': np.mean(np.abs(d2_alpha)) + np.mean(np.abs(d2_beta))
        }
        
        # Create bar plot
        names = list(curvature_stats.keys())
        values = list(curvature_stats.values())
        colors = ['blue', 'red', 'green', 'purple']
        
        bars = ax.bar(names, values, color=colors, alpha=0.7)
        
        # Add value labels
        for bar, value in zip(bars, values):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height + 0.01 * max(np.abs(values)),
                   f'{value:.3f}', ha='center', va='bottom', fontsize=9, fontweight='bold')
        
        ax.set_ylabel('Curvature Value', fontweight='bold')
        ax.set_title('Landscape Curvature Analysis', fontweight='bold')
        ax.tick_params(axis='x', rotation=45)
        ax.grid(True, alpha=0.3)
        ax.axhline(y=0, color='black', linestyle='-', alpha=0.3)
        
        # Add interpretation
        avg_curvature = np.mean([abs(v) for v in values[:3]])
        if avg_curvature > 0.1:
            interpretation = "High curvature\n(Challenging optimization)"
        elif avg_curvature > 0.01:
            interpretation = "Moderate curvature\n(Standard optimization)"
        else:
            interpretation = "Low curvature\n(Smooth optimization)"
        
        ax.text(0.98, 0.02, interpretation, transform=ax.transAxes, fontsize=10,
               verticalalignment='bottom', horizontalalignment='right',
               bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.8))
    
    def _plot_optimization_difficulty(self, ax, losses, gradient_norms):
        """Plot optimization difficulty assessment."""
        # Calculate various difficulty metrics
        loss_range = np.max(losses) - np.min(losses)
        loss_variance = np.var(losses)
        
        # Local minima detection (simplified)
        local_minima_count = 0
        flat_losses = losses.flatten()
        for i in range(1, len(flat_losses) - 1):
            if flat_losses[i] < flat_losses[i-1] and flat_losses[i] < flat_losses[i+1]:
                local_minima_count += 1
        
        # Gradient information if available
        if np.any(gradient_norms > 0):
            avg_grad_norm = np.mean(gradient_norms[gradient_norms > 0])
            grad_variance = np.var(gradient_norms[gradient_norms > 0])
        else:
            avg_grad_norm = 0
            grad_variance = 0
        
        # Create difficulty metrics
        difficulty_metrics = {
            'Loss Range': loss_range,
            'Loss Variance': loss_variance,
            'Local Minima': local_minima_count,
            'Avg Gradient': avg_grad_norm,
            'Grad Variance': grad_variance
        }
        
        # Normalize metrics for visualization
        max_values = {
            'Loss Range': loss_range,
            'Loss Variance': max(loss_variance, 1),
            'Local Minima': max(local_minima_count, 1),
            'Avg Gradient': max(avg_grad_norm, 1),
            'Grad Variance': max(grad_variance, 1)
        }
        
        normalized_metrics = {k: v / max_values[k] for k, v in difficulty_metrics.items()}
        
        # Create radar chart
        angles = np.linspace(0, 2 * np.pi, len(normalized_metrics), endpoint=False)
        values = list(normalized_metrics.values())
        
        # Close the plot
        angles = np.concatenate((angles, [angles[0]]))
        values = np.concatenate((values, [values[0]]))
        
        ax.plot(angles, values, 'o-', linewidth=2, color='red', alpha=0.8)
        ax.fill(angles, values, alpha=0.25, color='red')
        
        # Add labels
        ax.set_xticks(angles[:-1])
        ax.set_xticklabels(list(normalized_metrics.keys()), fontsize=10)
        ax.set_ylim(0, 1)
        ax.set_title('Optimization Difficulty Assessment', fontweight='bold')
        ax.grid(True, alpha=0.3)
        
        # Add difficulty score
        difficulty_score = np.mean(values[:-1])
        if difficulty_score > 0.7:
            difficulty_level = "Very Hard"
            color = 'red'
        elif difficulty_score > 0.5:
            difficulty_level = "Hard"
            color = 'orange'
        elif difficulty_score > 0.3:
            difficulty_level = "Moderate"
            color = 'yellow'
        else:
            difficulty_level = "Easy"
            color = 'green'
        
        ax.text(0.02, 0.98, f'Difficulty: {difficulty_level}\nScore: {difficulty_score:.2f}',
               transform=ax.transAxes, fontsize=12, fontweight='bold',
               verticalalignment='top', color=color,
               bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))
    
    def _plot_landscape_summary(self, ax):
        """Plot comprehensive landscape summary."""
        summary_data = self.landscape_data
        
        # Create summary text
        summary_text = f"""
🏔️ LANDSCAPE SUMMARY

📊 Loss Statistics:
  Current Loss: {summary_data['original_loss']:.4f}
  Minimum Loss: {summary_data['min_loss']:.4f}
  Maximum Loss: {summary_data['max_loss']:.4f}
  Loss Range: {summary_data['max_loss'] - summary_data['min_loss']:.4f}

📈 Optimization Insights:
  • Loss landscape explored in 2D slice
  • Current position marked in red
  • Gradient flow arrows shown
  • Local minima identified

🎯 Recommendations:
  • Monitor gradient magnitudes
  • Consider adaptive learning rates
  • Watch for vanishing gradients
  • Use momentum for escaping valleys
        """
        
        ax.text(0.05, 0.95, summary_text, transform=ax.transAxes, fontsize=10,
               verticalalignment='top', horizontalalignment='left',
               bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
        
        ax.set_title('Analysis Summary', fontweight='bold')
        ax.axis('off')
    
    def track_optimization_path(self, optimizer, num_steps=50, save_interval=5):
        """Track optimization path on the loss landscape."""
        print(f"🚀 Tracking optimization path for {num_steps} steps...")
        
        # Initialize tracking
        self.optimization_history = []
        param_history = []
        loss_history = []
        
        for step in range(num_steps):
            # Store current state
            current_params = self._get_parameters().clone()
            param_history.append(current_params)
            
            # Training step
            self.model.train()
            total_loss = 0
            
            for batch_data, batch_target in self.data_loader:
                optimizer.zero_grad()
                output = self.model(batch_data)
                loss = self.loss_fn(output, batch_target)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            
            avg_loss = total_loss / len(self.data_loader)
            loss_history.append(avg_loss)
            
            # Save state periodically
            if step % save_interval == 0:
                self.optimization_history.append({
                    'step': step,
                    'loss': avg_loss,
                    'parameters': current_params.clone()
                })
                print(f"  Step {step}: Loss = {avg_loss:.6f}")
        
        return param_history, loss_history
    
    def visualize_optimization_path(self, param_history, loss_history, figsize=(16, 12)):
        """Visualize the optimization path on the loss landscape."""
        if not hasattr(self, 'landscape_data') or not self.landscape_data:
            print("⚠️ No landscape data available. Generate landscape first.")
            return None
        
        # Project parameter history onto the 2D landscape directions
        direction1 = self.landscape_data['direction1']
        direction2 = self.landscape_data['direction2']
        
        path_alphas = []
        path_betas = []
        
        for params in param_history:
            param_diff = params - self.original_params
            alpha = torch.dot(param_diff, direction1).item()
            beta = torch.dot(param_diff, direction2).item()
            path_alphas.append(alpha)
            path_betas.append(beta)
        
        # Create visualization
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize)
        
        # 1. Optimization path on contour plot
        alphas = self.landscape_data['alphas']
        betas = self.landscape_data['betas']
        losses = self.landscape_data['losses']
        
        A, B = np.meshgrid(alphas, betas)
        contour = ax1.contour(A, B, losses.T, levels=15, alpha=0.6)
        ax1.clabel(contour, inline=True, fontsize=8)
        
        # Plot optimization path
        path_colors = plt.cm.viridis(np.linspace(0, 1, len(path_alphas)))
        
        for i in range(len(path_alphas) - 1):
            ax1.plot([path_alphas[i], path_alphas[i+1]], [path_betas[i], path_betas[i+1]], 
                    color=path_colors[i], linewidth=2, alpha=0.8)
        
        # Mark start and end
        ax1.plot(path_alphas[0], path_betas[0], 'go', markersize=12, label='Start')
        ax1.plot(path_alphas[-1], path_betas[-1], 'ro', markersize=12, label='End')
        
        ax1.set_xlabel('Direction 1 (α)')
        ax1.set_ylabel('Direction 2 (β)')
        ax1.set_title('Optimization Path on Loss Landscape')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # 2. Loss over time
        ax2.plot(loss_history, 'b-', linewidth=2, marker='o', markersize=4)
        ax2.set_xlabel('Optimization Step')
        ax2.set_ylabel('Loss')
        ax2.set_title('Loss Convergence')
        ax2.grid(True, alpha=0.3)
        
        # Add exponential fit
        if len(loss_history) > 10:
            steps = np.arange(len(loss_history))
            try:
                # Fit exponential decay
                from scipy.optimize import curve_fit
                def exp_decay(x, a, b, c):
                    return a * np.exp(-b * x) + c
                
                popt, _ = curve_fit(exp_decay, steps, loss_history, maxfev=1000)
                ax2.plot(steps, exp_decay(steps, *popt), 'r--', alpha=0.8, 
                        label=f'Exp fit: {popt[0]:.3f}*exp(-{popt[1]:.3f}*x)+{popt[2]:.3f}')
                ax2.legend()
            except:
                pass
        
        # 3. Parameter space trajectory
        param_norms = [torch.norm(params - self.original_params).item() for params in param_history]
        
        ax3.plot(param_norms, 'g-', linewidth=2, marker='s', markersize=4)
        ax3.set_xlabel('Optimization Step')
        ax3.set_ylabel('Parameter Distance from Start')
        ax3.set_title('Parameter Space Movement')
        ax3.grid(True, alpha=0.3)
        
        # 4. Learning rate analysis
        if len(loss_history) > 1:
            loss_changes = np.diff(loss_history)
            param_changes = np.diff(param_norms)
            
            ax4.scatter(param_changes[:-1], loss_changes[:-1], alpha=0.6, s=30)
            ax4.set_xlabel('Parameter Change')
            ax4.set_ylabel('Loss Change')
            ax4.set_title('Parameter vs Loss Changes')
            ax4.grid(True, alpha=0.3)
            
            # Add trend line
            if len(param_changes) > 2:
                z = np.polyfit(param_changes[:-1], loss_changes[:-1], 1)
                p = np.poly1d(z)
                x_trend = np.linspace(min(param_changes), max(param_changes), 100)
                ax4.plot(x_trend, p(x_trend), "r--", alpha=0.8)
        
        plt.suptitle('Optimization Path Analysis', fontsize=16, fontweight='bold')
        plt.tight_layout()
        
        # Save visualization
        plt.savefig(results_dir / 'optimization_path_analysis.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        return fig
    
    def save_landscape_analysis(self):
        """Save comprehensive landscape analysis to JSON."""
        # Prepare serializable data
        analysis_data = {
            'landscape_statistics': {
                'original_loss': float(self.landscape_data['original_loss']),
                'min_loss': float(self.landscape_data['min_loss']),
                'max_loss': float(self.landscape_data['max_loss']),
                'loss_range': float(self.landscape_data['max_loss'] - self.landscape_data['min_loss']),
                'landscape_shape': self.landscape_data['losses'].shape
            },
            'optimization_metrics': {
                'parameter_count': len(self.original_params),
                'directions_explored': 2,
                'resolution': len(self.landscape_data['alphas'])
            }
        }
        
        # Add optimization history if available
        if self.optimization_history:
            analysis_data['optimization_history'] = [
                {
                    'step': entry['step'],
                    'loss': float(entry['loss'])
                }
                for entry in self.optimization_history
            ]
        
        with open(results_dir / 'loss_landscape_analysis.json', 'w') as f:
            json.dump(analysis_data, f, indent=2)
        
        print(f"💾 Loss landscape analysis saved to {results_dir / 'loss_landscape_analysis.json'}")
        return analysis_data

def demonstrate_loss_landscape_analysis():
    """Demonstrate loss landscape visualization with different scenarios."""
    print("\n🏔️ Loss Landscape Analysis Demonstrations")
    print("=" * 50)
    
    # Create sample dataset
    torch.manual_seed(42)
    X = torch.randn(100, 8)
    y = torch.sum(X[:, :4], dim=1, keepdim=True) + 0.1 * torch.randn(100, 1)
    dataset = TensorDataset(X, y)
    data_loader = DataLoader(dataset, batch_size=20, shuffle=True)
    
    # Example 1: Simple Linear Model
    print("\n📈 Example 1: Simple Linear Model Landscape")
    
    class SimpleLinearModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear = nn.Linear(8, 1)
        
        def forward(self, x):
            return self.linear(x)
    
    simple_model = SimpleLinearModel()
    simple_visualizer = LossLandscapeVisualizer(simple_model, data_loader, nn.MSELoss())
    
    # Generate landscape analysis
    simple_fig, simple_landscape = simple_visualizer.create_comprehensive_landscape_analysis()
    simple_analysis = simple_visualizer.save_landscape_analysis()
    
    print(f"Simple Model Landscape:")
    print(f"  - Loss range: {simple_landscape['max_loss'] - simple_landscape['min_loss']:.4f}")
    print(f"  - Current loss: {simple_landscape['original_loss']:.4f}")
    
    # Example 2: Deep Nonlinear Model
    print("\n🧠 Example 2: Deep Nonlinear Model Landscape")
    
    class DeepNonlinearModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.Sequential(
                nn.Linear(8, 16),
                nn.ReLU(),
                nn.Linear(16, 12),
                nn.Tanh(),
                nn.Linear(12, 8),
                nn.ReLU(),
                nn.Linear(8, 1)
            )
        
        def forward(self, x):
            return self.layers(x)
    
    deep_model = DeepNonlinearModel()
    deep_visualizer = LossLandscapeVisualizer(deep_model, data_loader, nn.MSELoss())
    
    # Generate landscape analysis
    deep_fig, deep_landscape = deep_visualizer.create_comprehensive_landscape_analysis()
    deep_analysis = deep_visualizer.save_landscape_analysis()
    
    print(f"Deep Model Landscape:")
    print(f"  - Loss range: {deep_landscape['max_loss'] - deep_landscape['min_loss']:.4f}")
    print(f"  - Current loss: {deep_landscape['original_loss']:.4f}")
    
    # Example 3: Optimization Path Tracking
    print("\n🚀 Example 3: Optimization Path Tracking")
    
    # Train the deep model and track optimization path
    optimizer = torch.optim.SGD(deep_model.parameters(), lr=0.01, momentum=0.9)
    param_hist, loss_hist = deep_visualizer.track_optimization_path(optimizer, num_steps=30)
    
    # Visualize optimization path
    path_fig = deep_visualizer.visualize_optimization_path(param_hist, loss_hist)
    
    print(f"Optimization tracking:")
    print(f"  - Initial loss: {loss_hist[0]:.6f}")
    print(f"  - Final loss: {loss_hist[-1]:.6f}")
    print(f"  - Improvement: {loss_hist[0] - loss_hist[-1]:.6f}")
    
    return [simple_visualizer, deep_visualizer], [simple_analysis, deep_analysis]

# Run loss landscape demonstrations
landscape_visualizers, landscape_analyses = demonstrate_loss_landscape_analysis()

print(f"\n💡 Key Loss Landscape Insights:")
print("• 3D surfaces reveal optimization complexity")
print("• Contour plots show gradient descent paths")
print("• Curvature analysis predicts optimization difficulty")
print("• Real-time tracking enables adaptive strategies")
print("• Cross-sections reveal directional sensitivities")
```

## 5. Advanced Visualization Techniques

### 5.1 Activation Pattern Analysis

```python
class ActivationAnalyzer:
    """Comprehensive activation pattern analysis and visualization."""
    
    def __init__(self, model):
        self.model = model
        self.activations = {}
        self.activation_stats = {}
        self.hooks = []
        self.layer_info = {}
        
        self._register_activation_hooks()
        self._analyze_model_architecture()
    
    def _register_activation_hooks(self):
        """Register hooks to capture activations from all layers."""
        
        def get_activation_hook(name):
            def hook(module, input, output):
                if isinstance(output, torch.Tensor):
                    activation_data = {
                        'tensor': output.detach().clone(),
                        'mean': float(output.mean()),
                        'std': float(output.std()),
                        'min': float(output.min()),
                        'max': float(output.max()),
                        'shape': tuple(output.shape),
                        'sparsity': float((output == 0).float().mean()),
                        'saturation': self._compute_saturation(output, module)
                    }
                    self.activations[name] = activation_data
            return hook
        
        # Register hooks for all layers
        for name, module in self.model.named_modules():
            if len(list(module.children())) == 0:  # Leaf modules only
                hook = module.register_forward_hook(get_activation_hook(name))
                self.hooks.append(hook)
                
                # Store layer information
                self.layer_info[name] = {
                    'type': module.__class__.__name__,
                    'parameters': sum(p.numel() for p in module.parameters()),
                    'trainable': sum(p.numel() for p in module.parameters() if p.requires_grad)
                }
    
    def _analyze_model_architecture(self):
        """Analyze model architecture for visualization context."""
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        
        self.model_info = {
            'total_parameters': total_params,
            'trainable_parameters': trainable_params,
            'total_layers': len(self.layer_info),
            'model_size_mb': total_params * 4 / (1024 * 1024)  # Assuming float32
        }
    
    def _compute_saturation(self, tensor, module):
        """Compute activation saturation for different activation functions."""
        if isinstance(module, nn.ReLU):
            # For ReLU, saturation is the fraction of zero activations
            return float((tensor == 0).float().mean())
        elif isinstance(module, nn.Sigmoid):
            # For Sigmoid, saturation is fraction near 0 or 1
            threshold = 0.01
            saturated = ((tensor < threshold) | (tensor > 1 - threshold)).float()
            return float(saturated.mean())
        elif isinstance(module, nn.Tanh):
            # For Tanh, saturation is fraction near -1 or 1
            threshold = 0.02
            saturated = ((tensor < -1 + threshold) | (tensor > 1 - threshold)).float()
            return float(saturated.mean())
        else:
            return 0.0
    
    def capture_activations(self, input_data):
        """Capture activations for given input data."""
        self.activations.clear()
        
        self.model.eval()
        with torch.no_grad():
            output = self.model(input_data)
        
        return output
    
    def create_comprehensive_activation_analysis(self, input_data, figsize=(20, 16)):
        """Create comprehensive activation analysis dashboard."""
        
        # Capture activations
        model_output = self.capture_activations(input_data)
        
        # Create dashboard
        fig = plt.figure(figsize=figsize)
        gs = fig.add_gridspec(4, 4, hspace=0.4, wspace=0.3)
        
        # 1. Activation Flow Overview
        ax1 = fig.add_subplot(gs[0, :2])
        self._plot_activation_flow_overview(ax1)
        
        # 2. Activation Statistics Summary
        ax2 = fig.add_subplot(gs[0, 2:])
        self._plot_activation_statistics_summary(ax2)
        
        # 3. Layer-wise Activation Distributions
        ax3 = fig.add_subplot(gs[1, :2])
        self._plot_activation_distributions(ax3)
        
        # 4. Activation Saturation Analysis
        ax4 = fig.add_subplot(gs[1, 2:])
        self._plot_saturation_analysis(ax4)
        
        # 5. Activation Correlation Matrix
        ax5 = fig.add_subplot(gs[2, :2])
        self._plot_activation_correlation_matrix(ax5)
        
        # 6. Dead Neuron Detection
        ax6 = fig.add_subplot(gs[2, 2:])
        self._plot_dead_neuron_analysis(ax6)
        
        # 7. Activation Range Analysis
        ax7 = fig.add_subplot(gs[3, :2])
        self._plot_activation_range_analysis(ax7)
        
        # 8. Model Architecture Summary
        ax8 = fig.add_subplot(gs[3, 2:])
        self._plot_model_architecture_summary(ax8)
        
        plt.suptitle('Comprehensive Activation Pattern Analysis', fontsize=18, fontweight='bold')
        
        # Save analysis
        plt.savefig(results_dir / 'activation_pattern_analysis.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        return fig, self._compile_activation_analysis()
    
    def _plot_activation_flow_overview(self, ax):
        """Plot overview of activation flow through the network."""
        if not self.activations:
            ax.text(0.5, 0.5, 'No activation data', ha='center', va='center')
            return
        
        layer_names = list(self.activations.keys())
        activation_means = [self.activations[name]['mean'] for name in layer_names]
        activation_stds = [self.activations[name]['std'] for name in layer_names]
        
        x = np.arange(len(layer_names))
        width = 0.35
        
        # Create grouped bar chart
        bars1 = ax.bar(x - width/2, activation_means, width, label='Mean Activation', 
                      alpha=0.8, color='skyblue')
        bars2 = ax.bar(x + width/2, activation_stds, width, label='Std Deviation', 
                      alpha=0.8, color='lightcoral')
        
        # Add value labels
        for bars in [bars1, bars2]:
            for bar in bars:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + 0.01 * max(activation_means + activation_stds),
                       f'{height:.3f}', ha='center', va='bottom', fontsize=8)
        
        ax.set_xlabel('Layers')
        ax.set_ylabel('Activation Values')
        ax.set_title('Activation Flow Through Network', fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels([name[:10] + '...' if len(name) > 10 else name for name in layer_names], 
                          rotation=45, ha='right')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    def _plot_activation_statistics_summary(self, ax):
        """Plot summary of activation statistics."""
        if not self.activations:
            ax.text(0.5, 0.5, 'No activation statistics', ha='center', va='center')
            return
        
        # Collect statistics
        stats_data = {
            'Mean': [act['mean'] for act in self.activations.values()],
            'Std': [act['std'] for act in self.activations.values()],
            'Min': [act['min'] for act in self.activations.values()],
            'Max': [act['max'] for act in self.activations.values()],
            'Sparsity': [act['sparsity'] for act in self.activations.values()]
        }
        
        # Create box plots
        bp = ax.boxplot([stats_data[key] for key in stats_data.keys()], 
                       labels=list(stats_data.keys()), patch_artist=True)
        
        # Color the boxes
        colors = ['lightblue', 'lightcoral', 'lightgreen', 'gold', 'plum']
        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)
            patch.set_alpha(0.7)
        
        ax.set_ylabel('Values')
        ax.set_title('Activation Statistics Distribution', fontweight='bold')
        ax.grid(True, alpha=0.3)
        
        # Add median values as text
        for i, (key, values) in enumerate(stats_data.items()):
            median_val = np.median(values)
            ax.text(i + 1, median_val, f'{median_val:.3f}', ha='center', va='bottom',
                   fontweight='bold', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    def _plot_activation_distributions(self, ax):
        """Plot activation value distributions for each layer."""
        if not self.activations:
            ax.text(0.5, 0.5, 'No activation data', ha='center', va='center')
            return
        
        # Sample a few representative layers
        layer_names = list(self.activations.keys())
        sample_layers = layer_names[::max(1, len(layer_names)//4)][:4]
        
        colors = plt.cm.Set3(np.linspace(0, 1, len(sample_layers)))
        
        for i, (layer_name, color) in enumerate(zip(sample_layers, colors)):
            activation_tensor = self.activations[layer_name]['tensor']
            flat_activations = activation_tensor.flatten().cpu().numpy()
            
            # Sample for visualization if too large
            if len(flat_activations) > 10000:
                flat_activations = np.random.choice(flat_activations, 10000, replace=False)
            
            ax.hist(flat_activations, bins=50, alpha=0.6, label=layer_name[:15], 
                   color=color, density=True)
        
        ax.set_xlabel('Activation Values')
        ax.set_ylabel('Density')
        ax.set_title('Activation Value Distributions by Layer', fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    def _plot_saturation_analysis(self, ax):
        """Plot activation saturation analysis."""
        if not self.activations:
            ax.text(0.5, 0.5, 'No saturation data', ha='center', va='center')
            return
        
        layer_names = list(self.activations.keys())
        saturations = [self.activations[name]['saturation'] for name in layer_names]
        layer_types = [self.layer_info[name]['type'] for name in layer_names]
        
        # Color by layer type
        type_colors = {'Linear': 'blue', 'ReLU': 'red', 'Sigmoid': 'green', 
                      'Tanh': 'orange', 'Conv2d': 'purple'}
        colors = [type_colors.get(layer_type, 'gray') for layer_type in layer_types]
        
        bars = ax.bar(range(len(saturations)), saturations, color=colors, alpha=0.7)
        
        # Add threshold line
        ax.axhline(y=0.5, color='red', linestyle='--', alpha=0.8, 
                  label='High Saturation Threshold')
        
        # Add value labels
        for bar, saturation in zip(bars, saturations):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                   f'{saturation:.2f}', ha='center', va='bottom', fontsize=8)
        
        ax.set_xlabel('Layers')
        ax.set_ylabel('Saturation Level')
        ax.set_title('Activation Saturation Analysis', fontweight='bold')
        ax.set_xticks(range(len(layer_names)))
        ax.set_xticklabels([f"{name[:8]}\n({self.layer_info[name]['type']})" 
                           for name in layer_names], rotation=45, ha='right')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Add warning for high saturation
        high_saturation_layers = [name for name, sat in zip(layer_names, saturations) if sat > 0.5]
        if high_saturation_layers:
            ax.text(0.02, 0.98, f'⚠️ High saturation in:\n{", ".join(high_saturation_layers[:3])}', 
                   transform=ax.transAxes, fontsize=10, verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.8))
    
    def _plot_activation_correlation_matrix(self, ax):
        """Plot correlation matrix of layer activations."""
        if len(self.activations) < 2:
            ax.text(0.5, 0.5, 'Need at least 2 layers for correlation', ha='center', va='center')
            return
        
        # Compute correlation matrix of layer means
        layer_names = list(self.activations.keys())
        activation_means = np.array([self.activations[name]['mean'] for name in layer_names])
        activation_stds = np.array([self.activations[name]['std'] for name in layer_names])
        activation_sparsity = np.array([self.activations[name]['sparsity'] for name in layer_names])
        
        # Stack features for correlation analysis
        features = np.column_stack([activation_means, activation_stds, activation_sparsity])
        correlation_matrix = np.corrcoef(features.T)
        
        # Create heatmap
        im = ax.imshow(correlation_matrix, cmap='coolwarm', vmin=-1, vmax=1, aspect='auto')
        
        # Add labels
        feature_labels = ['Mean', 'Std', 'Sparsity']
        ax.set_xticks(range(len(feature_labels)))
        ax.set_yticks(range(len(feature_labels)))
        ax.set_xticklabels(feature_labels)
        ax.set_yticklabels(feature_labels)
        
        # Add correlation values
        for i in range(len(feature_labels)):
            for j in range(len(feature_labels)):
                text = ax.text(j, i, f'{correlation_matrix[i, j]:.2f}',
                             ha="center", va="center", color="black", fontweight='bold')
        
        ax.set_title('Activation Feature Correlation Matrix', fontweight='bold')
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    def _plot_dead_neuron_analysis(self, ax):
        """Plot dead neuron detection analysis."""
        if not self.activations:
            ax.text(0.5, 0.5, 'No activation data', ha='center', va='center')
            return
        
        layer_names = []
        dead_neuron_ratios = []
        
        for name, activation in self.activations.items():
            if activation['tensor'].dim() >= 2:
                # Consider neurons dead if they have very low variance across samples
                tensor = activation['tensor']
                
                if tensor.dim() == 2:  # Fully connected layer
                    neuron_vars = tensor.var(dim=0)
                    dead_threshold = 1e-6
                    dead_neurons = (neuron_vars < dead_threshold).float().mean()
                elif tensor.dim() == 4:  # Convolutional layer
                    # Average over spatial dimensions, then check variance across batch
                    spatial_avg = tensor.mean(dim=(2, 3))
                    neuron_vars = spatial_avg.var(dim=0)
                    dead_threshold = 1e-6
                    dead_neurons = (neuron_vars < dead_threshold).float().mean()
                else:
                    continue
                
                layer_names.append(name)
                dead_neuron_ratios.append(float(dead_neurons))
        
        if layer_names:
            bars = ax.bar(range(len(dead_neuron_ratios)), dead_neuron_ratios, 
                         color='darkred', alpha=0.7)
            
            # Add threshold line
            ax.axhline(y=0.1, color='orange', linestyle='--', alpha=0.8, 
                      label='Concerning Threshold (10%)')
            
            # Add value labels
            for bar, ratio in zip(bars, dead_neuron_ratios):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + 0.005,
                       f'{ratio:.3f}', ha='center', va='bottom', fontsize=8)
            
            ax.set_xlabel('Layers')
            ax.set_ylabel('Dead Neuron Ratio')
            ax.set_title('Dead Neuron Detection', fontweight='bold')
            ax.set_xticks(range(len(layer_names)))
            ax.set_xticklabels([name[:10] for name in layer_names], rotation=45, ha='right')
            ax.legend()
            ax.grid(True, alpha=0.3)
            
            # Add summary
            avg_dead_ratio = np.mean(dead_neuron_ratios)
            ax.text(0.98, 0.98, f'Avg Dead Ratio: {avg_dead_ratio:.3f}', 
                   transform=ax.transAxes, fontsize=12, fontweight='bold',
                   verticalalignment='top', horizontalalignment='right',
                   bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
        else:
            ax.text(0.5, 0.5, 'No suitable layers for\ndead neuron analysis', 
                   ha='center', va='center')
            ax.set_title('Dead Neuron Analysis')
    
    def _plot_activation_range_analysis(self, ax):
        """Plot activation range analysis across layers."""
        if not self.activations:
            ax.text(0.5, 0.5, 'No activation data', ha='center', va='center')
            return
        
        layer_names = list(self.activations.keys())
        min_values = [self.activations[name]['min'] for name in layer_names]
        max_values = [self.activations[name]['max'] for name in layer_names]
        ranges = [max_val - min_val for min_val, max_val in zip(min_values, max_values)]
        
        x = np.arange(len(layer_names))
        
        # Create range plot
        ax.bar(x, ranges, alpha=0.7, color='lightgreen', label='Activation Range')
        
        # Add min/max markers
        ax.scatter(x, min_values, color='blue', s=50, alpha=0.8, label='Min Values', marker='v')
        ax.scatter(x, max_values, color='red', s=50, alpha=0.8, label='Max Values', marker='^')
        
        # Connect min and max with lines
        for i, (min_val, max_val) in enumerate(zip(min_values, max_values)):
            ax.plot([i, i], [min_val, max_val], 'k-', alpha=0.3, linewidth=1)
        
        ax.set_xlabel('Layers')
        ax.set_ylabel('Activation Values')
        ax.set_title('Activation Range Analysis', fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels([name[:10] for name in layer_names], rotation=45, ha='right')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Add range statistics
        mean_range = np.mean(ranges)
        std_range = np.std(ranges)
        ax.text(0.02, 0.98, f'Mean Range: {mean_range:.3f}\nStd Range: {std_range:.3f}', 
               transform=ax.transAxes, fontsize=10, verticalalignment='top',
               bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    def _plot_model_architecture_summary(self, ax):
        """Plot model architecture summary."""
        summary_text = f"""
🧠 MODEL ARCHITECTURE SUMMARY

📊 Model Statistics:
  Total Parameters: {self.model_info['total_parameters']:,}
  Trainable Parameters: {self.model_info['trainable_parameters']:,}
  Total Layers: {self.model_info['total_layers']}
  Model Size: {self.model_info['model_size_mb']:.2f} MB

🔍 Activation Analysis:
  Layers Analyzed: {len(self.activations)}
  
📈 Health Indicators:
  • Check activation ranges
  • Monitor saturation levels
  • Watch for dead neurons
  • Analyze activation flow

🎯 Optimization Tips:
  • Normalize inputs appropriately
  • Use proper initialization
  • Consider different activations
  • Monitor gradient flow
        """
        
        ax.text(0.05, 0.95, summary_text, transform=ax.transAxes, fontsize=10,
               verticalalignment='top', horizontalalignment='left',
               bbox=dict(boxstyle='round', facecolor='lightcyan', alpha=0.8))
        
        ax.set_title('Model & Analysis Summary', fontweight='bold')
        ax.axis('off')
    
    def _compile_activation_analysis(self):
        """Compile comprehensive activation analysis results."""
        if not self.activations:
            return {}
        
        results = {
            'model_info': self.model_info,
            'activation_summary': {
                'total_layers_analyzed': len(self.activations),
                'mean_activation': np.mean([act['mean'] for act in self.activations.values()]),
                'mean_sparsity': np.mean([act['sparsity'] for act in self.activations.values()]),
                'mean_saturation': np.mean([act['saturation'] for act in self.activations.values()]),
            },
            'layer_analysis': {}
        }
        
        # Per-layer analysis
        for name, activation in self.activations.items():
            results['layer_analysis'][name] = {
                'layer_type': self.layer_info[name]['type'],
                'activation_stats': {
                    'mean': activation['mean'],
                    'std': activation['std'],
                    'min': activation['min'],
                    'max': activation['max'],
                    'sparsity': activation['sparsity'],
                    'saturation': activation['saturation']
                },
                'shape': activation['shape'],
                'parameters': self.layer_info[name]['parameters']
            }
        
        # Health metrics
        saturations = [act['saturation'] for act in self.activations.values()]
        sparsities = [act['sparsity'] for act in self.activations.values()]
        
        results['health_metrics'] = {
            'high_saturation_layers': sum(1 for s in saturations if s > 0.5),
            'high_sparsity_layers': sum(1 for s in sparsities if s > 0.8),
            'average_saturation': np.mean(saturations),
            'average_sparsity': np.mean(sparsities),
            'activation_health_score': self._compute_activation_health_score()
        }
        
        return results
    
    def _compute_activation_health_score(self):
        """Compute overall activation health score."""
        if not self.activations:
            return 0.0
        
        # Factors for health score
        saturations = [act['saturation'] for act in self.activations.values()]
        sparsities = [act['sparsity'] for act in self.activations.values()]
        ranges = [act['max'] - act['min'] for act in self.activations.values()]
        
        # Penalize high saturation and extreme sparsity
        saturation_penalty = np.mean([min(s * 2, 1.0) for s in saturations])
        sparsity_penalty = np.mean([min((s - 0.5) * 2, 1.0) if s > 0.5 else 0 for s in sparsities])
        
        # Reward reasonable activation ranges
        range_score = np.mean([min(r / 10, 1.0) for r in ranges])
        
        # Compute overall health score (0-100)
        health_score = (1 - saturation_penalty - sparsity_penalty + range_score) / 2 * 100
        return max(0, min(100, health_score))
    
    def save_analysis(self, results):
        """Save activation analysis to JSON."""
        with open(results_dir / 'activation_pattern_analysis.json', 'w') as f:
            json.dump(results, f, indent=2)
        
        print(f"💾 Activation analysis saved to {results_dir / 'activation_pattern_analysis.json'}")
        return results
    
    def cleanup(self):
        """Remove hooks to prevent memory leaks."""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()

def demonstrate_activation_analysis():
    """Demonstrate activation pattern analysis with different models."""
    print("\n🔬 Activation Pattern Analysis Demonstrations")
    print("=" * 50)
    
    # Create sample data
    sample_input = torch.randn(10, 16)
    
    # Example 1: Healthy Network
    print("\n✅ Example 1: Healthy Network Activations")
    
    class HealthyActivationNet(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.Sequential(
                nn.Linear(16, 32),
                nn.ReLU(),
                nn.Linear(32, 24),
                nn.ReLU(),
                nn.Linear(24, 16),
                nn.ReLU(),
                nn.Linear(16, 8)
            )
        
        def forward(self, x):
            return self.layers(x)
    
    healthy_net = HealthyActivationNet()
    healthy_analyzer = ActivationAnalyzer(healthy_net)
    
    healthy_fig, healthy_results = healthy_analyzer.create_comprehensive_activation_analysis(sample_input)
    healthy_analyzer.save_analysis(healthy_results)
    
    print(f"Healthy Network Activation Health Score: {healthy_results['health_metrics']['activation_health_score']:.1f}")
    
    # Example 2: Problematic Network
    print("\n⚠️ Example 2: Problematic Network Activations")
    
    class ProblematicActivationNet(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.Sequential(
                nn.Linear(16, 64),
                nn.Sigmoid(),  # Can cause saturation
                nn.Linear(64, 64),
                nn.Sigmoid(),
                nn.Linear(64, 32),
                nn.Sigmoid(),
                nn.Linear(32, 8)
            )
            
            # Initialize with large weights to cause saturation
            for layer in self.layers:
                if isinstance(layer, nn.Linear):
                    nn.init.normal_(layer.weight, 0, 2.0)
        
        def forward(self, x):
            return self.layers(x)
    
    problem_net = ProblematicActivationNet()
    problem_analyzer = ActivationAnalyzer(problem_net)
    
    problem_fig, problem_results = problem_analyzer.create_comprehensive_activation_analysis(sample_input)
    problem_analyzer.save_analysis(problem_results)
    
    print(f"Problematic Network Activation Health Score: {problem_results['health_metrics']['activation_health_score']:.1f}")
    
    # Comparison
    print("\n📊 Activation Analysis Comparison:")
    print(f"Healthy Network:")
    print(f"  - Health Score: {healthy_results['health_metrics']['activation_health_score']:.1f}")
    print(f"  - Average Saturation: {healthy_results['health_metrics']['average_saturation']:.3f}")
    print(f"  - Average Sparsity: {healthy_results['health_metrics']['average_sparsity']:.3f}")
    
    print(f"\nProblematic Network:")
    print(f"  - Health Score: {problem_results['health_metrics']['activation_health_score']:.1f}")
    print(f"  - Average Saturation: {problem_results['health_metrics']['average_saturation']:.3f}")
    print(f"  - Average Sparsity: {problem_results['health_metrics']['average_sparsity']:.3f}")
    
    # Cleanup
    healthy_analyzer.cleanup()
    problem_analyzer.cleanup()
    
    return [healthy_analyzer, problem_analyzer], [healthy_results, problem_results]

# Run activation analysis demonstrations
activation_analyzers, activation_results = demonstrate_activation_analysis()

print(f"\n💡 Key Activation Analysis Insights:")
print("• Activation patterns reveal network health")
print("• Saturation analysis identifies problematic layers")
print("• Dead neuron detection prevents wasted capacity")
print("• Range analysis guides initialization strategies")
print("• Health scores provide quantitative assessment")
```

## 6. Training Dynamics Monitoring

### 6.1 Real-Time Training Visualizer

```python
class TrainingDynamicsMonitor:
    """Comprehensive real-time training dynamics monitoring and visualization."""
    
    def __init__(self, model, train_loader, val_loader, optimizer, loss_fn):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        
        # Training history
        self.training_history = {
            'epoch': [],
            'train_loss': [],
            'val_loss': [],
            'learning_rate': [],
            'grad_norm': [],
            'param_norm': [],
            'weight_updates': [],
            'batch_times': [],
            'memory_usage': []
        }
        
        # Per-layer monitoring
        self.layer_history = defaultdict(lambda: {
            'weight_norms': [],
            'grad_norms': [],
            'weight_changes': [],
            'learning_rates': []
        })
        
        # Real-time metrics
        self.current_metrics = {}
        self.optimization_health = {}
        
    def monitor_training_epoch(self, epoch, max_epochs):
        """Monitor one epoch of training with comprehensive metrics."""
        
        # Training phase
        train_metrics = self._monitor_training_phase(epoch)
        
        # Validation phase
        val_metrics = self._monitor_validation_phase(epoch)
        
        # Update histories
        self._update_training_history(epoch, train_metrics, val_metrics)
        
        # Compute optimization health metrics
        self._compute_optimization_health(epoch)
        
        # Print progress
        self._print_epoch_summary(epoch, max_epochs, train_metrics, val_metrics)
        
        return train_metrics, val_metrics
    
    def _monitor_training_phase(self, epoch):
        """Monitor training phase with detailed metrics."""
        self.model.train()
        
        batch_losses = []
        batch_times = []
        grad_norms = []
        
        epoch_start_time = time.time()
        
        for batch_idx, (data, target) in enumerate(self.train_loader):
            batch_start_time = time.time()
            
            # Forward pass
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.loss_fn(output, target)
            
            # Backward pass
            loss.backward()
            
            # Capture gradients before update
            grad_norm = self._compute_gradient_norm()
            grad_norms.append(grad_norm)
            
            # Monitor per-layer metrics
            self._monitor_layer_metrics(epoch, batch_idx)
            
            # Optimizer step
            self.optimizer.step()
            
            # Track metrics
            batch_losses.append(loss.item())
            batch_times.append(time.time() - batch_start_time)
            
            # Memory usage tracking
            if torch.cuda.is_available():
                memory_usage = torch.cuda.memory_allocated() / 1024**2  # MB
            else:
                memory_usage = 0
        
        # Compile training metrics
        train_metrics = {
            'avg_loss': np.mean(batch_losses),
            'loss_std': np.std(batch_losses),
            'avg_grad_norm': np.mean(grad_norms),
            'grad_norm_std': np.std(grad_norms),
            'avg_batch_time': np.mean(batch_times),
            'total_time': time.time() - epoch_start_time,
            'memory_usage': memory_usage,
            'learning_rate': self._get_current_lr()
        }
        
        return train_metrics
    
    def _monitor_validation_phase(self, epoch):
        """Monitor validation phase."""
        self.model.eval()
        
        val_losses = []
        
        with torch.no_grad():
            for data, target in self.val_loader:
                output = self.model(data)
                loss = self.loss_fn(output, target)
                val_losses.append(loss.item())
        
        val_metrics = {
            'avg_loss': np.mean(val_losses),
            'loss_std': np.std(val_losses)
        }
        
        return val_metrics
    
    def _monitor_layer_metrics(self, epoch, batch_idx):
        """Monitor per-layer training metrics."""
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                # Current norms
                weight_norm = param.norm().item()
                grad_norm = param.grad.norm().item()
                
                # Store in layer history (sample every N batches to avoid memory issues)
                if batch_idx % 10 == 0:  # Sample every 10 batches
                    self.layer_history[name]['weight_norms'].append(weight_norm)
                    self.layer_history[name]['grad_norms'].append(grad_norm)
                    
                    # Effective learning rate for this parameter
                    if isinstance(self.optimizer, torch.optim.Adam):
                        # For Adam, effective LR varies per parameter
                        state = self.optimizer.state.get(param, {})
                        if 'step' in state:
                            bias_correction1 = 1 - 0.9 ** state['step']
                            bias_correction2 = 1 - 0.999 ** state['step']
                            effective_lr = self._get_current_lr() * np.sqrt(bias_correction2) / bias_correction1
                        else:
                            effective_lr = self._get_current_lr()
                    else:
                        effective_lr = self._get_current_lr()
                    
                    self.layer_history[name]['learning_rates'].append(effective_lr)
    
    def _compute_gradient_norm(self):
        """Compute total gradient norm across all parameters."""
        total_norm = 0
        for param in self.model.parameters():
            if param.grad is not None:
                total_norm += param.grad.norm().item() ** 2
        return np.sqrt(total_norm)
    
    def _get_current_lr(self):
        """Get current learning rate from optimizer."""
        for param_group in self.optimizer.param_groups:
            return param_group['lr']
        return 0.0
    
    def _update_training_history(self, epoch, train_metrics, val_metrics):
        """Update training history with current epoch metrics."""
        self.training_history['epoch'].append(epoch)
        self.training_history['train_loss'].append(train_metrics['avg_loss'])
        self.training_history['val_loss'].append(val_metrics['avg_loss'])
        self.training_history['learning_rate'].append(train_metrics['learning_rate'])
        self.training_history['grad_norm'].append(train_metrics['avg_grad_norm'])
        self.training_history['batch_times'].append(train_metrics['avg_batch_time'])
        self.training_history['memory_usage'].append(train_metrics['memory_usage'])
        
        # Compute parameter norm
        param_norm = sum(p.norm().item() ** 2 for p in self.model.parameters()) ** 0.5
        self.training_history['param_norm'].append(param_norm)
    
    def _compute_optimization_health(self, epoch):
        """Compute optimization health metrics."""
        if len(self.training_history['train_loss']) < 2:
            return
        
        # Loss improvement rate
        recent_losses = self.training_history['train_loss'][-5:]
        if len(recent_losses) >= 2:
            loss_improvement = (recent_losses[0] - recent_losses[-1]) / max(recent_losses[0], 1e-8)
        else:
            loss_improvement = 0
        
        # Gradient stability
        recent_grad_norms = self.training_history['grad_norm'][-5:]
        grad_stability = 1.0 / (1.0 + np.std(recent_grad_norms)) if recent_grad_norms else 0
        
        # Learning rate appropriateness (based on loss oscillation)
        if len(self.training_history['train_loss']) >= 3:
            loss_changes = np.diff(self.training_history['train_loss'][-10:])
            oscillation = np.mean(np.abs(loss_changes))
            lr_appropriateness = 1.0 / (1.0 + oscillation * 10)
        else:
            lr_appropriateness = 0.5
        
        # Overfitting indicator
        if len(self.training_history['val_loss']) >= 2:
            train_val_gap = self.training_history['train_loss'][-1] - self.training_history['val_loss'][-1]
            overfitting_score = max(0, min(1, train_val_gap))
        else:
            overfitting_score = 0
        
        # Overall health score
        health_score = (loss_improvement + grad_stability + lr_appropriateness + (1 - overfitting_score)) / 4 * 100
        
        self.optimization_health = {
            'loss_improvement': loss_improvement,
            'grad_stability': grad_stability,
            'lr_appropriateness': lr_appropriateness,
            'overfitting_score': overfitting_score,
            'overall_health': health_score
        }
    
    def _print_epoch_summary(self, epoch, max_epochs, train_metrics, val_metrics):
        """Print comprehensive epoch summary."""
        health = self.optimization_health.get('overall_health', 0)
        
        print(f"Epoch {epoch+1}/{max_epochs}:")
        print(f"  Train Loss: {train_metrics['avg_loss']:.6f} ± {train_metrics['loss_std']:.6f}")
        print(f"  Val Loss:   {val_metrics['avg_loss']:.6f} ± {val_metrics['loss_std']:.6f}")
        print(f"  Grad Norm:  {train_metrics['avg_grad_norm']:.4f}")
        print(f"  Health:     {health:.1f}% {'✅' if health > 70 else '⚠️' if health > 40 else '❌'}")
        print(f"  Time:       {train_metrics['total_time']:.2f}s")
        print()
    
    def create_realtime_dashboard(self, figsize=(20, 16)):
        """Create comprehensive real-time training dashboard."""
        if not self.training_history['epoch']:
            print("No training history available")
            return None
        
        fig = plt.figure(figsize=figsize)
        gs = fig.add_gridspec(4, 4, hspace=0.4, wspace=0.3)
        
        # 1. Loss Curves
        ax1 = fig.add_subplot(gs[0, :2])
        self._plot_loss_curves(ax1)
        
        # 2. Learning Rate Schedule
        ax2 = fig.add_subplot(gs[0, 2:])
        self._plot_learning_rate_schedule(ax2)
        
        # 3. Gradient Norms
        ax3 = fig.add_subplot(gs[1, :2])
        self._plot_gradient_norms(ax3)
        
        # 4. Parameter Norms
        ax4 = fig.add_subplot(gs[1, 2:])
        self._plot_parameter_norms(ax4)
        
        # 5. Training Speed Metrics
        ax5 = fig.add_subplot(gs[2, :2])
        self._plot_training_speed(ax5)
        
        # 6. Optimization Health
        ax6 = fig.add_subplot(gs[2, 2:])
        self._plot_optimization_health(ax6)
        
        # 7. Layer-wise Analysis
        ax7 = fig.add_subplot(gs[3, :2])
        self._plot_layerwise_analysis(ax7)
        
        # 8. Training Summary
        ax8 = fig.add_subplot(gs[3, 2:])
        self._plot_training_summary(ax8)
        
        plt.suptitle('Real-Time Training Dynamics Dashboard', fontsize=18, fontweight='bold')
        
        # Save dashboard
        plt.savefig(results_dir / 'training_dynamics_dashboard.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        return fig
    
    def _plot_loss_curves(self, ax):
        """Plot training and validation loss curves."""
        epochs = self.training_history['epoch']
        train_losses = self.training_history['train_loss']
        val_losses = self.training_history['val_loss']
        
        ax.plot(epochs, train_losses, 'b-', linewidth=2, label='Training Loss', marker='o', markersize=4)
        ax.plot(epochs, val_losses, 'r-', linewidth=2, label='Validation Loss', marker='s', markersize=4)
        
        # Add trend lines
        if len(epochs) > 3:
            # Exponential smoothing
            train_smooth = self._exponential_smoothing(train_losses, alpha=0.3)
            val_smooth = self._exponential_smoothing(val_losses, alpha=0.3)
            
            ax.plot(epochs, train_smooth, 'b--', alpha=0.7, linewidth=1, label='Train Trend')
            ax.plot(epochs, val_smooth, 'r--', alpha=0.7, linewidth=1, label='Val Trend')
        
        # Mark best validation loss
        if val_losses:
            best_val_idx = np.argmin(val_losses)
            ax.plot(epochs[best_val_idx], val_losses[best_val_idx], 'g*', markersize=15, 
                   label=f'Best Val: {val_losses[best_val_idx]:.4f}')
        
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.set_title('Training & Validation Loss', fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_yscale('log')
    
    def _plot_learning_rate_schedule(self, ax):
        """Plot learning rate schedule."""
        epochs = self.training_history['epoch']
        learning_rates = self.training_history['learning_rate']
        
        ax.plot(epochs, learning_rates, 'g-', linewidth=2, marker='o', markersize=4)
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Learning Rate')
        ax.set_title('Learning Rate Schedule', fontweight='bold')
        ax.grid(True, alpha=0.3)
        ax.set_yscale('log')
        
        # Add current LR annotation
        if learning_rates:
            current_lr = learning_rates[-1]
            ax.text(0.98, 0.98, f'Current LR: {current_lr:.2e}', 
                   transform=ax.transAxes, fontsize=12, fontweight='bold',
                   verticalalignment='top', horizontalalignment='right',
                   bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))
    
    def _plot_gradient_norms(self, ax):
        """Plot gradient norm evolution."""
        epochs = self.training_history['epoch']
        grad_norms = self.training_history['grad_norm']
        
        ax.plot(epochs, grad_norms, 'purple', linewidth=2, marker='d', markersize=4)
        
        # Add gradient clipping thresholds
        ax.axhline(y=1.0, color='orange', linestyle='--', alpha=0.7, label='Typical Clip Threshold')
        ax.axhline(y=0.1, color='red', linestyle='--', alpha=0.7, label='Vanishing Threshold')
        
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Gradient Norm')
        ax.set_title('Gradient Norm Evolution', fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_yscale('log')
        
        # Gradient health indicator
        if grad_norms:
            recent_grad = grad_norms[-1]
            if recent_grad > 1.0:
                status = "🔥 High"
                color = 'red'
            elif recent_grad < 0.01:
                status = "❄️ Low"
                color = 'blue'
            else:
                status = "✅ Healthy"
                color = 'green'
            
            ax.text(0.02, 0.98, f'Gradient Status: {status}', 
                   transform=ax.transAxes, fontsize=12, fontweight='bold',
                   verticalalignment='top', color=color,
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    def _plot_parameter_norms(self, ax):
        """Plot parameter norm evolution."""
        epochs = self.training_history['epoch']
        param_norms = self.training_history['param_norm']
        
        ax.plot(epochs, param_norms, 'brown', linewidth=2, marker='v', markersize=4)
        
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Parameter Norm')
        ax.set_title('Parameter Norm Evolution', fontweight='bold')
        ax.grid(True, alpha=0.3)
        
        # Parameter change rate
        if len(param_norms) > 1:
            param_changes = np.diff(param_norms)
            recent_change = param_changes[-1] if param_changes.size > 0 else 0
            change_rate = recent_change / param_norms[-1] * 100 if param_norms[-1] > 0 else 0
            
            ax.text(0.98, 0.02, f'Change Rate: {change_rate:.2f}%', 
                   transform=ax.transAxes, fontsize=10, fontweight='bold',
                   verticalalignment='bottom', horizontalalignment='right',
                   bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
    
    def _plot_training_speed(self, ax):
        """Plot training speed metrics."""
        epochs = self.training_history['epoch']
        batch_times = self.training_history['batch_times']
        memory_usage = self.training_history['memory_usage']
        
        # Dual y-axis plot
        ax2 = ax.twinx()
        
        line1 = ax.plot(epochs, batch_times, 'orange', linewidth=2, marker='o', markersize=4, label='Batch Time (s)')
        line2 = ax2.plot(epochs, memory_usage, 'cyan', linewidth=2, marker='s', markersize=4, label='Memory (MB)')
        
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Batch Time (s)', color='orange')
        ax2.set_ylabel('Memory Usage (MB)', color='cyan')
        ax.set_title('Training Speed Metrics', fontweight='bold')
        
        # Combined legend
        lines = line1 + line2
        labels = [l.get_label() for l in lines]
        ax.legend(lines, labels, loc='upper left')
        
        ax.grid(True, alpha=0.3)
        
        # Speed summary
        if batch_times and memory_usage:
            avg_time = np.mean(batch_times)
            avg_memory = np.mean(memory_usage)
            
            ax.text(0.02, 0.98, f'Avg Batch Time: {avg_time:.3f}s\nAvg Memory: {avg_memory:.1f}MB', 
                   transform=ax.transAxes, fontsize=10, fontweight='bold',
                   verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    def _plot_optimization_health(self, ax):
        """Plot optimization health metrics."""
        if not self.optimization_health:
            ax.text(0.5, 0.5, 'No health data available', ha='center', va='center')
            ax.set_title('Optimization Health')
            return
        
        # Create radar chart for health metrics
        health_metrics = {
            'Loss\nImprovement': self.optimization_health['loss_improvement'],
            'Gradient\nStability': self.optimization_health['grad_stability'],
            'LR\nAppropriateness': self.optimization_health['lr_appropriateness'],
            'Anti-Overfitting': 1 - self.optimization_health['overfitting_score']
        }
        
        # Normalize metrics to [0, 1]
        normalized_metrics = {k: max(0, min(1, v)) for k, v in health_metrics.items()}
        
        # Radar chart
        angles = np.linspace(0, 2 * np.pi, len(normalized_metrics), endpoint=False)
        values = list(normalized_metrics.values())
        
        # Close the plot
        angles = np.concatenate((angles, [angles[0]]))
        values = np.concatenate((values, [values[0]]))
        
        ax.plot(angles, values, 'o-', linewidth=2, color='green', alpha=0.8)
        ax.fill(angles, values, alpha=0.25, color='green')
        
        # Add labels
        ax.set_xticks(angles[:-1])
        ax.set_xticklabels(list(normalized_metrics.keys()), fontsize=10)
        ax.set_ylim(0, 1)
        ax.set_title('Optimization Health Radar', fontweight='bold')
        ax.grid(True, alpha=0.3)
        
        # Overall health score
        overall_health = self.optimization_health['overall_health']
        color = 'green' if overall_health > 70 else 'orange' if overall_health > 40 else 'red'
        
        ax.text(0.02, 0.98, f'Overall Health: {overall_health:.1f}%', 
               transform=ax.transAxes, fontsize=14, fontweight='bold',
               verticalalignment='top', color=color,
               bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))
    
    def _plot_layerwise_analysis(self, ax):
        """Plot layer-wise gradient and weight analysis."""
        if not self.layer_history:
            ax.text(0.5, 0.5, 'No layer history available', ha='center', va='center')
            ax.set_title('Layer-wise Analysis')
            return
        
        # Get recent layer metrics
        layer_names = []
        recent_grad_norms = []
        recent_weight_norms = []
        
        for layer_name, history in self.layer_history.items():
            if history['grad_norms'] and history['weight_norms']:
                layer_names.append(layer_name.split('.')[-1][:10])  # Shortened name
                recent_grad_norms.append(history['grad_norms'][-1])
                recent_weight_norms.append(history['weight_norms'][-1])
        
        if layer_names:
            x = np.arange(len(layer_names))
            width = 0.35
            
            # Normalize for comparison
            grad_norms_norm = np.array(recent_grad_norms) / max(recent_grad_norms) if recent_grad_norms else []
            weight_norms_norm = np.array(recent_weight_norms) / max(recent_weight_norms) if recent_weight_norms else []
            
            bars1 = ax.bar(x - width/2, grad_norms_norm, width, label='Grad Norm (normalized)', 
                          alpha=0.8, color='red')
            bars2 = ax.bar(x + width/2, weight_norms_norm, width, label='Weight Norm (normalized)', 
                          alpha=0.8, color='blue')
            
            ax.set_xlabel('Layers')
            ax.set_ylabel('Normalized Values')
            ax.set_title('Layer-wise Gradient & Weight Norms', fontweight='bold')
            ax.set_xticks(x)
            ax.set_xticklabels(layer_names, rotation=45, ha='right')
            ax.legend()
            ax.grid(True, alpha=0.3)
        else:
            ax.text(0.5, 0.5, 'Insufficient layer data', ha='center', va='center')
    
    def _plot_training_summary(self, ax):
        """Plot comprehensive training summary."""
        if not self.training_history['epoch']:
            ax.text(0.5, 0.5, 'No training data', ha='center', va='center')
            return
        
        # Calculate summary statistics
        current_epoch = self.training_history['epoch'][-1] if self.training_history['epoch'] else 0
        best_train_loss = min(self.training_history['train_loss']) if self.training_history['train_loss'] else 0
        best_val_loss = min(self.training_history['val_loss']) if self.training_history['val_loss'] else 0
        current_train_loss = self.training_history['train_loss'][-1] if self.training_history['train_loss'] else 0
        current_val_loss = self.training_history['val_loss'][-1] if self.training_history['val_loss'] else 0
        
        # Convergence analysis
        if len(self.training_history['train_loss']) >= 5:
            recent_losses = self.training_history['train_loss'][-5:]
            loss_stability = np.std(recent_losses) / np.mean(recent_losses) if np.mean(recent_losses) > 0 else float('inf')
            converged = loss_stability < 0.01
        else:
            loss_stability = float('inf')
            converged = False
        
        # Training efficiency
        total_time = sum(self.training_history['batch_times']) * len(self.training_history['epoch'])
        avg_improvement_per_epoch = (self.training_history['train_loss'][0] - current_train_loss) / max(current_epoch, 1) if self.training_history['train_loss'] else 0
        
        summary_text = f"""
🚀 TRAINING SUMMARY

📊 Progress:
  Current Epoch: {current_epoch + 1}
  Best Train Loss: {best_train_loss:.6f}
  Best Val Loss: {best_val_loss:.6f}
  Current Train Loss: {current_train_loss:.6f}
  Current Val Loss: {current_val_loss:.6f}

📈 Convergence:
  Loss Stability: {loss_stability:.4f}
  Converged: {'✅ Yes' if converged else '⏳ No'}
  Avg Improvement/Epoch: {avg_improvement_per_epoch:.6f}

⚡ Performance:
  Total Training Time: {total_time:.1f}s
  Health Score: {self.optimization_health.get('overall_health', 0):.1f}%

🎯 Status: {'🎉 Training Complete!' if converged else '🔄 Training in Progress'}
        """
        
        ax.text(0.05, 0.95, summary_text, transform=ax.transAxes, fontsize=11,
               verticalalignment='top', horizontalalignment='left',
               bbox=dict(boxstyle='round', facecolor='lightcyan', alpha=0.8))
        
        ax.set_title('Training Session Summary', fontweight='bold')
        ax.axis('off')
    
    def _exponential_smoothing(self, data, alpha=0.3):
        """Apply exponential smoothing to data."""
        smoothed = [data[0]]
        for i in range(1, len(data)):
            smoothed.append(alpha * data[i] + (1 - alpha) * smoothed[-1])
        return smoothed
    
    def save_training_analysis(self):
        """Save comprehensive training analysis to JSON."""
        analysis_data = {
            'training_summary': {
                'total_epochs': len(self.training_history['epoch']),
                'best_train_loss': float(min(self.training_history['train_loss'])) if self.training_history['train_loss'] else None,
                'best_val_loss': float(min(self.training_history['val_loss'])) if self.training_history['val_loss'] else None,
                'final_train_loss': float(self.training_history['train_loss'][-1]) if self.training_history['train_loss'] else None,
                'final_val_loss': float(self.training_history['val_loss'][-1]) if self.training_history['val_loss'] else None,
                'total_parameters': sum(p.numel() for p in self.model.parameters()),
                'trainable_parameters': sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            },
            'optimization_health': self.optimization_health,
            'training_efficiency': {
                'avg_batch_time': float(np.mean(self.training_history['batch_times'])) if self.training_history['batch_times'] else None,
                'avg_memory_usage': float(np.mean(self.training_history['memory_usage'])) if self.training_history['memory_usage'] else None,
            }
        }
        
        # Add convergence analysis
        if len(self.training_history['train_loss']) >= 5:
            recent_losses = self.training_history['train_loss'][-5:]
            loss_stability = np.std(recent_losses) / np.mean(recent_losses) if np.mean(recent_losses) > 0 else float('inf')
            analysis_data['convergence_analysis'] = {
                'loss_stability': float(loss_stability),
                'converged': bool(loss_stability < 0.01),
                'epochs_to_convergence': None  # Could be computed based on criteria
            }
        
        with open(results_dir / 'training_dynamics_analysis.json', 'w') as f:
            json.dump(analysis_data, f, indent=2)
        
        print(f"💾 Training dynamics analysis saved to {results_dir / 'training_dynamics_analysis.json'}")
        return analysis_data

def demonstrate_training_dynamics_monitoring():
    """Demonstrate comprehensive training dynamics monitoring."""
    print("\n⚡ Training Dynamics Monitoring Demonstration")
    print("=" * 50)
    
    # Create sample dataset
    torch.manual_seed(42)
    X_train = torch.randn(1000, 20)
    y_train = torch.sum(X_train[:, :10], dim=1, keepdim=True) + 0.1 * torch.randn(1000, 1)
    
    X_val = torch.randn(200, 20)
    y_val = torch.sum(X_val[:, :10], dim=1, keepdim=True) + 0.1 * torch.randn(200, 1)
    
    train_dataset = TensorDataset(X_train, y_train)
    val_dataset = TensorDataset(X_val, y_val)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    
    # Create model
    class MonitoringTestNet(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.Sequential(
                nn.Linear(20, 64),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(64, 32),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(32, 16),
                nn.ReLU(),
                nn.Linear(16, 1)
            )
        
        def forward(self, x):
            return self.layers(x)
    
    model = MonitoringTestNet()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.MSELoss()
    
    # Create training monitor
    monitor = TrainingDynamicsMonitor(model, train_loader, val_loader, optimizer, loss_fn)
    
    print("🏃‍♂️ Starting monitored training...")
    
    # Training loop with monitoring
    num_epochs = 20
    for epoch in range(num_epochs):
        train_metrics, val_metrics = monitor.monitor_training_epoch(epoch, num_epochs)
        
        # Create dashboard every 5 epochs
        if (epoch + 1) % 5 == 0:
            print(f"\n📊 Creating dashboard at epoch {epoch + 1}...")
            dashboard_fig = monitor.create_realtime_dashboard()
    
    # Final analysis
    print("\n📋 Generating final training analysis...")
    final_analysis = monitor.save_training_analysis()
    
    print(f"\n✅ Training monitoring complete!")
    print(f"Final Results:")
    print(f"  - Best Train Loss: {final_analysis['training_summary']['best_train_loss']:.6f}")
    print(f"  - Best Val Loss: {final_analysis['training_summary']['best_val_loss']:.6f}")
    print(f"  - Optimization Health: {final_analysis['optimization_health']['overall_health']:.1f}%")
    print(f"  - Training Efficiency: {final_analysis['training_efficiency']['avg_batch_time']:.4f}s/batch")
    
    return monitor, final_analysis

# Run training dynamics monitoring demonstration
training_monitor, training_analysis = demonstrate_training_dynamics_monitoring()

print(f"\n💡 Key Training Dynamics Insights:")
print("• Real-time monitoring enables proactive intervention")
print("• Health metrics quantify training quality")
print("• Layer-wise analysis identifies bottlenecks")
print("• Convergence tracking optimizes stopping criteria")
print("• Performance metrics guide resource allocation")
```

## 7. Comprehensive Summary and Mastery Assessment

### 7.1 Final Integration and Assessment

```python
def create_comprehensive_mastery_assessment():
    """Create comprehensive assessment of backpropagation visualization mastery."""
    print("\n🎨 BACKPROPAGATION VISUALIZATION MASTERY ASSESSMENT")
    print("=" * 70)
    
    # Assessment categories and scores
    mastery_categories = {
        'Computational Graph Visualization': {
            'score': 0,
            'max_score': 100,
            'components': [
                'Graph construction and layout',
                'Node and edge visualization',
                'Memory analysis integration',
                'Interactive exploration capabilities'
            ]
        },
        'Gradient Flow Animation': {
            'score': 0,
            'max_score': 100,
            'components': [
                'Real-time gradient capture',
                'Flow direction visualization',
                'Magnitude analysis',
                'Health diagnostics'
            ]
        },
        'Loss Landscape Exploration': {
            'score': 0,
            'max_score': 100,
            'components': [
                '3D surface visualization',
                'Optimization path tracking',
                'Curvature analysis',
                'Difficulty assessment'
            ]
        },
        'Activation Pattern Analysis': {
            'score': 0,
            'max_score': 100,
            'components': [
                'Multi-layer activation monitoring',
                'Saturation detection',
                'Dead neuron analysis',
                'Health scoring system'
            ]
        },
        'Training Dynamics Monitoring': {
            'score': 0,
            'max_score': 100,
            'components': [
                'Real-time metric tracking',
                'Convergence analysis',
                'Performance monitoring',
                'Health assessment'
            ]
        }
    }
    
    # Calculate scores based on completed demonstrations
    total_score = 0
    max_total_score = 0
    
    for category, details in mastery_categories.items():
        # Simulate scoring based on successful completion
        # In real implementation, this would be based on actual performance
        details['score'] = 85  # High score for demonstration
        total_score += details['score']
        max_total_score += details['max_score']
    
    overall_mastery = (total_score / max_total_score) * 100
    
    # Create mastery visualization
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    
    # 1. Mastery Score Radar Chart
    categories = list(mastery_categories.keys())
    scores = [details['score'] for details in mastery_categories.values()]
    
    # Radar chart
    angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False)
    scores_plot = scores + [scores[0]]  # Close the plot
    angles_plot = np.concatenate((angles, [angles[0]]))
    
    ax1.plot(angles_plot, scores_plot, 'o-', linewidth=3, color='blue', alpha=0.8)
    ax1.fill(angles_plot, scores_plot, alpha=0.25, color='blue')
    ax1.set_xticks(angles)
    ax1.set_xticklabels([cat.replace(' ', '\n') for cat in categories], fontsize=10)
    ax1.set_ylim(0, 100)
    ax1.set_title('Mastery Assessment Radar', fontweight='bold', fontsize=14)
    ax1.grid(True, alpha=0.3)
    
    # Add score annotations
    for angle, score, cat in zip(angles, scores, categories):
        x = (score + 10) * np.cos(angle)
        y = (score + 10) * np.sin(angle)
        ax1.annotate(f'{score}%', (x, y), ha='center', va='center', 
                    fontweight='bold', fontsize=9,
                    bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))
    
    # 2. Component Breakdown
    all_components = []
    component_categories = []
    
    for cat, details in mastery_categories.items():
        for comp in details['components']:
            all_components.append(comp[:20] + '...' if len(comp) > 20 else comp)
            component_categories.append(cat)
    
    # Color map for categories
    color_map = plt.cm.Set3(np.linspace(0, 1, len(categories)))
    category_colors = {cat: color for cat, color in zip(categories, color_map)}
    
    y_pos = np.arange(len(all_components))
    colors = [category_colors[cat] for cat in component_categories]
    
    bars = ax2.barh(y_pos, [85] * len(all_components), color=colors, alpha=0.7)
    ax2.set_yticks(y_pos)
    ax2.set_yticklabels(all_components, fontsize=8)
    ax2.set_xlabel('Mastery Score')
    ax2.set_title('Component Mastery Breakdown', fontweight='bold')
    ax2.grid(True, alpha=0.3, axis='x')
    
    # Add category legend
    legend_elements = [plt.Rectangle((0,0),1,1, facecolor=category_colors[cat], alpha=0.7, label=cat[:15]+'...' if len(cat) > 15 else cat) 
                      for cat in categories]
    ax2.legend(handles=legend_elements, loc='lower right', fontsize=8)
    
    # 3. Skill Progression Timeline
    skill_timeline = {
        'Basic Graph Visualization': 1,
        'Interactive Exploration': 2,
        'Gradient Flow Animation': 3,
        'Real-time Monitoring': 4,
        '3D Loss Landscapes': 5,
        'Advanced Diagnostics': 6,
        'Comprehensive Analysis': 7,
        'Mastery Integration': 8
    }
    
    skills = list(skill_timeline.keys())
    weeks = list(skill_timeline.values())
    
    ax3.plot(weeks, range(len(skills)), 'o-', linewidth=3, markersize=8, color='green', alpha=0.8)
    
    for i, (skill, week) in enumerate(zip(skills, weeks)):
        ax3.annotate(skill, (week, i), xytext=(10, 0), textcoords='offset points',
                    ha='left', va='center', fontsize=9,
                    bbox=dict(boxstyle='round,pad=0.3', facecolor='lightgreen', alpha=0.7))
    
    ax3.set_xlabel('Learning Timeline (Weeks)')
    ax3.set_ylabel('Skill Level')
    ax3.set_title('Skill Progression Path', fontweight='bold')
    ax3.set_yticks(range(len(skills)))
    ax3.set_yticklabels([f'Level {i+1}' for i in range(len(skills))])
    ax3.grid(True, alpha=0.3)
    
    # 4. Overall Mastery Summary
    mastery_level = 'EXPERT' if overall_mastery >= 90 else 'ADVANCED' if overall_mastery >= 75 else 'INTERMEDIATE' if overall_mastery >= 60 else 'BEGINNER'
    mastery_color = 'gold' if overall_mastery >= 90 else 'silver' if overall_mastery >= 75 else 'lightblue' if overall_mastery >= 60 else 'lightcoral'
    
    # Create mastery certificate
    certificate_text = f"""
🏆 MASTERY CERTIFICATE 🏆

PyTorch Backpropagation
Visualization Mastery

Overall Score: {overall_mastery:.1f}%
Mastery Level: {mastery_level}

✅ Computational Graphs: {mastery_categories['Computational Graph Visualization']['score']}%
✅ Gradient Flow: {mastery_categories['Gradient Flow Animation']['score']}%
✅ Loss Landscapes: {mastery_categories['Loss Landscape Exploration']['score']}%
✅ Activation Analysis: {mastery_categories['Activation Pattern Analysis']['score']}%
✅ Training Dynamics: {mastery_categories['Training Dynamics Monitoring']['score']}%

🎓 Certified PyTorch Visualization Expert
Ready for Advanced Deep Learning Research!
    """
    
    ax4.text(0.5, 0.5, certificate_text, ha='center', va='center', 
            transform=ax4.transAxes, fontsize=12, fontweight='bold',
            bbox=dict(boxstyle='round,pad=1', facecolor=mastery_color, alpha=0.8))
    ax4.set_title('Mastery Certification', fontweight='bold', fontsize=16)
    ax4.axis('off')
    
    plt.suptitle(f'🎨 Backpropagation Visualization Mastery Assessment\nOverall Score: {overall_mastery:.1f}% - {mastery_level}', 
                fontsize=18, fontweight='bold')
    plt.tight_layout()
    
    # Save mastery assessment
    plt.savefig(results_dir / 'mastery_assessment.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    return {
        'overall_mastery': overall_mastery,
        'mastery_level': mastery_level,
        'category_scores': mastery_categories,
        'total_score': total_score,
        'max_score': max_total_score
    }

def generate_final_summary_report():
    """Generate comprehensive final summary report."""
    print("\n📋 COMPREHENSIVE FINAL SUMMARY REPORT")
    print("=" * 70)
    
    # Collect all generated analyses
    generated_files = list(results_dir.glob('*.json')) + list(results_dir.glob('*.png'))
    
    summary_report = {
        'analysis_completion_timestamp': pd.Timestamp.now().isoformat(),
        'total_visualizations_created': len(list(results_dir.glob('*.png'))),
        'total_analyses_saved': len(list(results_dir.glob('*.json'))),
        'mastery_modules_completed': [
            'Computational Graph Visualization',
            'Gradient Flow Animation',
            'Loss Landscape Exploration', 
            'Activation Pattern Analysis',
            'Training Dynamics Monitoring'
        ],
        'key_innovations': [
            'Interactive computational graph builder with memory analysis',
            'Real-time gradient flow animation with health diagnostics',
            '3D loss landscape exploration with optimization path tracking',
            'Comprehensive activation pattern analysis with saturation detection',
            'Real-time training dynamics monitoring with health scoring'
        ],
        'technical_achievements': [
            'Multi-dimensional visualization frameworks',
            'Real-time performance monitoring systems',
            'Advanced diagnostic and health assessment tools',
            'Interactive exploration interfaces',
            'Comprehensive data analysis pipelines'
        ],
        'practical_applications': [
            'Neural network architecture debugging',
            'Training optimization and intervention',
            'Research visualization and publication',
            'Educational demonstration and teaching',
            'Performance analysis and benchmarking'
        ]
    }
    
    # Save final summary
    with open(results_dir / 'final_summary_report.json', 'w') as f:
        json.dump(summary_report, f, indent=2)
    
    # Print summary
    print(f"📊 Analysis Summary:")
    print(f"  - Visualizations Created: {summary_report['total_visualizations_created']}")
    print(f"  - Analyses Completed: {summary_report['total_analyses_saved']}")
    print(f"  - Mastery Modules: {len(summary_report['mastery_modules_completed'])}")
    
    print(f"\n🚀 Key Innovations:")
    for innovation in summary_report['key_innovations']:
        print(f"  • {innovation}")
    
    print(f"\n🛠️ Technical Achievements:")
    for achievement in summary_report['technical_achievements']:
        print(f"  • {achievement}")
    
    print(f"\n🎯 Practical Applications:")
    for application in summary_report['practical_applications']:
        print(f"  • {application}")
    
    print(f"\n📂 Generated Files:")
    for file_path in sorted(generated_files):
        file_size = file_path.stat().st_size / 1024  # KB
        file_type = "📊 Analysis" if file_path.suffix == '.json' else "🖼️ Visualization"
        print(f"  {file_type}: {file_path.name} ({file_size:.1f} KB)")
    
    print(f"\n💾 Complete analysis package saved to: {results_dir}")
    
    return summary_report

# Run comprehensive mastery assessment
print("\n🎓 Running Comprehensive Mastery Assessment...")
mastery_results = create_comprehensive_mastery_assessment()

print(f"\n📋 Generating Final Summary Report...")
final_summary = generate_final_summary_report()

print(f"\n🎉 BACKPROPAGATION VISUALIZATION MASTERY COMPLETE! 🎉")
print("=" * 70)
print(f"🏆 Overall Mastery Level: {mastery_results['mastery_level']}")
print(f"📊 Final Score: {mastery_results['overall_mastery']:.1f}%")
print(f"✅ All visualization modules successfully completed!")
print(f"📁 Complete analysis saved to: {results_dir}")
print(f"\n🚀 Ready for advanced deep learning research and development!")
print(f"🎯 Next recommended module: 03_neural_networks/")
print("\n🌟 Congratulations on achieving PyTorch Visualization Mastery! 🌟")
```

## Summary and Key Achievements

This comprehensive backpropagation visualization mastery notebook has successfully demonstrated:

### 🎨 **Visualization Mastery Achievements**
- **Interactive Computational Graphs**: Built dynamic graph visualizers with memory analysis
- **Real-time Gradient Flow**: Created animated gradient flow monitoring with health diagnostics  
- **3D Loss Landscapes**: Developed immersive loss surface exploration with optimization tracking
- **Activation Pattern Analysis**: Implemented comprehensive activation monitoring with saturation detection
- **Training Dynamics Monitoring**: Built real-time training dashboards with performance analytics

### 📊 **Technical Innovations**
- Advanced visualization frameworks with interactive capabilities
- Real-time performance monitoring and health assessment systems
- Multi-dimensional analysis tools for complex neural network behavior
- Comprehensive diagnostic and debugging interfaces
- Educational and research-grade visualization tools

### 🎯 **Practical Applications**
- Neural network architecture debugging and optimization
- Training intervention and adaptive strategy development
- Research visualization for publication and presentation
- Educational demonstration and teaching enhancement
- Performance benchmarking and comparative analysis

### 📁 **Comprehensive Documentation**
- Complete analysis results saved to structured directory
- JSON data files for programmatic access and further analysis
- High-resolution visualizations for research and presentation
- Detailed assessment and mastery certification
- Ready-to-use code modules for integration

### 🚀 **Ready for Advanced Applications**
- Model architecture research and development
- Training optimization and hyperparameter tuning
- Publication-quality figure generation
- Educational content creation
- Deep learning system diagnostics

**All visualizations, analyses, and assessment results have been systematically organized and saved for future reference and application. This mastery-level understanding of backpropagation visualization provides a solid foundation for advanced deep learning research and development.**