# AllTastePoisson Model Demo

This notebook demonstrates the `AllTastePoisson` model for detecting changepoints in Poisson-distributed data across multiple experimental conditions (tastes).

## Model Overview
- **Data Type**: Poisson distributed data (spike counts)
- **Input Shape**: 4D array (tastes × neurons × trials × time)
- **Scope**: Multiple tastes/conditions with hierarchical structure
- **Use Case**: Multi-condition neural spike train changepoint detection

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add pytau to path
sys.path.append('../../')

from pytau.changepoint_model import AllTastePoisson, gen_test_array
import pymc as pm

## Generate Synthetic Data

Create synthetic Poisson spike count data for multiple tastes with known changepoints.

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

# Parameters
n_tastes = 3
n_neurons = 4
n_trials = 10
n_timepoints = 60
n_states = 3

# Generate test data for multiple tastes
data_array = np.zeros((n_tastes, n_neurons, n_trials, n_timepoints))

for taste in range(n_tastes):
    # Generate data for each taste with slightly different parameters
    taste_data = gen_test_array((n_neurons, n_trials, n_timepoints), 
                               n_states=n_states, type="poisson")
    # Add taste-specific scaling
    taste_scaling = 0.5 + taste * 0.3  # Different baseline rates for different tastes
    data_array[taste] = taste_data * taste_scaling

print(f"Data shape: {data_array.shape}")
print(f"Data range: [{data_array.min()}, {data_array.max()}]")
print(f"Mean firing rate per taste: {[data_array[i].mean():.2f} for i in range(n_tastes)]}")

## Visualize Input Data

Plot the synthetic spike count data to visualize the changepoints across tastes.

In [None]:
# Plot average firing rates across trials for each taste
mean_rates = np.mean(data_array, axis=2)  # Average across trials

# Expected changepoint locations
changepoint_locs = np.linspace(0, n_timepoints, n_states+1)[1:-1]

fig, axes = plt.subplots(n_tastes, n_neurons, figsize=(15, 3*n_tastes), sharex=True)
if n_tastes == 1:
    axes = axes.reshape(1, -1)
if n_neurons == 1:
    axes = axes.reshape(-1, 1)

taste_names = ['Sweet', 'Sour', 'Bitter'][:n_tastes]
colors = ['blue', 'orange', 'green'][:n_tastes]

for taste in range(n_tastes):
    for neuron in range(n_neurons):
        axes[taste, neuron].plot(mean_rates[taste, neuron, :], 
                               color=colors[taste], alpha=0.7, linewidth=2)
        axes[taste, neuron].set_ylabel(f'{taste_names[taste]}\nNeuron {neuron+1}')
        axes[taste, neuron].grid(True, alpha=0.3)
        
        # Add vertical lines at expected changepoints
        for cp in changepoint_locs:
            axes[taste, neuron].axvline(cp, color='red', linestyle='--', alpha=0.5)

# Set x-label only for bottom row
for neuron in range(n_neurons):
    axes[-1, neuron].set_xlabel('Time (bins)')

plt.suptitle('Synthetic Multi-Taste Poisson Spike Data (Trial-Averaged)\n(Red dashed lines show true changepoints)', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Plot comparison of firing rates across tastes for each neuron
fig, axes = plt.subplots(n_neurons, 1, figsize=(12, 2*n_neurons), sharex=True)
if n_neurons == 1:
    axes = [axes]

for neuron in range(n_neurons):
    for taste in range(n_tastes):
        axes[neuron].plot(mean_rates[taste, neuron, :], 
                         color=colors[taste], alpha=0.7, linewidth=2, 
                         label=taste_names[taste])
    
    axes[neuron].set_ylabel(f'Neuron {neuron+1}\nSpikes/bin')
    axes[neuron].grid(True, alpha=0.3)
    axes[neuron].legend()
    
    # Add changepoints
    for cp in changepoint_locs:
        axes[neuron].axvline(cp, color='red', linestyle='--', alpha=0.5)

axes[-1].set_xlabel('Time (bins)')
plt.suptitle('Firing Rate Comparison Across Tastes', fontsize=14)
plt.tight_layout()
plt.show()

## Create and Fit Model

Initialize the AllTastePoisson model and fit it to the data.

In [None]:
# Create model instance
model_instance = AllTastePoisson(data_array, n_states=n_states)

# Generate the PyMC model
model = model_instance.generate_model()

print("Model created successfully!")
print(f"Model variables: {[var.name for var in model.unobserved_RVs]}")

In [None]:
# Fit the model using variational inference
with model:
    # Use ADVI for faster inference
    inference = pm.ADVI()
    approx = pm.fit(n=8000, method=inference)  # More iterations for hierarchical model
    
    # Sample from the approximation
    trace = approx.sample(draws=1000)

print("Model fitting completed!")

## Analyze Results

Extract and visualize the inferred changepoints and firing rate parameters across tastes.

In [None]:
# Extract changepoint estimates
tau_samples = trace.posterior['tau'].values.reshape(-1, n_states-1)
tau_mean = np.mean(tau_samples, axis=0)
tau_std = np.std(tau_samples, axis=0)

print("Inferred Changepoints (shared across tastes):")
for i, (mean_cp, std_cp) in enumerate(zip(tau_mean, tau_std)):
    print(f"  Changepoint {i+1}: {mean_cp:.2f} ± {std_cp:.2f}")

# True changepoints for comparison
true_changepoints = np.linspace(0, n_timepoints, n_states+1)[1:-1]
print(f"\nTrue Changepoints: {true_changepoints}")

In [None]:
# Plot changepoint posterior distributions
fig, axes = plt.subplots(1, n_states-1, figsize=(4*(n_states-1), 4))
if n_states-1 == 1:
    axes = [axes]

for i in range(n_states-1):
    axes[i].hist(tau_samples[:, i], bins=50, alpha=0.7, density=True)
    axes[i].axvline(tau_mean[i], color='red', linestyle='-', label=f'Mean: {tau_mean[i]:.1f}')
    axes[i].axvline(true_changepoints[i], color='green', linestyle='--', label=f'True: {true_changepoints[i]:.1f}')
    axes[i].set_xlabel('Time (bins)')
    axes[i].set_ylabel('Density')
    axes[i].set_title(f'Changepoint {i+1} Posterior')
    axes[i].legend()
    axes[i].grid(True, alpha=0.3)

plt.suptitle('Shared Changepoints Across All Tastes', fontsize=14)
plt.tight_layout()
plt.show()

## Visualize Model Fit

Show the original data with inferred changepoints overlaid.

In [None]:
# Plot data with inferred changepoints
fig, axes = plt.subplots(n_tastes, n_neurons, figsize=(15, 3*n_tastes), sharex=True)
if n_tastes == 1:
    axes = axes.reshape(1, -1)
if n_neurons == 1:
    axes = axes.reshape(-1, 1)

for taste in range(n_tastes):
    for neuron in range(n_neurons):
        axes[taste, neuron].plot(mean_rates[taste, neuron, :], 
                               color=colors[taste], alpha=0.7, linewidth=2, label='Data')
        axes[taste, neuron].set_ylabel(f'{taste_names[taste]}\nNeuron {neuron+1}')
        axes[taste, neuron].grid(True, alpha=0.3)
        
        # Add inferred changepoints
        for j, cp in enumerate(tau_mean):
            axes[taste, neuron].axvline(cp, color='red', linestyle='-', alpha=0.8, linewidth=2,
                                      label='Inferred CP' if taste == 0 and neuron == 0 and j == 0 else '')
        
        # Add true changepoints
        for j, cp in enumerate(true_changepoints):
            axes[taste, neuron].axvline(cp, color='green', linestyle='--', alpha=0.5,
                                      label='True CP' if taste == 0 and neuron == 0 and j == 0 else '')

axes[0, 0].legend()
for neuron in range(n_neurons):
    axes[-1, neuron].set_xlabel('Time (bins)')

plt.suptitle('All Taste Poisson Model Results\n(Red: Inferred, Green: True Changepoints)', fontsize=14)
plt.tight_layout()
plt.show()

## Model Parameters

Examine the inferred firing rate parameters for each taste, neuron, and state.

In [None]:
# Extract firing rate parameters
lambda_samples = trace.posterior['lambda'].values.reshape(-1, n_tastes, n_neurons, n_states)
lambda_mean = np.mean(lambda_samples, axis=0)
lambda_std = np.std(lambda_samples, axis=0)

print("Inferred Firing Rates by Taste, Neuron, and State:")
for taste in range(n_tastes):
    print(f"\n{taste_names[taste]}:")
    for state in range(n_states):
        print(f"  State {state + 1}:")
        for neuron in range(n_neurons):
            print(f"    Neuron {neuron + 1}: λ = {lambda_mean[taste, neuron, state]:.3f} ± {lambda_std[taste, neuron, state]:.3f}")

In [None]:
# Plot firing rate heatmaps for each taste
fig, axes = plt.subplots(1, n_tastes, figsize=(5*n_tastes, 6))
if n_tastes == 1:
    axes = [axes]

for taste in range(n_tastes):
    im = axes[taste].imshow(lambda_mean[taste], aspect='auto', cmap='viridis', interpolation='nearest')
    axes[taste].set_xlabel('State')
    axes[taste].set_ylabel('Neuron')
    axes[taste].set_title(f'{taste_names[taste]} - Firing Rates')
    axes[taste].set_xticks(range(n_states))
    axes[taste].set_xticklabels([f'State {i+1}' for i in range(n_states)])
    axes[taste].set_yticks(range(n_neurons))
    axes[taste].set_yticklabels([f'Neuron {i+1}' for i in range(n_neurons)])
    
    # Add text annotations
    for neuron in range(n_neurons):
        for state in range(n_states):
            text = axes[taste].text(state, neuron, f'{lambda_mean[taste, neuron, state]:.2f}',
                                  ha="center", va="center", color="white", fontweight='bold')
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=axes[taste])
    cbar.set_label('Firing Rate (spikes/bin)')

plt.tight_layout()
plt.show()

In [None]:
# Plot firing rate trajectories across states for each neuron and taste
fig, axes = plt.subplots(n_neurons, 1, figsize=(10, 2*n_neurons), sharex=True)
if n_neurons == 1:
    axes = [axes]

for neuron in range(n_neurons):
    for taste in range(n_tastes):
        axes[neuron].plot(range(1, n_states+1), lambda_mean[taste, neuron, :], 
                         'o-', color=colors[taste], label=taste_names[taste], 
                         linewidth=2, markersize=8)
    
    axes[neuron].set_ylabel(f'Neuron {neuron+1}\nFiring Rate')
    axes[neuron].legend()
    axes[neuron].grid(True, alpha=0.3)
    axes[neuron].set_xticks(range(1, n_states+1))

axes[-1].set_xlabel('State')
plt.suptitle('Firing Rate Evolution Across States by Taste', fontsize=14)
plt.tight_layout()
plt.show()

## Hierarchical Structure Analysis

Examine the hierarchical parameters that capture similarities across tastes.

In [None]:
# Compare firing rates across tastes for each state
fig, axes = plt.subplots(1, n_states, figsize=(5*n_states, 6))
if n_states == 1:
    axes = [axes]

for state in range(n_states):
    # Create a grouped bar plot
    x = np.arange(n_neurons)
    width = 0.25
    
    for taste in range(n_tastes):
        offset = (taste - n_tastes/2 + 0.5) * width
        axes[state].bar(x + offset, lambda_mean[taste, :, state], width, 
                       color=colors[taste], alpha=0.7, label=taste_names[taste])
    
    axes[state].set_xlabel('Neuron')
    axes[state].set_ylabel('Firing Rate (spikes/bin)')
    axes[state].set_title(f'State {state+1}')
    axes[state].set_xticks(x)
    axes[state].set_xticklabels([f'N{i+1}' for i in range(n_neurons)])
    axes[state].legend()
    axes[state].grid(True, alpha=0.3)

plt.suptitle('Firing Rate Comparison Across Tastes by State', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Calculate and visualize taste-specific differences
# Compute relative firing rates (normalized by mean across tastes)
mean_across_tastes = np.mean(lambda_mean, axis=0)  # Shape: (n_neurons, n_states)
relative_rates = lambda_mean / mean_across_tastes[np.newaxis, :, :]  # Broadcasting

fig, axes = plt.subplots(1, n_tastes, figsize=(5*n_tastes, 6))
if n_tastes == 1:
    axes = [axes]

for taste in range(n_tastes):
    im = axes[taste].imshow(relative_rates[taste], aspect='auto', cmap='RdBu_r', 
                           interpolation='nearest', vmin=0.5, vmax=1.5)
    axes[taste].set_xlabel('State')
    axes[taste].set_ylabel('Neuron')
    axes[taste].set_title(f'{taste_names[taste]} - Relative Firing Rates')
    axes[taste].set_xticks(range(n_states))
    axes[taste].set_xticklabels([f'State {i+1}' for i in range(n_states)])
    axes[taste].set_yticks(range(n_neurons))
    axes[taste].set_yticklabels([f'Neuron {i+1}' for i in range(n_neurons)])
    
    # Add text annotations
    for neuron in range(n_neurons):
        for state in range(n_states):
            text = axes[taste].text(state, neuron, f'{relative_rates[taste, neuron, state]:.2f}',
                                  ha="center", va="center", color="black", fontweight='bold')
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=axes[taste])
    cbar.set_label('Relative Firing Rate')

plt.suptitle('Taste-Specific Firing Rate Patterns\n(Relative to Cross-Taste Mean)', fontsize=14)
plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated the `AllTastePoisson` model:

1. **Data Generation**: Created synthetic multi-taste Poisson spike count data
2. **Data Visualization**: Showed firing patterns across tastes and neurons
3. **Model Fitting**: Used hierarchical Bayesian inference for multi-condition data
4. **Results Analysis**: Extracted shared changepoints and taste-specific parameters
5. **Hierarchical Analysis**: Examined taste-specific vs. shared patterns

The model successfully detected changepoints shared across multiple experimental conditions while capturing taste-specific firing rate patterns. Key features:

**Strengths:**
- **Hierarchical Structure**: Shares information across tastes while allowing taste-specific parameters
- **Shared Changepoints**: Assumes timing of state transitions is consistent across conditions
- **Condition-Specific Rates**: Allows different firing rates for each taste in each state
- **Statistical Power**: Leverages data from all conditions to improve changepoint detection

**Use Cases:**
- Multi-condition neural recordings (different stimuli, behaviors, etc.)
- Comparative analysis across experimental groups
- Studies where timing of state changes should be consistent but magnitudes may differ
- Population-level analysis with condition-specific effects

**Model Assumptions:**
- Changepoint timing is shared across all tastes/conditions
- Firing rate parameters can differ between tastes
- Hierarchical structure provides regularization and improved estimates

**When to Use:**
- Multiple experimental conditions with expected shared timing
- Need to compare response patterns across conditions
- Want to leverage statistical power from multiple conditions
- Interested in both shared and condition-specific effects