# Visualization Debugging and Troubleshooting

This notebook demonstrates techniques for debugging and troubleshooting visualization issues:
- Memory Leak Detection
- Performance Profiling
- Error Handling
- Visual Regression Testing

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import networkx as nx
import memory_profiler
import cProfile
import pstats
import io
import time
from PIL import Image
import imagehash
import logging

from ncps.mlx import CfC, LTC
from ncps.mlx.wirings import Random, NCP, AutoNCP
from ncps.mlx.visualization import WiringVisualizer, PerformanceVisualizer

# Setup logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger('visualization_debug')

## 1. Memory Leak Detection

Detect and fix memory leaks in visualizations:

In [None]:
class MemoryDebugVisualizer(WiringVisualizer):
    """Visualization with memory debugging capabilities."""
    
    def __init__(self, wiring):
        super().__init__(wiring)
        self.memory_tracker = []
    
    @memory_profiler.profile
    def plot_with_memory_tracking(self):
        """Plot with memory usage tracking."""
        # Record initial memory
        initial_mem = memory_profiler.memory_usage()[0]
        
        try:
            # Create figure
            plt.figure(figsize=(10, 10))
            
            # Plot network
            pos = nx.spring_layout(self.graph)
            nx.draw(
                self.graph,
                pos=pos,
                node_color='lightblue',
                with_labels=True
            )
            
            # Record memory after plotting
            final_mem = memory_profiler.memory_usage()[0]
            self.memory_tracker.append({
                'initial': initial_mem,
                'final': final_mem,
                'diff': final_mem - initial_mem
            })
            
            plt.show()
            
        finally:
            # Cleanup
            plt.close('all')
    
    def check_memory_leaks(self, n_iterations=10):
        """Check for memory leaks over multiple iterations."""
        for i in range(n_iterations):
            logger.info(f'Iteration {i+1}/{n_iterations}')
            self.plot_with_memory_tracking()
        
        # Analyze memory usage
        diffs = [record['diff'] for record in self.memory_tracker]
        increasing = all(diffs[i] <= diffs[i+1] for i in range(len(diffs)-1))
        
        if increasing:
            logger.warning('Potential memory leak detected: Memory usage consistently increasing')
        else:
            logger.info('No memory leaks detected')
        
        return diffs

# Example usage
wiring = Random(units=100, sparsity_level=0.1)
visualizer = MemoryDebugVisualizer(wiring)

# Check for memory leaks
memory_diffs = visualizer.check_memory_leaks()

# Plot memory usage
plt.figure(figsize=(10, 5))
plt.plot(memory_diffs)
plt.title('Memory Usage Over Iterations')
plt.xlabel('Iteration')
plt.ylabel('Memory Difference (MB)')
plt.show()

## 2. Performance Profiling

Profile visualization performance:

In [None]:
class PerformanceDebugVisualizer(WiringVisualizer):
    """Visualization with performance debugging capabilities."""
    
    def __init__(self, wiring):
        super().__init__(wiring)
        self.profiler = cProfile.Profile()
        self.performance_stats = []
    
    def profile_visualization(self):
        """Profile visualization creation."""
        # Start profiling
        self.profiler.enable()
        
        try:
            # Create visualization
            fig = go.Figure()
            
            # Add network elements
            pos = nx.spring_layout(self.graph)
            
            # Add edges
            edge_x = []
            edge_y = []
            for edge in self.graph.edges():
                x0, y0 = pos[edge[0]]
                x1, y1 = pos[edge[1]]
                edge_x.extend([x0, x1, None])
                edge_y.extend([y0, y1, None])
            
            fig.add_trace(go.Scatter(
                x=edge_x,
                y=edge_y,
                mode='lines',
                line=dict(color='gray', width=1)
            ))
            
            # Add nodes
            node_x = [pos[node][0] for node in self.graph.nodes()]
            node_y = [pos[node][1] for node in self.graph.nodes()]
            
            fig.add_trace(go.Scatter(
                x=node_x,
                y=node_y,
                mode='markers',
                marker=dict(
                    size=10,
                    color='lightblue'
                )
            ))
            
            return fig
            
        finally:
            # Stop profiling
            self.profiler.disable()
    
    def analyze_performance(self, n_runs=5):
        """Analyze performance over multiple runs."""
        for i in range(n_runs):
            logger.info(f'Run {i+1}/{n_runs}')
            
            # Profile visualization
            start_time = time.time()
            fig = self.profile_visualization()
            end_time = time.time()
            
            # Get stats
            s = io.StringIO()
            ps = pstats.Stats(self.profiler, stream=s).sort_stats('cumulative')
            ps.print_stats()
            
            # Store performance data
            self.performance_stats.append({
                'run': i+1,
                'time': end_time - start_time,
                'stats': s.getvalue()
            })
        
        return self.performance_stats

# Example usage
visualizer = PerformanceDebugVisualizer(wiring)

# Analyze performance
stats = visualizer.analyze_performance()

# Plot performance results
times = [stat['time'] for stat in stats]
plt.figure(figsize=(10, 5))
plt.plot(times)
plt.title('Visualization Performance')
plt.xlabel('Run')
plt.ylabel('Time (seconds)')
plt.show()

## 3. Error Handling

Implement robust error handling:

In [None]:
class RobustVisualizer(WiringVisualizer):
    """Visualization with robust error handling."""
    
    def __init__(self, wiring):
        super().__init__(wiring)
        self.error_log = []
    
    def create_visualization(self):
        """Create visualization with error handling."""
        try:
            # Validate input
            if not self._validate_input():
                raise ValueError('Invalid input data')
            
            # Create figure
            fig = go.Figure()
            
            try:
                # Add network elements
                self._add_network_elements(fig)
            except Exception as e:
                logger.error(f'Error adding network elements: {e}')
                self.error_log.append({
                    'type': 'network_elements',
                    'error': str(e)
                })
                # Create fallback visualization
                self._create_fallback_visualization(fig)
            
            try:
                # Add interactive elements
                self._add_interactive_elements(fig)
            except Exception as e:
                logger.warning(f'Error adding interactive elements: {e}')
                self.error_log.append({
                    'type': 'interactive_elements',
                    'error': str(e)
                })
            
            return fig
            
        except Exception as e:
            logger.error(f'Critical error in visualization creation: {e}')
            self.error_log.append({
                'type': 'critical',
                'error': str(e)
            })
            return self._create_error_visualization()
    
    def _validate_input(self):
        """Validate input data."""
        try:
            # Check graph
            if not self.graph or len(self.graph) == 0:
                return False
            
            # Check adjacency matrix
            if self.wiring.adjacency_matrix.size == 0:
                return False
            
            return True
        except Exception as e:
            logger.error(f'Validation error: {e}')
            return False
    
    def _add_network_elements(self, fig):
        """Add network elements to figure."""
        pos = nx.spring_layout(self.graph)
        
        # Add edges
        edge_x = []
        edge_y = []
        for edge in self.graph.edges():
            x0, y0 = pos[edge[0]]
            x1, y1 = pos[edge[1]]
            edge_x.extend([x0, x1, None])
            edge_y.extend([y0, y1, None])
        
        fig.add_trace(go.Scatter(
            x=edge_x,
            y=edge_y,
            mode='lines',
            line=dict(color='gray', width=1)
        ))
        
        # Add nodes
        node_x = [pos[node][0] for node in self.graph.nodes()]
        node_y = [pos[node][1] for node in self.graph.nodes()]
        
        fig.add_trace(go.Scatter(
            x=node_x,
            y=node_y,
            mode='markers',
            marker=dict(size=10, color='lightblue')
        ))
    
    def _add_interactive_elements(self, fig):
        """Add interactive elements to figure."""
        fig.update_layout(
            showlegend=False,
            hovermode='closest',
            margin=dict(b=20, l=5, r=5, t=40)
        )
    
    def _create_fallback_visualization(self, fig):
        """Create simple fallback visualization."""
        fig.add_trace(go.Scatter(
            x=[0],
            y=[0],
            mode='markers+text',
            text=['Error: Using fallback visualization'],
            textposition='bottom center'
        ))
    
    def _create_error_visualization(self):
        """Create error visualization."""
        fig = go.Figure()
        
        fig.add_trace(go.Scatter(
            x=[0],
            y=[0],
            mode='markers+text',
            text=['Error: Unable to create visualization'],
            textposition='bottom center'
        ))
        
        fig.update_layout(
            title='Visualization Error',
            showlegend=False
        )
        
        return fig

# Example usage
visualizer = RobustVisualizer(wiring)

# Create visualization
fig = visualizer.create_visualization()
fig.show()

# Check error log
if visualizer.error_log:
    print('\nErrors encountered:')
    for error in visualizer.error_log:
        print(f"Type: {error['type']}, Error: {error['error']}")

## 4. Visual Regression Testing

Implement visual regression testing:

In [None]:
class VisualRegressionTester:
    """Visual regression testing for visualizations."""
    
    def __init__(self, reference_dir='reference_images'):
        self.reference_dir = reference_dir
        os.makedirs(reference_dir, exist_ok=True)
    
    def capture_reference(self, name, fig):
        """Capture reference image."""
        path = os.path.join(self.reference_dir, f'{name}.png')
        fig.write_image(path)
        logger.info(f'Captured reference image: {path}')
    
    def compare_with_reference(self, name, fig, threshold=5):
        """Compare with reference image."""
        reference_path = os.path.join(self.reference_dir, f'{name}.png')
        
        if not os.path.exists(reference_path):
            logger.warning(f'Reference image not found: {reference_path}')
            return False
        
        # Save current image
        current_path = os.path.join(self.reference_dir, 'current.png')
        fig.write_image(current_path)
        
        # Compare images
        reference_hash = imagehash.average_hash(Image.open(reference_path))
        current_hash = imagehash.average_hash(Image.open(current_path))
        
        difference = reference_hash - current_hash
        
        if difference > threshold:
            logger.warning(f'Visual regression detected: {difference} > {threshold}')
            self._generate_diff_image(reference_path, current_path)
            return False
        
        logger.info('Visual regression test passed')
        return True
    
    def _generate_diff_image(self, reference_path, current_path):
        """Generate difference image."""
        reference_img = Image.open(reference_path)
        current_img = Image.open(current_path)
        
        # Create difference image
        diff = Image.new('RGB', reference_img.size)
        
        for x in range(reference_img.width):
            for y in range(reference_img.height):
                r1, g1, b1 = reference_img.getpixel((x, y))
                r2, g2, b2 = current_img.getpixel((x, y))
                
                # Highlight differences in red
                if abs(r1-r2) + abs(g1-g2) + abs(b1-b2) > 30:
                    diff.putpixel((x, y), (255, 0, 0))
                else:
                    diff.putpixel((x, y), (r1, g1, b1))
        
        # Save difference image
        diff_path = os.path.join(self.reference_dir, 'diff.png')
        diff.save(diff_path)
        logger.info(f'Generated difference image: {diff_path}')

# Example usage
tester = VisualRegressionTester()

# Create visualization
visualizer = RobustVisualizer(wiring)
fig = visualizer.create_visualization()

# Capture reference (first time)
tester.capture_reference('network', fig)

# Later: Compare with reference
passed = tester.compare_with_reference('network', fig)
print(f'Visual regression test passed: {passed}')

## Debugging Tips

1. **Memory Issues**
   - Profile memory usage
   - Track object lifecycles
   - Clear unused resources
   - Monitor memory patterns

2. **Performance Issues**
   - Profile critical sections
   - Identify bottlenecks
   - Optimize algorithms
   - Cache results

3. **Error Handling**
   - Implement fallbacks
   - Log errors properly
   - Validate inputs
   - Handle edge cases

4. **Visual Testing**
   - Use reference images
   - Compare visual output
   - Track regressions
   - Document changes