# GaussianChangepointMeanDirichlet Model Demo

This notebook demonstrates the usage of the GaussianChangepointMeanDirichlet model with dummy data.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add pytau to path
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), '..', '..'))

from pytau.changepoint_model import GaussianChangepointMeanDirichlet, gen_test_array

## Generate Test Data

In [None]:
# Generate dummy data for Gaussian model
np.random.seed(42)
array_size = (10, 100)  # (trials, time)
n_states = 3

# Generate test array with Gaussian data
test_data = gen_test_array(array_size, n_states, type="normal")
print(f"Generated test data shape: {test_data.shape}")
print(f"Data range: [{test_data.min():.2f}, {test_data.max():.2f}]")

## Initialize and Fit Model

In [None]:
# Initialize model
model = GaussianChangepointMeanDirichlet(
    data_array=test_data,
    n_states=n_states,
    fit_type='vi'
)

# Generate the PyMC model
model.generate_model()
print("Model generated successfully")

## Run Model Test

In [None]:
# Test the model
try:
    model.test()
    print("Model test completed successfully")
except Exception as e:
    print(f"Model test failed: {e}")

## Visualize Input Data

In [None]:
# Plot the input data
plt.figure(figsize=(12, 6))

# Plot time series
plt.subplot(1, 2, 1)
plt.plot(test_data.flatten())
plt.title('Input Time Series Data')
plt.xlabel('Time')
plt.ylabel('Value')
plt.grid(True, alpha=0.3)

# Plot histogram
plt.subplot(1, 2, 2)
plt.hist(test_data.flatten(), bins=30, alpha=0.7, edgecolor='black')
plt.title('Data Distribution')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated:
1. Generating dummy Gaussian data
2. Initializing the GaussianChangepointMeanDirichlet model
3. Running model tests
4. Visualizing the input data

The GaussianChangepointMeanDirichlet model uses a Dirichlet prior for changepoint probabilities in Gaussian data.