# AllTastePoisson Model Demo

This notebook demonstrates the usage of the AllTastePoisson model with dummy data.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add pytau to path
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), '..', '..'))

from pytau.changepoint_model import AllTastePoisson, gen_test_array

## Generate Test Data

In [None]:
# Generate dummy data for multi-taste Poisson model
np.random.seed(42)
n_states = 3
n_tastes = 4  # Multiple tastes
n_trials = 10
n_neurons = 5
n_time = 100

# Generate test array with Poisson data for multiple tastes
# Shape: (tastes, trials, neurons, time_bins)
test_data = np.random.poisson(5, size=(n_tastes, n_trials, n_neurons, n_time))
print(f"Generated test data shape: {test_data.shape}")
print(f"Data range: [{test_data.min()}, {test_data.max()}]")
print(f"Data type: {test_data.dtype}")

## Initialize and Fit Model

In [None]:
# Initialize model
model = AllTastePoisson(
    data_array=test_data,
    n_states=n_states,
    fit_type='vi'
)

# Generate the PyMC model
model.generate_model()
print("Model generated successfully")

## Run Model Test

In [None]:
# Test the model
try:
    model.test()
    print("Model test completed successfully")
except Exception as e:
    print(f"Model test failed: {e}")

## Visualize Input Data

In [None]:
# Plot the input data for all tastes
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()

for taste_idx in range(min(4, n_tastes)):
    # Plot average across trials for each taste
    taste_data = test_data[taste_idx].mean(axis=0)  # Average across trials
    axes[taste_idx].plot(taste_data, 'o-', markersize=3)
    axes[taste_idx].set_title(f'Taste {taste_idx + 1} - Average Response')
    axes[taste_idx].set_xlabel('Time')
    axes[taste_idx].set_ylabel('Count')
    axes[taste_idx].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Plot overall distribution
plt.figure(figsize=(10, 6))
plt.hist(test_data.flatten(), bins=range(int(test_data.max())+2), alpha=0.7, edgecolor='black')
plt.title('Overall Count Distribution Across All Tastes')
plt.xlabel('Count')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)
plt.show()

## Summary

This notebook demonstrated:
1. Generating dummy Poisson count data for multiple tastes
2. Initializing the AllTastePoisson model
3. Running model tests
4. Visualizing the input data across different tastes

The AllTastePoisson model is designed for detecting changepoints in Poisson count data across multiple taste conditions simultaneously.