# Visualization Performance Optimization

This notebook demonstrates techniques for optimizing visualization performance:
- Memory-Efficient Visualization
- Rendering Optimization
- Interactive Performance
- Large-Scale Visualization

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import networkx as nx
import memory_profiler
import time
from functools import lru_cache

from ncps.mlx import CfC, LTC
from ncps.mlx.wirings import Random, NCP, AutoNCP
from ncps.mlx.visualization import WiringVisualizer, PerformanceVisualizer

## 1. Memory-Efficient Visualization

Optimize memory usage in visualizations:

In [None]:
class MemoryEfficientVisualizer(WiringVisualizer):
    """Memory-efficient visualization implementation."""
    
    def __init__(self, wiring, cache_size=100):
        super().__init__(wiring)
        self.cache_size = cache_size
        self.cache = {}
    
    @lru_cache(maxsize=128)
    def _compute_layout(self, seed=None):
        """Compute and cache graph layout."""
        return nx.spring_layout(self.graph, seed=seed)
    
    def _clear_old_cache(self):
        """Clear old cache entries."""
        if len(self.cache) > self.cache_size:
            # Remove oldest entries
            oldest = sorted(self.cache.items(), key=lambda x: x[1]['timestamp'])[:len(self.cache)-self.cache_size]
            for key, _ in oldest:
                del self.cache[key]
    
    def plot_efficient(self, chunk_size=1000):
        """Plot network in memory-efficient way."""
        plt.figure(figsize=(10, 10))
        
        # Get cached layout
        pos = self._compute_layout()
        
        # Plot edges in chunks
        edges = list(self.graph.edges())
        for i in range(0, len(edges), chunk_size):
            chunk = edges[i:i+chunk_size]
            edge_pos = np.array([(pos[u], pos[v]) for u, v in chunk])
            
            # Plot edge chunk
            for start, end in edge_pos:
                plt.plot(
                    [start[0], end[0]],
                    [start[1], end[1]],
                    'gray',
                    alpha=0.5,
                    linewidth=0.5
                )
            
            # Clear memory
            del edge_pos
        
        # Plot nodes
        node_x = [coord[0] for coord in pos.values()]
        node_y = [coord[1] for coord in pos.values()]
        plt.scatter(node_x, node_y, c='lightblue', s=50)
        
        plt.title('Memory-Efficient Network Visualization')
        plt.axis('equal')
        plt.show()

# Example usage
wiring = Random(units=1000, sparsity_level=0.1)  # Large network
visualizer = MemoryEfficientVisualizer(wiring)

# Profile memory usage
@memory_profiler.profile
def plot_with_profiling():
    visualizer.plot_efficient()

plot_with_profiling()

## 2. Rendering Optimization

Optimize rendering performance:

In [None]:
class OptimizedRenderer(WiringVisualizer):
    """Optimized rendering implementation."""
    
    def __init__(self, wiring):
        super().__init__(wiring)
        self.figure = None
        self.background = None
    
    def _initialize_figure(self):
        """Initialize figure and cache background."""
        self.figure = plt.figure(figsize=(10, 10))
        ax = self.figure.add_subplot(111)
        
        # Draw static elements
        pos = nx.spring_layout(self.graph)
        nx.draw_networkx_edges(
            self.graph,
            pos,
            alpha=0.5,
            ax=ax
        )
        
        # Cache background
        self.figure.canvas.draw()
        self.background = self.figure.canvas.copy_from_bbox(ax.bbox)
        
        return ax, pos
    
    def plot_dynamic(self, n_frames=100):
        """Create dynamic visualization with optimized rendering."""
        ax, pos = self._initialize_figure()
        
        # Create node collection for efficient updates
        node_collection = ax.scatter(
            [],
            [],
            c='lightblue',
            s=50
        )
        
        # Animation loop
        for frame in range(n_frames):
            # Restore background
            self.figure.canvas.restore_region(self.background)
            
            # Update node positions
            node_x = [coord[0] + 0.1*np.sin(frame/10) for coord in pos.values()]
            node_y = [coord[1] + 0.1*np.cos(frame/10) for coord in pos.values()]
            node_collection.set_offsets(np.c_[node_x, node_y])
            
            # Update only changed artists
            ax.draw_artist(node_collection)
            
            # Show frame
            self.figure.canvas.blit(ax.bbox)
            self.figure.canvas.flush_events()
            
            time.sleep(0.05)

# Example usage
visualizer = OptimizedRenderer(wiring)
visualizer.plot_dynamic()

## 3. Interactive Performance

Optimize interactive visualization performance:

In [None]:
class HighPerformanceInteractive:
    """High-performance interactive visualization."""
    
    def __init__(self, model):
        self.model = model
        self.wiring = model.wiring
        self.downsampled_data = None
        self.full_data = None
    
    def _downsample(self, data, target_size=1000):
        """Downsample data for faster rendering."""
        if len(data) <= target_size:
            return data
        
        # Use stride-based downsampling
        stride = len(data) // target_size
        return data[::stride]
    
    def _create_figure(self):
        """Create interactive figure with optimizations."""
        fig = go.Figure()
        
        # Add traces with WebGL rendering
        fig.add_trace(go.Scattergl(
            x=[],
            y=[],
            mode='lines+markers',
            name='Activity'
        ))
        
        # Optimize layout
        fig.update_layout(
            uirevision=True,  # Preserve UI state
            hovermode='closest',
            showlegend=False
        )
        
        return fig
    
    def create_interactive_view(self, input_data):
        """Create high-performance interactive view."""
        # Get network activity
        output = self.model(input_data)
        activity = mx.mean(output, axis=(0,2))
        
        # Store data
        self.full_data = activity
        self.downsampled_data = self._downsample(activity)
        
        # Create figure
        fig = self._create_figure()
        
        # Add downsampled data
        fig.update_traces(
            x=list(range(len(self.downsampled_data))),
            y=self.downsampled_data
        )
        
        # Add range slider with full resolution data
        fig.update_layout(
            xaxis=dict(
                rangeslider=dict(
                    visible=True,
                    thickness=0.1
                )
            )
        )
        
        return fig

# Example usage
model = CfC(wiring)
visualizer = HighPerformanceInteractive(model)

# Generate input data
input_data = mx.random.normal((1, 1000, 2))  # Large sequence

# Create visualization
fig = visualizer.create_interactive_view(input_data)
fig.show()

## 4. Large-Scale Visualization

Optimize visualization for large networks:

In [None]:
class LargeScaleVisualizer:
    """Visualization optimized for large networks."""
    
    def __init__(self, model, max_visible_nodes=1000):
        self.model = model
        self.wiring = model.wiring
        self.max_visible_nodes = max_visible_nodes
        self.clusters = None
    
    def _cluster_nodes(self):
        """Cluster nodes for hierarchical visualization."""
        # Create graph
        G = nx.from_numpy_array(self.wiring.adjacency_matrix)
        
        # Compute clusters
        clusters = nx.community.louvain_communities(G.to_undirected())
        return clusters
    
    def _create_cluster_graph(self):
        """Create clustered graph representation."""
        if self.clusters is None:
            self.clusters = self._cluster_nodes()
        
        # Create cluster graph
        cluster_graph = nx.Graph()
        
        # Add cluster nodes
        for i, cluster in enumerate(self.clusters):
            cluster_graph.add_node(
                f'Cluster {i}',
                size=len(cluster)
            )
        
        # Add cluster edges
        adj = self.wiring.adjacency_matrix
        for i, cluster1 in enumerate(self.clusters):
            for j, cluster2 in enumerate(self.clusters):
                if i < j:
                    weight = np.sum(adj[np.ix_(list(cluster1), list(cluster2))])
                    if weight > 0:
                        cluster_graph.add_edge(
                            f'Cluster {i}',
                            f'Cluster {j}',
                            weight=float(weight)
                        )
        
        return cluster_graph
    
    def plot_hierarchical(self):
        """Create hierarchical visualization."""
        # Create cluster graph
        cluster_graph = self._create_cluster_graph()
        
        # Create figure
        fig = go.Figure()
        
        # Add edges
        edge_x = []
        edge_y = []
        edge_weights = []
        
        pos = nx.spring_layout(cluster_graph)
        for edge in cluster_graph.edges(data=True):
            x0, y0 = pos[edge[0]]
            x1, y1 = pos[edge[1]]
            edge_x.extend([x0, x1, None])
            edge_y.extend([y0, y1, None])
            edge_weights.append(edge[2]['weight'])
        
        fig.add_trace(go.Scatter(
            x=edge_x,
            y=edge_y,
            mode='lines',
            line=dict(
                width=1,
                color='gray'
            ),
            hoverinfo='none'
        ))
        
        # Add nodes
        node_x = []
        node_y = []
        node_sizes = []
        node_text = []
        
        for node in cluster_graph.nodes(data=True):
            x, y = pos[node[0]]
            node_x.append(x)
            node_y.append(y)
            node_sizes.append(node[1]['size'] * 10)
            node_text.append(f"{node[0]}\nSize: {node[1]['size']}")
        
        fig.add_trace(go.Scatter(
            x=node_x,
            y=node_y,
            mode='markers',
            marker=dict(
                size=node_sizes,
                color='lightblue',
                line=dict(width=1)
            ),
            text=node_text,
            hoverinfo='text'
        ))
        
        # Update layout
        fig.update_layout(
            title='Hierarchical Network Visualization',
            showlegend=False,
            hovermode='closest',
            margin=dict(b=20, l=5, r=5, t=40)
        )
        
        return fig

# Example usage
wiring = Random(units=5000, sparsity_level=0.01)  # Very large network
model = CfC(wiring)
visualizer = LargeScaleVisualizer(model)

# Create visualization
fig = visualizer.plot_hierarchical()
fig.show()

## Optimization Tips

1. **Memory Optimization**
   - Use data chunking
   - Implement caching
   - Clear unused resources
   - Monitor memory usage

2. **Rendering Optimization**
   - Cache static elements
   - Use efficient updates
   - Optimize redraw calls
   - Consider hardware acceleration

3. **Interactive Performance**
   - Implement downsampling
   - Use WebGL rendering
   - Optimize event handling
   - Cache computations

4. **Large-Scale Visualization**
   - Use hierarchical views
   - Implement clustering
   - Optimize data structures
   - Consider streaming data