# GaussianChangepointMeanDirichlet Model Demo

This notebook demonstrates the `GaussianChangepointMeanDirichlet` model for detecting changepoints in Gaussian data using a Dirichlet process prior to automatically determine the number of states.

## Model Overview
- **Data Type**: Gaussian/Normal distributed data
- **Input Shape**: 2D array (dimensions × time)
- **Detects Changes In**: Mean only (constant variance)
- **Special Feature**: Uses Dirichlet process to automatically infer number of states
- **Use Case**: When the number of changepoints is unknown

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add pytau to path
sys.path.append('../../')

from pytau.changepoint_model import GaussianChangepointMeanDirichlet, gen_test_array
import pymc as pm

## Generate Synthetic Data

Create synthetic Gaussian data with unknown number of changepoints.

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

# Parameters
n_dimensions = 3
n_timepoints = 150
true_n_states = 4  # True number of states (unknown to model)
max_states = 10    # Maximum states to consider

# Generate test data using the built-in function
data_array = gen_test_array((n_dimensions, n_timepoints), n_states=true_n_states, type="normal")

print(f"Data shape: {data_array.shape}")
print(f"Data range: [{data_array.min():.2f}, {data_array.max():.2f}]")
print(f"True number of states: {true_n_states}")
print(f"Maximum states to consider: {max_states}")

## Visualize Input Data

Plot the synthetic data to visualize the changepoints.

In [None]:
# Plot the synthetic data
fig, axes = plt.subplots(n_dimensions, 1, figsize=(12, 2*n_dimensions), sharex=True)
if n_dimensions == 1:
    axes = [axes]

for i in range(n_dimensions):
    axes[i].plot(data_array[i, :], 'b-', alpha=0.7)
    axes[i].set_ylabel(f'Dimension {i+1}')
    axes[i].grid(True, alpha=0.3)
    
    # Add vertical lines at expected changepoints
    changepoint_locs = np.linspace(0, n_timepoints, true_n_states+1)[1:-1]
    for cp in changepoint_locs:
        axes[i].axvline(cp, color='red', linestyle='--', alpha=0.5)

axes[-1].set_xlabel('Time')
plt.suptitle('Synthetic Gaussian Data with Unknown Number of Changepoints\n(Red dashed lines show true changepoints)', fontsize=14)
plt.tight_layout()
plt.show()

## Create and Fit Model

Initialize the GaussianChangepointMeanDirichlet model and fit it to the data.

In [None]:
# Create model instance
model_instance = GaussianChangepointMeanDirichlet(data_array, max_states=max_states)

# Generate the PyMC model
model = model_instance.generate_model()

print("Model created successfully!")
print(f"Model variables: {[var.name for var in model.unobserved_RVs]}")

In [None]:
# Fit the model using variational inference
with model:
    # Use ADVI for faster inference
    inference = pm.ADVI()
    approx = pm.fit(n=8000, method=inference)  # More iterations for Dirichlet process
    
    # Sample from the approximation
    trace = approx.sample(draws=1000)

print("Model fitting completed!")

## Analyze Results

Extract and visualize the inferred changepoints and determine the effective number of states.

In [None]:
# Extract changepoint estimates
tau_samples = trace.posterior['tau'].values.reshape(-1, max_states-1)
tau_mean = np.mean(tau_samples, axis=0)
tau_std = np.std(tau_samples, axis=0)

# Extract weights to determine effective number of states
w_samples = trace.posterior['w'].values.reshape(-1, max_states)
w_mean = np.mean(w_samples, axis=0)

# Determine effective number of states (weights > threshold)
weight_threshold = 0.05
effective_states = np.sum(w_mean > weight_threshold)

print(f"Effective number of states (weight > {weight_threshold}): {effective_states}")
print(f"True number of states: {true_n_states}")
print("\nState weights:")
for i, weight in enumerate(w_mean):
    print(f"  State {i+1}: {weight:.4f}")

print("\nSignificant Changepoints:")
significant_changepoints = tau_mean[:effective_states-1]
for i, cp in enumerate(significant_changepoints):
    print(f"  Changepoint {i+1}: {cp:.2f} ± {tau_std[i]:.2f}")

In [None]:
# Plot state weights
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

bars = ax.bar(range(1, max_states+1), w_mean, alpha=0.7)
ax.axhline(weight_threshold, color='red', linestyle='--', label=f'Threshold ({weight_threshold})')
ax.set_xlabel('State')
ax.set_ylabel('Weight')
ax.set_title('Dirichlet Process State Weights')
ax.legend()
ax.grid(True, alpha=0.3)

# Color significant states differently
for i, bar in enumerate(bars):
    if w_mean[i] > weight_threshold:
        bar.set_color('green')
        bar.set_alpha(0.8)

plt.tight_layout()
plt.show()

In [None]:
# Plot significant changepoint posterior distributions
n_significant_cp = len(significant_changepoints)
if n_significant_cp > 0:
    fig, axes = plt.subplots(1, n_significant_cp, figsize=(4*n_significant_cp, 4))
    if n_significant_cp == 1:
        axes = [axes]

    true_changepoints = np.linspace(0, n_timepoints, true_n_states+1)[1:-1]
    
    for i in range(n_significant_cp):
        axes[i].hist(tau_samples[:, i], bins=50, alpha=0.7, density=True)
        axes[i].axvline(tau_mean[i], color='red', linestyle='-', label=f'Mean: {tau_mean[i]:.1f}')
        if i < len(true_changepoints):
            axes[i].axvline(true_changepoints[i], color='green', linestyle='--', label=f'True: {true_changepoints[i]:.1f}')
        axes[i].set_xlabel('Time')
        axes[i].set_ylabel('Density')
        axes[i].set_title(f'Changepoint {i+1} Posterior')
        axes[i].legend()
        axes[i].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()
else:
    print("No significant changepoints detected.")

## Visualize Model Fit

Show the original data with inferred changepoints overlaid.

In [None]:
# Plot data with inferred changepoints
fig, axes = plt.subplots(n_dimensions, 1, figsize=(12, 2*n_dimensions), sharex=True)
if n_dimensions == 1:
    axes = [axes]

true_changepoints = np.linspace(0, n_timepoints, true_n_states+1)[1:-1]

for i in range(n_dimensions):
    axes[i].plot(data_array[i, :], 'b-', alpha=0.7, label='Data')
    axes[i].set_ylabel(f'Dimension {i+1}')
    axes[i].grid(True, alpha=0.3)
    
    # Add inferred changepoints (only significant ones)
    for j, cp in enumerate(significant_changepoints):
        axes[i].axvline(cp, color='red', linestyle='-', alpha=0.8, 
                       label='Inferred CP' if i == 0 and j == 0 else '')
    
    # Add true changepoints
    for j, cp in enumerate(true_changepoints):
        axes[i].axvline(cp, color='green', linestyle='--', alpha=0.5,
                       label='True CP' if i == 0 and j == 0 else '')

axes[0].legend()
axes[-1].set_xlabel('Time')
plt.suptitle(f'Dirichlet Process Model Results\nInferred {effective_states} states (True: {true_n_states})', fontsize=14)
plt.tight_layout()
plt.show()

## Model Parameters

Examine the inferred mean parameters for significant states.

In [None]:
# Extract mean parameters for significant states
mu_samples = trace.posterior['mu'].values.reshape(-1, n_dimensions, max_states)
mu_mean = np.mean(mu_samples, axis=0)

# Extract global sigma (constant across states)
sigma_samples = trace.posterior['sigma'].values.reshape(-1, n_dimensions)
sigma_mean = np.mean(sigma_samples, axis=0)

print("Inferred Parameters for Significant States:")
for state in range(effective_states):
    print(f"\nState {state + 1} (weight: {w_mean[state]:.4f}):")
    for dim in range(n_dimensions):
        print(f"  Dimension {dim + 1}: μ = {mu_mean[dim, state]:.3f}")

print("\nConstant Standard Deviations:")
for dim in range(n_dimensions):
    print(f"  Dimension {dim + 1}: σ = {sigma_mean[dim]:.3f}")

In [None]:
# Plot mean parameter evolution across significant states
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

# Mean parameters for significant states only
for dim in range(n_dimensions):
    ax.plot(range(1, effective_states+1), mu_mean[dim, :effective_states], 
            'o-', label=f'Dim {dim+1}', linewidth=2, markersize=8)

ax.set_xlabel('Significant State')
ax.set_ylabel('Mean (μ)')
ax.set_title(f'Mean Parameters for {effective_states} Significant States\n(Dirichlet Process Automatic State Selection)')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_xticks(range(1, effective_states+1))

plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated the `GaussianChangepointMeanDirichlet` model:

1. **Data Generation**: Created synthetic Gaussian data with unknown number of changepoints
2. **Model Fitting**: Used Dirichlet process prior to automatically determine number of states
3. **State Selection**: Identified significant states based on weight thresholding
4. **Results Analysis**: Compared inferred vs. true number of states and changepoint locations

The Dirichlet process model is particularly useful when:
- The number of changepoints is unknown
- You want to avoid overfitting by automatically selecting model complexity
- You have sufficient data to support the more complex inference procedure

**Key Advantages:**
- Automatic model selection
- Principled uncertainty quantification
- Avoids manual tuning of number of states

**Considerations:**
- Requires more computational resources
- May need larger datasets for reliable inference
- Weight threshold selection affects final model complexity