# GaussianChangepointMeanDirichlet Model Demo

This notebook demonstrates the usage of the GaussianChangepointMeanDirichlet model with dummy data.

In [None]:
import numpy as np
# Try to import matplotlib, create dummy if not available
try:
    import matplotlib.pyplot as plt
except ImportError:
    print('matplotlib not available - creating dummy plt object')
    class DummyPlt:
        def __getattr__(self, name):
            def dummy_func(*args, **kwargs):
                pass
            return dummy_func
    plt = DummyPlt()
import sys
import os

# Add pytau to path
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), '..', '..'))

from pytau.changepoint_model import GaussianChangepointMeanDirichlet, gen_test_array

## Generate Test Data

In [None]:
# Generate dummy data with obvious changepoints for Gaussian model
np.random.seed(42)
n_trials = 10
n_time = 100
n_states = 3

# Create data with obvious state transitions in mean only
# State 1: low mean
# State 2: high mean
# State 3: medium mean
state_means = [1.0, 4.0, 2.5]
state_std = 0.8  # Fixed variance across states
transition_points = [33, 66]  # Clear transition points

test_data = np.zeros((n_trials, n_time))
for trial in range(n_trials):
    # Add some trial-to-trial variability in transition points
    trial_transitions = [t + np.random.randint(-3, 4) for t in transition_points]
    trial_transitions = [max(5, min(n_time-5, t)) for t in trial_transitions]  # Keep within bounds
    
    # Generate data for each state
    test_data[trial, :trial_transitions[0]] = np.random.normal(state_means[0], state_std, trial_transitions[0])
    test_data[trial, trial_transitions[0]:trial_transitions[1]] = np.random.normal(state_means[1], state_std, trial_transitions[1] - trial_transitions[0])
    test_data[trial, trial_transitions[1]:] = np.random.normal(state_means[2], state_std, n_time - trial_transitions[1])

print(f"Generated test data shape: {test_data.shape}")
print(f"Data range: [{test_data.min():.2f}, {test_data.max():.2f}]")
print(f"State means: {state_means}")
print(f"Fixed std: {state_std}")
print(f"Transition points (approx): {transition_points}")

## Initialize and Fit Model

In [None]:
# Initialize model
model = GaussianChangepointMeanDirichlet(
    data_array=test_data,
    n_states=n_states,
    fit_type='vi'
)

# Generate the PyMC model
model.generate_model()
print("Model generated successfully")

## Run Model Test

In [None]:
# Test the model
try:
    model.test()
    print("Model test completed successfully")
except Exception as e:
    print(f"Model test failed: {e}")

## Visualize Input Data

In [None]:
# Plot the input data
plt.figure(figsize=(12, 6))

# Plot time series
plt.subplot(1, 2, 1)
plt.plot(test_data.flatten())
plt.title('Input Time Series Data')
plt.xlabel('Time')
plt.ylabel('Value')
plt.grid(True, alpha=0.3)

# Plot histogram
plt.subplot(1, 2, 2)
plt.hist(test_data.flatten(), bins=30, alpha=0.7, edgecolor='black')
plt.title('Data Distribution')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated:
1. Generating dummy Gaussian data
2. Initializing the GaussianChangepointMeanDirichlet model
3. Running model tests
4. Visualizing the input data

The GaussianChangepointMeanDirichlet model uses a Dirichlet prior for changepoint probabilities in Gaussian data.