# PyTau New Scikit-Learn Style API

This notebook demonstrates the new scikit-learn style API for PyTau that works directly with numpy arrays instead of requiring HDF5 files.

## Key Features
- Direct numpy array input (no HDF5 required)
- Scikit-learn style `fit()` and `predict()` methods
- Simplified parameter handling
- Backward compatibility maintained

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add pytau to path
sys.path.append('../../')

from pytau import ChangepointDetector, fit_changepoint_model

## Generate Synthetic Data

Let's create some synthetic spike data with known changepoints to test our model.

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

# Parameters
n_trials = 25
n_neurons = 8
n_timepoints = 100
true_changepoints = [30, 70]  # True changepoint locations

# Generate synthetic data with different firing rates in each state
spike_data = np.zeros((n_trials, n_neurons, n_timepoints))

for trial in range(n_trials):
    for neuron in range(n_neurons):
        # State 1: Low firing rate
        spike_data[trial, neuron, :true_changepoints[0]] = np.random.poisson(
            1.0, size=true_changepoints[0]
        )
        
        # State 2: High firing rate  
        spike_data[trial, neuron, true_changepoints[0]:true_changepoints[1]] = np.random.poisson(
            4.0, size=true_changepoints[1] - true_changepoints[0]
        )
        
        # State 3: Medium firing rate
        spike_data[trial, neuron, true_changepoints[1]:] = np.random.poisson(
            2.0, size=n_timepoints - true_changepoints[1]
        )

print(f"Generated spike data shape: {spike_data.shape}")
print(f"True changepoints: {true_changepoints}")

## Visualize the Data

In [None]:
# Plot average firing rate across trials
avg_firing_rate = np.mean(spike_data, axis=(0, 1))

plt.figure(figsize=(12, 6))
plt.plot(avg_firing_rate, 'b-', linewidth=2, label='Average firing rate')
for cp in true_changepoints:
    plt.axvline(cp, color='red', linestyle='--', alpha=0.7, label=f'True changepoint: {cp}')
plt.xlabel('Time')
plt.ylabel('Average Firing Rate')
plt.title('Synthetic Spike Data with True Changepoints')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## Method 1: Using ChangepointDetector Class

In [None]:
# Create and configure the detector
detector = ChangepointDetector(
    model_type='single_taste_poisson',
    n_states=3,  # We know there are 3 states
    inference_method='advi',
    n_iterations=5000,
    random_state=42
)

print("Fitting changepoint model...")
detector.fit(spike_data)
print("Model fitted successfully!")

In [None]:
# Get predictions
predictions = detector.predict()

print("Predictions:")
for key, value in predictions.items():
    if isinstance(value, np.ndarray):
        print(f"{key}: shape {value.shape}")
        if key == 'changepoints':
            print(f"  Estimated changepoints: {value}")
            print(f"  True changepoints: {true_changepoints}")
            print(f"  Error: {np.abs(value - true_changepoints)}")

## Method 2: Using Convenience Function

In [None]:
# Alternative: Use convenience function for quick fitting
print("Fitting model using convenience function...")
quick_detector = fit_changepoint_model(
    spike_data,
    model_type='single_taste_poisson',
    n_states=3,
    n_iterations=3000,
    random_state=42
)

quick_predictions = quick_detector.predict()
print(f"Quick fit changepoints: {quick_predictions['changepoints']}")

## Visualize Results

In [None]:
# Plot results
plt.figure(figsize=(15, 8))

# Plot 1: Data and changepoints
plt.subplot(2, 1, 1)
plt.plot(avg_firing_rate, 'b-', linewidth=2, label='Average firing rate')

# True changepoints
for i, cp in enumerate(true_changepoints):
    plt.axvline(cp, color='red', linestyle='--', alpha=0.7, 
                label='True changepoint' if i == 0 else '')

# Estimated changepoints
if 'changepoints' in predictions:
    for i, cp in enumerate(predictions['changepoints']):
        plt.axvline(cp, color='green', linestyle='-', alpha=0.7,
                    label='Estimated changepoint' if i == 0 else '')

plt.xlabel('Time')
plt.ylabel('Firing Rate')
plt.title('Changepoint Detection Results')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 2: State sequence
plt.subplot(2, 1, 2)
if 'states' in predictions:
    plt.plot(predictions['states'], 'g-', linewidth=2, label='Predicted states')
    plt.ylabel('State')
    plt.xlabel('Time')
    plt.title('Predicted State Sequence')
    plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Model Comparison

Let's compare different model types on the same data.

In [None]:
# Compare different model types
model_types = ['single_taste_poisson', 'single_taste_poisson_varsig']
results = {}

for model_type in model_types:
    print(f"\nFitting {model_type}...")
    try:
        detector = ChangepointDetector(
            model_type=model_type,
            n_states=3,
            n_iterations=3000,
            random_state=42
        )
        detector.fit(spike_data)
        predictions = detector.predict()
        
        if 'changepoints' in predictions:
            results[model_type] = predictions['changepoints']
            error = np.mean(np.abs(predictions['changepoints'] - true_changepoints))
            print(f"  Estimated changepoints: {predictions['changepoints']}")
            print(f"  Mean absolute error: {error:.2f}")
        else:
            print(f"  No changepoints found in predictions")
            
    except Exception as e:
        print(f"  Error fitting {model_type}: {e}")

print(f"\nTrue changepoints: {true_changepoints}")

## Parameter Exploration

Demonstrate scikit-learn style parameter getting and setting.

In [None]:
# Create detector
detector = ChangepointDetector()

# Get current parameters
params = detector.get_params()
print("Default parameters:")
for key, value in params.items():
    print(f"  {key}: {value}")

# Update parameters
detector.set_params(n_states=4, n_iterations=2000)
print("\nUpdated parameters:")
updated_params = detector.get_params()
for key, value in updated_params.items():
    print(f"  {key}: {value}")

## Summary

The new PyTau API provides:

1. **Simple Interface**: Just pass numpy arrays directly to `fit()` and `predict()`
2. **Scikit-learn Compatibility**: Familiar methods like `get_params()`, `set_params()`, `score()`
3. **No HDF5 Dependency**: Works with any numpy array data
4. **Flexible Models**: Support for multiple changepoint model types
5. **Backward Compatibility**: Original API still available

This makes PyTau much easier to use for researchers who want to apply changepoint detection to their neural data without dealing with complex file formats or preprocessing pipelines.