# Advanced Visualization Use Cases

This notebook demonstrates advanced visualization techniques for specific use cases:
- Time Series Analysis
- Reinforcement Learning
- Anomaly Detection
- Real-time Control

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from ncps.mlx import CfC, LTC
from ncps.mlx.wirings import Random, NCP, AutoNCP
from ncps.mlx.advanced_profiling import MLXProfiler
from ncps.mlx.visualization import WiringVisualizer, PerformanceVisualizer, ProfileVisualizer, plot_comparison

## 1. Time Series Analysis

Visualize temporal patterns and dependencies:

In [None]:
def analyze_temporal_patterns(model, data, targets):
    """Analyze temporal patterns in predictions."""
    # Get predictions
    preds = model(data)
    
    # Plot temporal patterns
    plt.figure(figsize=(15, 10))
    
    # Plot example sequence
    plt.subplot(221)
    seq_idx = 0
    plt.plot(data[seq_idx, :, 0], label='Input')
    plt.plot(targets[seq_idx, :, 0], label='Target')
    plt.plot(preds[seq_idx, :, 0], label='Prediction')
    plt.xlabel('Time Step')
    plt.ylabel('Value')
    plt.title('Example Sequence')
    plt.legend()
    plt.grid(True)
    
    # Plot prediction error over time
    plt.subplot(222)
    error = mx.mean((preds - targets) ** 2, axis=2)
    plt.plot(mx.mean(error, axis=0))
    plt.fill_between(
        range(error.shape[1]),
        mx.mean(error, axis=0) - mx.std(error, axis=0),
        mx.mean(error, axis=0) + mx.std(error, axis=0),
        alpha=0.3
    )
    plt.xlabel('Time Step')
    plt.ylabel('MSE')
    plt.title('Prediction Error Over Time')
    plt.grid(True)
    
    # Plot attention heatmap
    plt.subplot(223)
    attention = mx.abs(model.cell.wiring.adjacency_matrix)
    plt.imshow(attention, cmap='viridis')
    plt.colorbar(label='Connection Strength')
    plt.xlabel('To Node')
    plt.ylabel('From Node')
    plt.title('Temporal Attention')
    
    # Plot frequency response
    plt.subplot(224)
    freqs = np.fft.fftfreq(data.shape[1])
    input_fft = np.abs(np.fft.fft(data[0, :, 0]))
    pred_fft = np.abs(np.fft.fft(preds[0, :, 0]))
    plt.plot(freqs[1:len(freqs)//2], input_fft[1:len(freqs)//2], label='Input')
    plt.plot(freqs[1:len(freqs)//2], pred_fft[1:len(freqs)//2], label='Prediction')
    plt.xlabel('Frequency')
    plt.ylabel('Magnitude')
    plt.title('Frequency Response')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()

# Generate time series data
def generate_time_series(n_samples=1000, seq_length=50):
    t = np.linspace(0, 4*np.pi, seq_length)
    X = np.zeros((n_samples, seq_length, 1))
    y = np.zeros((n_samples, seq_length, 1))
    
    for i in range(n_samples):
        freq = 1.0 + 0.1 * np.random.randn()
        phase = 2 * np.pi * np.random.rand()
        X[i, :, 0] = np.sin(freq * t + phase)
        y[i, :, 0] = np.cos(freq * t + phase)
    
    return mx.array(X), mx.array(y)

# Create and train model
X, y = generate_time_series()
model = CfC(Random(units=100, sparsity_level=0.5))

# Train model
optimizer = nn.Adam(learning_rate=0.001)
for epoch in range(50):
    def loss_fn(model, x, y):
        pred = model(x)
        return mx.mean((pred - y) ** 2)
    
    loss, grads = mx.value_and_grad(model, loss_fn)(model, X, y)
    optimizer.update(model, grads)

# Analyze temporal patterns
analyze_temporal_patterns(model, X[:10], y[:10])

## 2. Reinforcement Learning Analysis

Visualize RL training dynamics:

In [None]:
def analyze_rl_dynamics(model, states, actions, rewards):
    """Analyze reinforcement learning dynamics."""
    # Get action predictions
    pred_actions = model(states)
    
    plt.figure(figsize=(15, 10))
    
    # Plot state-action mapping
    plt.subplot(221)
    plt.scatter(states[:, 0, 0], states[:, 0, 1],
                c=actions[:, 0, 0], cmap='viridis')
    plt.colorbar(label='Action')
    plt.xlabel('State Dim 1')
    plt.ylabel('State Dim 2')
    plt.title('State-Action Mapping')
    
    # Plot value function
    plt.subplot(222)
    plt.scatter(states[:, 0, 0], states[:, 0, 1],
                c=rewards[:, 0], cmap='viridis')
    plt.colorbar(label='Value')
    plt.xlabel('State Dim 1')
    plt.ylabel('State Dim 2')
    plt.title('Value Function')
    
    # Plot action distribution
    plt.subplot(223)
    plt.hist2d(actions[:, 0, 0].reshape(-1),
               pred_actions[:, 0, 0].reshape(-1),
               bins=50)
    plt.colorbar(label='Count')
    plt.xlabel('True Action')
    plt.ylabel('Predicted Action')
    plt.title('Action Distribution')
    
    # Plot reward correlation
    plt.subplot(224)
    action_error = mx.mean((pred_actions - actions) ** 2, axis=2)
    plt.scatter(rewards[:, 0], action_error[:, 0], alpha=0.5)
    plt.xlabel('Reward')
    plt.ylabel('Action Error')
    plt.title('Reward vs Action Error')
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()

# Generate RL data
def generate_rl_data(n_samples=1000, seq_length=10):
    states = np.random.randn(n_samples, seq_length, 2)
    actions = np.random.randn(n_samples, seq_length, 1)
    rewards = np.sum(states ** 2, axis=2)
    return mx.array(states), mx.array(actions), mx.array(rewards)

# Create and train model
states, actions, rewards = generate_rl_data()
model = CfC(NCP(
    inter_neurons=50,
    command_neurons=30,
    motor_neurons=1,
    sensory_fanout=3,
    inter_fanout=3,
    recurrent_command_synapses=5,
    motor_fanin=3
))

# Train model
optimizer = nn.Adam(learning_rate=0.001)
for epoch in range(50):
    def loss_fn(model, x, y):
        pred = model(x)
        return mx.mean((pred - y) ** 2)
    
    loss, grads = mx.value_and_grad(model, loss_fn)(model, states, actions)
    optimizer.update(model, grads)

# Analyze RL dynamics
analyze_rl_dynamics(model, states, actions, rewards)

## 3. Anomaly Detection

Visualize anomaly detection patterns:

In [None]:
def analyze_anomalies(model, data, anomaly_scores):
    """Analyze anomaly detection patterns."""
    # Get reconstructions
    reconstructions = model(data)
    
    plt.figure(figsize=(15, 10))
    
    # Plot normal vs anomalous patterns
    plt.subplot(221)
    normal_idx = np.argmin(anomaly_scores)
    anomaly_idx = np.argmax(anomaly_scores)
    plt.plot(data[normal_idx, :, 0], label='Normal')
    plt.plot(data[anomaly_idx, :, 0], label='Anomaly')
    plt.xlabel('Time Step')
    plt.ylabel('Value')
    plt.title('Normal vs Anomalous Patterns')
    plt.legend()
    plt.grid(True)
    
    # Plot reconstruction error distribution
    plt.subplot(222)
    error = mx.mean((reconstructions - data) ** 2, axis=(1, 2))
    plt.hist(error, bins=50, density=True, alpha=0.7)
    plt.axvline(error[anomaly_idx], color='r', linestyle='--',
                label='Max Anomaly')
    plt.xlabel('Reconstruction Error')
    plt.ylabel('Density')
    plt.title('Error Distribution')
    plt.legend()
    plt.grid(True)
    
    # Plot temporal anomaly scores
    plt.subplot(223)
    temporal_error = mx.mean((reconstructions - data) ** 2, axis=2)
    plt.imshow(temporal_error.T, aspect='auto', cmap='viridis')
    plt.colorbar(label='Error')
    plt.xlabel('Sample')
    plt.ylabel('Time Step')
    plt.title('Temporal Anomaly Scores')
    
    # Plot feature importance
    plt.subplot(224)
    feature_error = mx.mean((reconstructions - data) ** 2, axis=1)
    plt.boxplot([feature_error[:, i] for i in range(feature_error.shape[1])])
    plt.xlabel('Feature')
    plt.ylabel('Reconstruction Error')
    plt.title('Feature Importance')
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()

# Generate anomaly data
def generate_anomaly_data(n_samples=1000, seq_length=50, n_features=3):
    data = np.random.randn(n_samples, seq_length, n_features)
    # Add anomalies
    anomaly_idx = np.random.choice(n_samples, size=int(0.1*n_samples))
    data[anomaly_idx] += 3 * np.random.randn(len(anomaly_idx), seq_length, n_features)
    # Calculate anomaly scores
    anomaly_scores = np.mean(data ** 2, axis=(1, 2))
    return mx.array(data), mx.array(anomaly_scores)

# Create and train model
data, anomaly_scores = generate_anomaly_data()
model = CfC(AutoNCP(units=100, output_size=3, sparsity_level=0.5))

# Train model
optimizer = nn.Adam(learning_rate=0.001)
for epoch in range(50):
    def loss_fn(model, x):
        pred = model(x)
        return mx.mean((pred - x) ** 2)
    
    loss, grads = mx.value_and_grad(model, loss_fn)(model, data)
    optimizer.update(model, grads)

# Analyze anomalies
analyze_anomalies(model, data, anomaly_scores)

## 4. Real-time Control

Visualize control system behavior:

In [None]:
def analyze_control_system(model, states, actions, targets):
    """Analyze control system behavior."""
    # Get control signals
    pred_actions = model(states)
    
    plt.figure(figsize=(15, 10))
    
    # Plot trajectory
    plt.subplot(221)
    plt.plot(states[0, :, 0], states[0, :, 1], label='Actual')
    plt.plot(targets[0, :, 0], targets[0, :, 1], '--', label='Target')
    plt.xlabel('X Position')
    plt.ylabel('Y Position')
    plt.title('System Trajectory')
    plt.legend()
    plt.grid(True)
    
    # Plot control signals
    plt.subplot(222)
    plt.plot(actions[0, :, 0], label='True Control')
    plt.plot(pred_actions[0, :, 0], '--', label='Predicted Control')
    plt.xlabel('Time Step')
    plt.ylabel('Control Signal')
    plt.title('Control Signals')
    plt.legend()
    plt.grid(True)
    
    # Plot phase portrait
    plt.subplot(223)
    plt.quiver(states[0, :-1, 0], states[0, :-1, 1],
               states[0, 1:, 0] - states[0, :-1, 0],
               states[0, 1:, 1] - states[0, :-1, 1])
    plt.xlabel('X Position')
    plt.ylabel('Y Position')
    plt.title('Phase Portrait')
    plt.grid(True)
    
    # Plot error over time
    plt.subplot(224)
    error = mx.sqrt(mx.sum((states - targets) ** 2, axis=2))
    plt.plot(mx.mean(error, axis=0))
    plt.fill_between(
        range(error.shape[1]),
        mx.mean(error, axis=0) - mx.std(error, axis=0),
        mx.mean(error, axis=0) + mx.std(error, axis=0),
        alpha=0.3
    )
    plt.xlabel('Time Step')
    plt.ylabel('Error')
    plt.title('Tracking Error')
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()

# Generate control data
def generate_control_data(n_samples=1000, seq_length=50):
    t = np.linspace(0, 2*np.pi, seq_length)
    states = np.zeros((n_samples, seq_length, 4))  # [x, y, vx, vy]
    actions = np.zeros((n_samples, seq_length, 2))  # [ax, ay]
    targets = np.zeros((n_samples, seq_length, 2))  # [x, y]
    
    for i in range(n_samples):
        radius = 1.0 + 0.1 * np.random.randn()
        phase = 2 * np.pi * np.random.rand()
        
        # Target trajectory
        targets[i, :, 0] = radius * np.cos(t + phase)
        targets[i, :, 1] = radius * np.sin(t + phase)
        
        # State trajectory with noise
        states[i, :, 0] = targets[i, :, 0] + 0.1 * np.random.randn(seq_length)
        states[i, :, 1] = targets[i, :, 1] + 0.1 * np.random.randn(seq_length)
        states[i, :, 2] = np.gradient(states[i, :, 0], t)
        states[i, :, 3] = np.gradient(states[i, :, 1], t)
        
        # Control actions
        actions[i, :, 0] = np.gradient(states[i, :, 2], t)
        actions[i, :, 1] = np.gradient(states[i, :, 3], t)
    
    return mx.array(states), mx.array(actions), mx.array(targets)

# Create and train model
states, actions, targets = generate_control_data()
model = CfC(NCP(
    inter_neurons=50,
    command_neurons=30,
    motor_neurons=2,
    sensory_fanout=3,
    inter_fanout=3,
    recurrent_command_synapses=5,
    motor_fanin=3
))

# Train model
optimizer = nn.Adam(learning_rate=0.001)
for epoch in range(50):
    def loss_fn(model, x, y):
        pred = model(x)
        return mx.mean((pred - y) ** 2)
    
    loss, grads = mx.value_and_grad(model, loss_fn)(model, states, actions)
    optimizer.update(model, grads)

# Analyze control system
analyze_control_system(model, states, actions, targets)

## Visualization Insights

Based on our analysis:

1. **Time Series**
   - Strong temporal dependencies
   - Good prediction accuracy
   - Clear frequency patterns

2. **Reinforcement Learning**
   - State-action mapping learned
   - Value function approximated
   - Action distribution matched

3. **Anomaly Detection**
   - Clear anomaly patterns
   - Good reconstruction
   - Feature importance identified

4. **Control Systems**
   - Stable tracking
   - Smooth control signals
   - Error convergence

Recommendations:
- Use task-specific visualizations
- Monitor multiple aspects
- Analyze temporal patterns
- Track system behavior