# AllTastePoisson Model Demo

This notebook demonstrates the usage of the AllTastePoisson model with dummy data.

In [None]:
import numpy as np
# Try to import matplotlib, create dummy if not available
try:
    import matplotlib.pyplot as plt
except ImportError:
    print('matplotlib not available - creating dummy plt object')
    class DummyPlt:
        def __getattr__(self, name):
            def dummy_func(*args, **kwargs):
                pass
            return dummy_func
    plt = DummyPlt()
import sys
import os

# Add pytau to path
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), '..', '..'))

from pytau.changepoint_model import AllTastePoisson, gen_test_array

## Generate Test Data

In [None]:
# Generate dummy data with obvious changepoints for multi-taste Poisson model
np.random.seed(42)
n_states = 3
n_tastes = 4  # Multiple tastes
n_trials = 10
n_neurons = 5
n_time = 100

# Create data with obvious state transitions in firing rates
# Different base rates for different tastes
taste_base_rates = [3.0, 4.0, 2.5, 5.0]  # Base rate for each taste
# State multipliers (same pattern across tastes)
state_multipliers = [0.5, 2.0, 1.2]  # Low, high, medium relative to base
transition_points = [33, 66]  # Clear transition points

test_data = np.zeros((n_tastes, n_trials, n_neurons, n_time), dtype=int)
for taste in range(n_tastes):
    base_rate = taste_base_rates[taste]
    state_rates = [base_rate * mult for mult in state_multipliers]
    
    for trial in range(n_trials):
        # Add some trial-to-trial variability in transition points
        trial_transitions = [t + np.random.randint(-3, 4) for t in transition_points]
        trial_transitions = [max(5, min(n_time-5, t)) for t in trial_transitions]  # Keep within bounds
        
        for neuron in range(n_neurons):
            # Add some neuron-specific variability to rates
            neuron_rates = [r * (0.7 + 0.6 * np.random.random()) for r in state_rates]
            
            # Generate data for each state
            test_data[taste, trial, neuron, :trial_transitions[0]] = np.random.poisson(neuron_rates[0], trial_transitions[0])
            test_data[taste, trial, neuron, trial_transitions[0]:trial_transitions[1]] = np.random.poisson(neuron_rates[1], trial_transitions[1] - trial_transitions[0])
            test_data[taste, trial, neuron, trial_transitions[1]:] = np.random.poisson(neuron_rates[2], n_time - trial_transitions[1])

print(f"Generated test data shape: {test_data.shape}")
print(f"Data range: [{test_data.min()}, {test_data.max()}]")
print(f"Data type: {test_data.dtype}")
print(f"Taste base rates: {taste_base_rates}")
print(f"State multipliers: {state_multipliers}")
print(f"Transition points (approx): {transition_points}")

## Initialize and Fit Model

In [None]:
# Initialize model
model = AllTastePoisson(
    data_array=test_data,
    n_states=n_states,
    fit_type='vi'
)

# Generate the PyMC model
model.generate_model()
print("Model generated successfully")

## Run Model Test

In [None]:
# Test the model
try:
    model.test()
    print("Model test completed successfully")
except Exception as e:
    print(f"Model test failed: {e}")

## Visualize Input Data

In [None]:
# Plot the input data for all tastes
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()

for taste_idx in range(min(4, n_tastes)):
    # Plot average across trials for each taste
    taste_data = test_data[taste_idx].mean(axis=0)  # Average across trials
    axes[taste_idx].plot(taste_data, 'o-', markersize=3)
    axes[taste_idx].set_title(f'Taste {taste_idx + 1} - Average Response')
    axes[taste_idx].set_xlabel('Time')
    axes[taste_idx].set_ylabel('Count')
    axes[taste_idx].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Plot overall distribution
plt.figure(figsize=(10, 6))
plt.hist(test_data.flatten(), bins=range(int(test_data.max())+2), alpha=0.7, edgecolor='black')
plt.title('Overall Count Distribution Across All Tastes')
plt.xlabel('Count')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)
plt.show()

## Summary

This notebook demonstrated:
1. Generating dummy Poisson count data for multiple tastes
2. Initializing the AllTastePoisson model
3. Running model tests
4. Visualizing the input data across different tastes

The AllTastePoisson model is designed for detecting changepoints in Poisson count data across multiple taste conditions simultaneously.