# Visualization Library Integrations

This notebook demonstrates how to integrate neural circuit policy visualizations with popular visualization libraries:
- Bokeh Integration
- Altair Integration
- HoloViews Integration
- Dash Integration

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import pandas as pd

# Visualization libraries
import bokeh.plotting as bk
import altair as alt
import holoviews as hv
import dash
from dash import dcc, html
from dash.dependencies import Input, Output

from ncps.mlx import CfC, LTC
from ncps.mlx.wirings import Random, NCP, AutoNCP
from ncps.mlx.visualization import WiringVisualizer, PerformanceVisualizer

# Initialize HoloViews
hv.extension('bokeh')

## 1. Bokeh Integration

Create interactive visualizations with Bokeh:

In [None]:
class BokehWiringVisualizer(WiringVisualizer):
    """Bokeh-based wiring visualization."""
    
    def plot_interactive_wiring(self, width=800, height=600):
        """Create interactive wiring plot."""
        # Create graph layout
        import networkx as nx
        pos = nx.spring_layout(self.graph)
        
        # Create node data
        node_x = [pos[node][0] for node in self.graph.nodes()]
        node_y = [pos[node][1] for node in self.graph.nodes()]
        node_labels = [str(node) for node in self.graph.nodes()]
        
        # Create edge data
        edge_x = []
        edge_y = []
        edge_weights = []
        for edge in self.graph.edges():
            x0, y0 = pos[edge[0]]
            x1, y1 = pos[edge[1]]
            edge_x.extend([x0, x1, None])
            edge_y.extend([y0, y1, None])
            edge_weights.append(self.wiring.adjacency_matrix[edge[0], edge[1]])
        
        # Create Bokeh figure
        p = bk.figure(
            width=width,
            height=height,
            title='Interactive Wiring Visualization',
            tools='pan,box_zoom,wheel_zoom,reset,save,hover'
        )
        
        # Add edges
        p.line(
            edge_x,
            edge_y,
            line_color='gray',
            line_alpha=0.5
        )
        
        # Add nodes
        node_source = bk.ColumnDataSource({
            'x': node_x,
            'y': node_y,
            'label': node_labels,
            'size': [10] * len(node_x)
        })
        
        p.circle(
            'x',
            'y',
            size='size',
            source=node_source,
            fill_color='lightblue',
            line_color='black'
        )
        
        # Add hover tool
        p.hover.tooltips = [
            ('Node', '@label'),
            ('Position', '($x, $y)')
        ]
        
        return p

# Example usage
wiring = NCP(
    inter_neurons=20,
    command_neurons=15,
    motor_neurons=5,
    sensory_fanout=3,
    inter_fanout=3,
    recurrent_command_synapses=5,
    motor_fanin=3
)

visualizer = BokehWiringVisualizer(wiring)
p = visualizer.plot_interactive_wiring()
bk.show(p)

## 2. Altair Integration

Create declarative visualizations with Altair:

In [None]:
class AltairPerformanceVisualizer(PerformanceVisualizer):
    """Altair-based performance visualization."""
    
    def plot_metrics_altair(self):
        """Create interactive metrics visualization."""
        # Convert history to DataFrame
        data = []
        for metric, values in self.history.items():
            for step, value in enumerate(values):
                data.append({
                    'Step': step,
                    'Metric': metric,
                    'Value': value
                })
        df = pd.DataFrame(data)
        
        # Create base chart
        base = alt.Chart(df).encode(
            x='Step:Q',
            color='Metric:N'
        )
        
        # Create line chart
        lines = base.mark_line().encode(
            y='Value:Q'
        )
        
        # Create points
        points = base.mark_circle(size=60).encode(
            y='Value:Q',
            tooltip=['Step:Q', 'Metric:N', 'Value:Q']
        )
        
        # Combine charts
        chart = (lines + points).properties(
            width=600,
            height=400,
            title='Performance Metrics'
        ).interactive()
        
        return chart

# Example usage
visualizer = AltairPerformanceVisualizer()

# Add sample metrics
for i in range(100):
    visualizer.add_metrics(
        loss=np.exp(-i/50) + 0.1 * np.random.randn(),
        accuracy=1 - np.exp(-i/30) + 0.05 * np.random.randn()
    )

chart = visualizer.plot_metrics_altair()
chart

## 3. HoloViews Integration

Create composable visualizations with HoloViews:

In [None]:
class HoloViewsVisualizer:
    """HoloViews-based visualization."""
    
    def __init__(self, model):
        self.model = model
        self.wiring = model.wiring
    
    def plot_network_dynamics(self, input_data):
        """Create interactive network dynamics visualization."""
        # Get network states
        output = self.model(input_data)
        states = mx.mean(output, axis=(0,2))
        
        # Create state evolution plot
        state_curve = hv.Curve(
            (range(len(states)), states),
            'Time Step',
            'State'
        ).opts(
            width=400,
            height=300,
            title='State Evolution'
        )
        
        # Create connectivity matrix plot
        connectivity = hv.HeatMap(
            self.wiring.adjacency_matrix,
            'From Node',
            'To Node'
        ).opts(
            width=400,
            height=300,
            title='Connectivity Matrix',
            colorbar=True,
            tools=['hover']
        )
        
        # Combine plots
        layout = (state_curve + connectivity).cols(2)
        
        return layout

# Example usage
model = CfC(wiring)
visualizer = HoloViewsVisualizer(model)

# Generate input data
input_data = mx.random.normal((1, 10, 2))

# Create visualization
layout = visualizer.plot_network_dynamics(input_data)
layout

## 4. Dash Integration

Create interactive web applications with Dash:

In [None]:
class DashVisualizer:
    """Dash-based interactive visualization."""
    
    def __init__(self, model):
        self.model = model
        self.wiring = model.wiring
        self.app = dash.Dash(__name__)
        self.setup_layout()
        self.setup_callbacks()
    
    def setup_layout(self):
        """Create Dash layout."""
        self.app.layout = html.Div([
            html.H1('Neural Circuit Policy Visualization'),
            
            html.Div([
                html.Label('Input Size:'),
                dcc.Slider(
                    id='input-size-slider',
                    min=1,
                    max=50,
                    value=10,
                    marks={i: str(i) for i in range(0, 51, 10)}
                )
            ]),
            
            html.Div([
                dcc.Graph(id='network-graph'),
                dcc.Graph(id='state-graph')
            ], style={'display': 'flex'}),
            
            dcc.Interval(
                id='interval-component',
                interval=1000,  # in milliseconds
                n_intervals=0
            )
        ])
    
    def setup_callbacks(self):
        """Setup Dash callbacks."""
        @self.app.callback(
            [Output('network-graph', 'figure'),
             Output('state-graph', 'figure')],
            [Input('input-size-slider', 'value'),
             Input('interval-component', 'n_intervals')]
        )
        def update_graphs(input_size, n_intervals):
            # Generate input data
            input_data = mx.random.normal((1, input_size, 2))
            
            # Get network states
            output = self.model(input_data)
            states = mx.mean(output, axis=(0,2))
            
            # Create network graph
            network_fig = {
                'data': [{
                    'type': 'heatmap',
                    'z': self.wiring.adjacency_matrix,
                    'colorscale': 'Viridis'
                }],
                'layout': {
                    'title': 'Network Connectivity',
                    'xaxis': {'title': 'From Node'},
                    'yaxis': {'title': 'To Node'}
                }
            }
            
            # Create state graph
            state_fig = {
                'data': [{
                    'x': list(range(len(states))),
                    'y': states,
                    'type': 'scatter',
                    'mode': 'lines+markers'
                }],
                'layout': {
                    'title': 'Network States',
                    'xaxis': {'title': 'Time Step'},
                    'yaxis': {'title': 'State'}
                }
            }
            
            return network_fig, state_fig
    
    def run(self, debug=True, port=8050):
        """Run Dash application."""
        self.app.run_server(debug=debug, port=port)

# Example usage
model = CfC(wiring)
visualizer = DashVisualizer(model)
visualizer.run()

## Integration Tips

1. **Library Selection**
   - Choose based on needs
   - Consider interactivity
   - Check performance
   - Evaluate features

2. **Implementation**
   - Follow library patterns
   - Use library features
   - Handle data conversion
   - Optimize rendering

3. **Best Practices**
   - Keep code modular
   - Handle updates efficiently
   - Consider user interaction
   - Document clearly