# Manifold Learning for Neural Population Dynamics

This tutorial demonstrates how to use MEA-Flow's manifold learning capabilities to analyze neural population dynamics and state space trajectories.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mea_flow.data import SpikeList
from mea_flow.manifold import ManifoldAnalysis
from mea_flow.visualization import MEAPlotter

## 1. Generate Synthetic Population Data

In [None]:
# Create synthetic population with temporal structure
np.random.seed(42)
n_channels = 32
duration = 120.0  # 2 minutes

spike_data = {}
for ch in range(1, n_channels + 1):
    # Create time-varying firing rates
    base_rate = 2.0 + np.random.exponential(3.0)
    n_spikes = int(base_rate * duration)
    spike_times = np.sort(np.random.uniform(0, duration, n_spikes))
    spike_data[ch] = spike_times

spike_list = SpikeList(spike_data, recording_length=duration)
print(f"Created population data: {len(spike_list.get_active_channels())} channels, "
      f"{sum(train.n_spikes for train in spike_list.spike_trains.values())} spikes")

## 2. Basic Manifold Analysis

In [None]:
# Initialize manifold analyzer
manifold = ManifoldAnalysis()

# Analyze population dynamics with multiple methods
results = manifold.analyze_population_dynamics(
    spike_list,
    methods=['pca', 'mds'],  # Start with basic methods
    bin_size=1.0,  # 1-second bins
    n_components=3
)

print("Manifold analysis complete:")
for method in results.keys():
    print(f"  {method.upper()}: {results[method]['embedding'].shape}")

## 3. Visualize Embeddings

In [None]:
# Plot 2D embeddings
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

for i, method in enumerate(['pca', 'mds']):
    embedding = results[method]['embedding']
    time_points = np.arange(len(embedding))
    
    scatter = axes[i].scatter(embedding[:, 0], embedding[:, 1], 
                             c=time_points, cmap='viridis', alpha=0.7)
    axes[i].set_xlabel(f'{method.upper()} Component 1')
    axes[i].set_ylabel(f'{method.upper()} Component 2')
    axes[i].set_title(f'{method.upper()} Embedding')
    plt.colorbar(scatter, ax=axes[i], label='Time (bins)')

plt.tight_layout()
plt.show()

## 4. Advanced Methods (Optional Dependencies)

In [None]:
# Try advanced methods if available
try:
    advanced_results = manifold.analyze_population_dynamics(
        spike_list,
        methods=['pca', 'umap', 'tsne'],
        bin_size=1.0,
        n_components=2
    )
    print("Advanced methods available")
    
    # Plot comparison
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    methods = ['pca', 'umap', 'tsne']
    
    for i, method in enumerate(methods):
        if method in advanced_results:
            embedding = advanced_results[method]['embedding']
            time_points = np.arange(len(embedding))
            
            scatter = axes[i].scatter(embedding[:, 0], embedding[:, 1],
                                     c=time_points, cmap='plasma', alpha=0.7)
            axes[i].set_title(f'{method.upper()} Embedding')
            plt.colorbar(scatter, ax=axes[i], label='Time')
    
    plt.tight_layout()
    plt.show()
    
except Exception as e:
    print(f"Advanced methods not available: {e}")
    print("Install with: pip install umap-learn")

## 5. Cross-Condition Comparison

In [None]:
# Create two conditions for comparison
def create_condition_data(base_rate_multiplier=1.0, seed=42):
    np.random.seed(seed)
    spike_data = {}
    for ch in range(1, 17):  # Smaller for demo
        base_rate = (2.0 + np.random.exponential(2.0)) * base_rate_multiplier
        n_spikes = int(base_rate * 60)  # 1 minute
        spike_times = np.sort(np.random.uniform(0, 60, n_spikes))
        spike_data[ch] = spike_times
    return SpikeList(spike_data, recording_length=60.0)

# Create conditions
control = create_condition_data(1.0, seed=42)
treatment = create_condition_data(1.5, seed=123)  # Higher activity

conditions = {'control': control, 'treatment': treatment}
print("Created condition data for comparison")

In [None]:
# Compare manifolds across conditions
comparison = manifold.compare_conditions(
    conditions,
    methods=['pca'],
    bin_size=2.0,
    alignment='procrustes'
)

print("Cross-condition comparison complete")
print(f"Alignment score: {comparison.get('alignment_score', 'N/A')}")

## 6. Embedding Quality Assessment

In [None]:
# Evaluate embedding quality
quality_metrics = manifold.evaluate_embedding_quality(
    spike_list,
    methods=['pca', 'mds'],
    bin_size=1.0
)

print("Embedding Quality Metrics:")
for method, metrics in quality_metrics.items():
    print(f"\n{method.upper()}:")
    for metric, value in metrics.items():
        if isinstance(value, (int, float)):
            print(f"  {metric}: {value:.3f}")
        else:
            print(f"  {metric}: {value}")

## 7. Trajectory Analysis

In [None]:
# Analyze state space trajectories
pca_embedding = results['pca']['embedding']

# Plot trajectory with arrows showing temporal progression
fig, ax = plt.subplots(figsize=(10, 8))

# Plot trajectory points
time_points = np.arange(len(pca_embedding))
scatter = ax.scatter(pca_embedding[:, 0], pca_embedding[:, 1], 
                    c=time_points, cmap='viridis', s=50, alpha=0.7)

# Add trajectory arrows (every 5th point to avoid clutter)
for i in range(0, len(pca_embedding)-5, 5):
    dx = pca_embedding[i+5, 0] - pca_embedding[i, 0]
    dy = pca_embedding[i+5, 1] - pca_embedding[i, 1]
    ax.arrow(pca_embedding[i, 0], pca_embedding[i, 1], dx, dy,
             head_width=0.1, head_length=0.1, fc='red', ec='red', alpha=0.5)

ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_title('Neural Population State Space Trajectory')
plt.colorbar(scatter, label='Time (bins)')
plt.tight_layout()
plt.show()

print("✓ Trajectory visualization complete")

## Summary

This tutorial covered:

✅ **Basic manifold analysis** with PCA and MDS  
✅ **Advanced methods** (UMAP, t-SNE) when available  
✅ **Cross-condition comparison** with Procrustes alignment  
✅ **Embedding quality assessment** with multiple metrics  
✅ **Trajectory visualization** in state space  

## Next Steps

1. Try with your own MEA data
2. Experiment with different bin sizes and time windows
3. Compare multiple experimental conditions
4. Explore the [Cross-Condition Analysis](04_cross_condition_analysis.ipynb) tutorial