# SingleTastePoisson Model Demo

This notebook demonstrates the `SingleTastePoisson` model for detecting changepoints in Poisson-distributed data (e.g., spike counts) for a single condition/taste.

## Model Overview
- **Data Type**: Poisson distributed data (spike counts)
- **Input Shape**: 3D array (neurons × trials × time)
- **Scope**: Single taste/condition
- **Use Case**: Neural spike train changepoint detection for one experimental condition

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add pytau to path
sys.path.append('../../')

from pytau.changepoint_model import SingleTastePoisson, gen_test_array
import pymc as pm

## Generate Synthetic Data

Create synthetic Poisson spike count data with known changepoints.

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

# Parameters
n_neurons = 6
n_trials = 15
n_timepoints = 80
n_states = 3

# Generate test data using the built-in function
data_array = gen_test_array((n_neurons, n_trials, n_timepoints), n_states=n_states, type="poisson")

print(f"Data shape: {data_array.shape}")
print(f"Data range: [{data_array.min()}, {data_array.max()}]")
print(f"Mean firing rate: {data_array.mean():.2f} spikes/bin")

## Visualize Input Data

Plot the synthetic spike count data to visualize the changepoints.

In [None]:
# Plot average firing rates across trials
mean_rates = np.mean(data_array, axis=1)  # Average across trials

fig, axes = plt.subplots(n_neurons, 1, figsize=(12, 2*n_neurons), sharex=True)
if n_neurons == 1:
    axes = [axes]

# Expected changepoint locations
changepoint_locs = np.linspace(0, n_timepoints, n_states+1)[1:-1]

for i in range(n_neurons):
    axes[i].plot(mean_rates[i, :], 'b-', alpha=0.7, linewidth=2)
    axes[i].set_ylabel(f'Neuron {i+1}\nSpikes/bin')
    axes[i].grid(True, alpha=0.3)
    
    # Add vertical lines at expected changepoints
    for cp in changepoint_locs:
        axes[i].axvline(cp, color='red', linestyle='--', alpha=0.5)

axes[-1].set_xlabel('Time (bins)')
plt.suptitle('Synthetic Poisson Spike Data (Trial-Averaged)\n(Red dashed lines show true changepoints)', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Plot raster plot for first few neurons and trials
n_show_neurons = min(3, n_neurons)
n_show_trials = min(5, n_trials)

fig, axes = plt.subplots(n_show_neurons, 1, figsize=(12, 2*n_show_neurons), sharex=True)
if n_show_neurons == 1:
    axes = [axes]

for neuron in range(n_show_neurons):
    for trial in range(n_show_trials):
        spike_times = np.where(data_array[neuron, trial, :] > 0)[0]
        spike_counts = data_array[neuron, trial, spike_times]
        
        # Plot spikes as vertical lines with height proportional to count
        for time, count in zip(spike_times, spike_counts):
            axes[neuron].vlines(time, trial, trial + count/data_array.max(), 
                              colors='black', alpha=0.7)
    
    axes[neuron].set_ylabel(f'Neuron {neuron+1}\nTrial')
    axes[neuron].set_ylim(0, n_show_trials + 1)
    axes[neuron].grid(True, alpha=0.3)
    
    # Add changepoints
    for cp in changepoint_locs:
        axes[neuron].axvline(cp, color='red', linestyle='--', alpha=0.5)

axes[-1].set_xlabel('Time (bins)')
plt.suptitle(f'Spike Raster Plot (First {n_show_neurons} neurons, {n_show_trials} trials)', fontsize=14)
plt.tight_layout()
plt.show()

## Create and Fit Model

Initialize the SingleTastePoisson model and fit it to the data.

In [None]:
# Create model instance
model_instance = SingleTastePoisson(data_array, n_states=n_states)

# Generate the PyMC model
model = model_instance.generate_model()

print("Model created successfully!")
print(f"Model variables: {[var.name for var in model.unobserved_RVs]}")

In [None]:
# Fit the model using variational inference
with model:
    # Use ADVI for faster inference
    inference = pm.ADVI()
    approx = pm.fit(n=6000, method=inference)
    
    # Sample from the approximation
    trace = approx.sample(draws=1000)

print("Model fitting completed!")

## Analyze Results

Extract and visualize the inferred changepoints and firing rate parameters.

In [None]:
# Extract changepoint estimates
tau_samples = trace.posterior['tau'].values.reshape(-1, n_states-1)
tau_mean = np.mean(tau_samples, axis=0)
tau_std = np.std(tau_samples, axis=0)

print("Inferred Changepoints:")
for i, (mean_cp, std_cp) in enumerate(zip(tau_mean, tau_std)):
    print(f"  Changepoint {i+1}: {mean_cp:.2f} ± {std_cp:.2f}")

# True changepoints for comparison
true_changepoints = np.linspace(0, n_timepoints, n_states+1)[1:-1]
print(f"\nTrue Changepoints: {true_changepoints}")

In [None]:
# Plot changepoint posterior distributions
fig, axes = plt.subplots(1, n_states-1, figsize=(4*(n_states-1), 4))
if n_states-1 == 1:
    axes = [axes]

for i in range(n_states-1):
    axes[i].hist(tau_samples[:, i], bins=50, alpha=0.7, density=True)
    axes[i].axvline(tau_mean[i], color='red', linestyle='-', label=f'Mean: {tau_mean[i]:.1f}')
    axes[i].axvline(true_changepoints[i], color='green', linestyle='--', label=f'True: {true_changepoints[i]:.1f}')
    axes[i].set_xlabel('Time (bins)')
    axes[i].set_ylabel('Density')
    axes[i].set_title(f'Changepoint {i+1} Posterior')
    axes[i].legend()
    axes[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Visualize Model Fit

Show the original data with inferred changepoints overlaid.

In [None]:
# Plot data with inferred changepoints
fig, axes = plt.subplots(n_neurons, 1, figsize=(12, 2*n_neurons), sharex=True)
if n_neurons == 1:
    axes = [axes]

for i in range(n_neurons):
    axes[i].plot(mean_rates[i, :], 'b-', alpha=0.7, linewidth=2, label='Data')
    axes[i].set_ylabel(f'Neuron {i+1}\nSpikes/bin')
    axes[i].grid(True, alpha=0.3)
    
    # Add inferred changepoints
    for j, cp in enumerate(tau_mean):
        axes[i].axvline(cp, color='red', linestyle='-', alpha=0.8, linewidth=2,
                       label='Inferred CP' if i == 0 and j == 0 else '')
    
    # Add true changepoints
    for j, cp in enumerate(true_changepoints):
        axes[i].axvline(cp, color='green', linestyle='--', alpha=0.5,
                       label='True CP' if i == 0 and j == 0 else '')

axes[0].legend()
axes[-1].set_xlabel('Time (bins)')
plt.suptitle('Single Taste Poisson Model Results\n(Red: Inferred, Green: True Changepoints)', fontsize=14)
plt.tight_layout()
plt.show()

## Model Parameters

Examine the inferred firing rate parameters for each neuron and state.

In [None]:
# Extract firing rate parameters
lambda_samples = trace.posterior['lambda'].values.reshape(-1, n_neurons, n_states)
lambda_mean = np.mean(lambda_samples, axis=0)
lambda_std = np.std(lambda_samples, axis=0)

print("Inferred Firing Rates by State:")
for state in range(n_states):
    print(f"\nState {state + 1}:")
    for neuron in range(n_neurons):
        print(f"  Neuron {neuron + 1}: λ = {lambda_mean[neuron, state]:.3f} ± {lambda_std[neuron, state]:.3f} spikes/bin")

In [None]:
# Plot firing rate evolution across states
fig, ax = plt.subplots(1, 1, figsize=(12, 8))

# Create a heatmap of firing rates
im = ax.imshow(lambda_mean, aspect='auto', cmap='viridis', interpolation='nearest')
ax.set_xlabel('State')
ax.set_ylabel('Neuron')
ax.set_title('Firing Rates by Neuron and State')
ax.set_xticks(range(n_states))
ax.set_xticklabels([f'State {i+1}' for i in range(n_states)])
ax.set_yticks(range(n_neurons))
ax.set_yticklabels([f'Neuron {i+1}' for i in range(n_neurons)])

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Firing Rate (spikes/bin)')

# Add text annotations
for i in range(n_neurons):
    for j in range(n_states):
        text = ax.text(j, i, f'{lambda_mean[i, j]:.2f}',
                      ha="center", va="center", color="white", fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
# Plot firing rate trajectories for each neuron
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

for neuron in range(n_neurons):
    ax.plot(range(1, n_states+1), lambda_mean[neuron, :], 'o-', 
            label=f'Neuron {neuron+1}', linewidth=2, markersize=8)

ax.set_xlabel('State')
ax.set_ylabel('Firing Rate (spikes/bin)')
ax.set_title('Firing Rate Evolution Across States')
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax.grid(True, alpha=0.3)
ax.set_xticks(range(1, n_states+1))

plt.tight_layout()
plt.show()

## Model Validation

Compare model predictions with observed data.

In [None]:
# Generate predicted firing rates using inferred parameters
# This is a simplified reconstruction - the actual model uses sigmoid transitions

predicted_rates = np.zeros((n_neurons, n_timepoints))
time_points = np.arange(n_timepoints)

# Assign states based on changepoints (simplified step function)
state_assignments = np.zeros(n_timepoints, dtype=int)
changepoints_with_bounds = np.concatenate([[0], tau_mean, [n_timepoints]])

for i in range(n_states):
    start_idx = int(changepoints_with_bounds[i])
    end_idx = int(changepoints_with_bounds[i+1])
    state_assignments[start_idx:end_idx] = i

# Assign firing rates based on state
for t in range(n_timepoints):
    state = state_assignments[t]
    predicted_rates[:, t] = lambda_mean[:, state]

# Plot comparison
fig, axes = plt.subplots(min(4, n_neurons), 1, figsize=(12, 2*min(4, n_neurons)), sharex=True)
if min(4, n_neurons) == 1:
    axes = [axes]

for i in range(min(4, n_neurons)):
    axes[i].plot(mean_rates[i, :], 'b-', alpha=0.7, linewidth=2, label='Observed')
    axes[i].plot(predicted_rates[i, :], 'r--', alpha=0.8, linewidth=2, label='Predicted')
    axes[i].set_ylabel(f'Neuron {i+1}\nSpikes/bin')
    axes[i].grid(True, alpha=0.3)
    axes[i].legend()
    
    # Add changepoints
    for cp in tau_mean:
        axes[i].axvline(cp, color='gray', linestyle=':', alpha=0.5)

axes[-1].set_xlabel('Time (bins)')
plt.suptitle('Model Validation: Observed vs Predicted Firing Rates', fontsize=14)
plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated the `SingleTastePoisson` model:

1. **Data Generation**: Created synthetic Poisson spike count data with known changepoints
2. **Data Visualization**: Showed trial-averaged firing rates and raster plots
3. **Model Fitting**: Used variational inference to fit the changepoint model
4. **Results Analysis**: Extracted and visualized changepoint estimates and firing rate parameters
5. **Model Validation**: Compared predicted vs observed firing rates

The model successfully detected changepoints in Poisson spike count data for a single experimental condition. Key features:

**Strengths:**
- Appropriate for count data (spikes)
- Handles trial-to-trial variability
- Provides uncertainty estimates for changepoints
- Infers firing rate parameters for each state

**Use Cases:**
- Single-condition neural recordings
- Spike train analysis
- Detecting state transitions in neural activity
- Time-resolved analysis of neural responses