# SingleTastePoisson Model Demo

This notebook demonstrates the usage of the SingleTastePoisson model with dummy data.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add pytau to path
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), '..', '..'))

from pytau.changepoint_model import SingleTastePoisson, gen_test_array

## Generate Test Data

In [None]:
# Generate dummy data with obvious changepoints for Poisson model
np.random.seed(42)
n_trials = 10
n_neurons = 5
n_time = 100
n_states = 3

# Create data with obvious state transitions in firing rates
# State 1: low firing rate
# State 2: high firing rate
# State 3: medium firing rate
state_rates = [2.0, 8.0, 5.0]  # Different firing rates for each state
transition_points = [33, 66]  # Clear transition points

test_data = np.zeros((n_trials, n_neurons, n_time), dtype=int)
for trial in range(n_trials):
    # Add some trial-to-trial variability in transition points
    trial_transitions = [t + np.random.randint(-3, 4) for t in transition_points]
    trial_transitions = [max(5, min(n_time-5, t)) for t in trial_transitions]  # Keep within bounds
    
    for neuron in range(n_neurons):
        # Add some neuron-specific variability to rates
        neuron_rates = [r * (0.7 + 0.6 * np.random.random()) for r in state_rates]
        
        # Generate data for each state
        test_data[trial, neuron, :trial_transitions[0]] = np.random.poisson(neuron_rates[0], trial_transitions[0])
        test_data[trial, neuron, trial_transitions[0]:trial_transitions[1]] = np.random.poisson(neuron_rates[1], trial_transitions[1] - trial_transitions[0])
        test_data[trial, neuron, trial_transitions[1]:] = np.random.poisson(neuron_rates[2], n_time - trial_transitions[1])

print(f"Generated test data shape: {test_data.shape}")
print(f"Data range: [{test_data.min()}, {test_data.max()}]")
print(f"Data type: {test_data.dtype}")
print(f"State rates: {state_rates}")
print(f"Transition points (approx): {transition_points}")

## Initialize and Fit Model

In [None]:
# Initialize model
model = SingleTastePoisson(
    data_array=test_data,
    n_states=n_states,
    fit_type='vi'
)

# Generate the PyMC model
model.generate_model()
print("Model generated successfully")

## Run Model Test

In [None]:
# Test the model
try:
    model.test()
    print("Model test completed successfully")
except Exception as e:
    print(f"Model test failed: {e}")

## Visualize Input Data

In [None]:
# Plot the input data
plt.figure(figsize=(15, 8))

# Plot time series for first neuron, first few trials
plt.subplot(2, 2, 1)
for trial in range(min(3, test_data.shape[0])):
    plt.plot(test_data[trial, 0, :], 'o-', markersize=2, alpha=0.7, label=f'Trial {trial+1}')
plt.title('Neuron 1 - Time Series (First 3 Trials)')
plt.xlabel('Time')
plt.ylabel('Count')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot average across trials for each neuron
plt.subplot(2, 2, 2)
for neuron in range(min(3, test_data.shape[1])):
    avg_response = test_data[:, neuron, :].mean(axis=0)
    plt.plot(avg_response, 'o-', markersize=2, alpha=0.7, label=f'Neuron {neuron+1}')
plt.title('Average Response per Neuron')
plt.xlabel('Time')
plt.ylabel('Average Count')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot histogram
plt.subplot(2, 2, 3)
plt.hist(test_data.flatten(), bins=range(int(test_data.max())+2), alpha=0.7, edgecolor='black')
plt.title('Count Distribution')
plt.xlabel('Count')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)

# Plot heatmap of first neuron
plt.subplot(2, 2, 4)
plt.imshow(test_data[:, 0, :], aspect='auto', cmap='viridis', interpolation='nearest')
plt.colorbar(label='Count')
plt.title('Neuron 1 - Trial x Time Heatmap')
plt.xlabel('Time')
plt.ylabel('Trial')

plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated:
1. Generating dummy Poisson count data
2. Initializing the SingleTastePoisson model
3. Running model tests
4. Visualizing the input data

The SingleTastePoisson model is designed for detecting changepoints in Poisson count data for single taste experiments.