# GaussianChangepointMeanVar2D Model Demo

This notebook demonstrates the `GaussianChangepointMeanVar2D` model for detecting changepoints in Gaussian data where both mean and variance can change.

## Model Overview
- **Data Type**: Gaussian/Normal distributed data
- **Input Shape**: 2D array (dimensions × time)
- **Detects Changes In**: Both mean and variance
- **Use Case**: Neural data with varying firing rates and variability

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add pytau to path
sys.path.append('../../')

from pytau.changepoint_model import GaussianChangepointMeanVar2D, gen_test_array
import pymc as pm

## Generate Synthetic Data

Create synthetic Gaussian data with known changepoints in both mean and variance.

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

# Parameters
n_dimensions = 5
n_timepoints = 100
n_states = 3

# Generate test data using the built-in function
data_array = gen_test_array((n_dimensions, n_timepoints), n_states=n_states, type="normal")

print(f"Data shape: {data_array.shape}")
print(f"Data range: [{data_array.min():.2f}, {data_array.max():.2f}]")

## Visualize Input Data

Plot the synthetic data to visualize the changepoints.

In [None]:
# Plot the synthetic data
fig, axes = plt.subplots(n_dimensions, 1, figsize=(12, 2*n_dimensions), sharex=True)
if n_dimensions == 1:
    axes = [axes]

for i in range(n_dimensions):
    axes[i].plot(data_array[i, :], 'b-', alpha=0.7)
    axes[i].set_ylabel(f'Dimension {i+1}')
    axes[i].grid(True, alpha=0.3)
    
    # Add vertical lines at expected changepoints
    changepoint_locs = np.linspace(0, n_timepoints, n_states+1)[1:-1]
    for cp in changepoint_locs:
        axes[i].axvline(cp, color='red', linestyle='--', alpha=0.5)

axes[-1].set_xlabel('Time')
plt.suptitle('Synthetic Gaussian Data with Changepoints\n(Red dashed lines show true changepoints)', fontsize=14)
plt.tight_layout()
plt.show()

## Create and Fit Model

Initialize the GaussianChangepointMeanVar2D model and fit it to the data.

In [None]:
# Create model instance
model_instance = GaussianChangepointMeanVar2D(data_array, n_states=n_states)

# Generate the PyMC model
model = model_instance.generate_model()

print("Model created successfully!")
print(f"Model variables: {[var.name for var in model.unobserved_RVs]}")

In [None]:
# Fit the model using variational inference
with model:
    # Use ADVI for faster inference
    inference = pm.ADVI()
    approx = pm.fit(n=5000, method=inference)
    
    # Sample from the approximation
    trace = approx.sample(draws=1000)

print("Model fitting completed!")

## Analyze Results

Extract and visualize the inferred changepoints and parameters.

In [None]:
# Extract changepoint estimates
tau_samples = trace.posterior['tau'].values.reshape(-1, n_states-1)
tau_mean = np.mean(tau_samples, axis=0)
tau_std = np.std(tau_samples, axis=0)

print("Inferred Changepoints:")
for i, (mean_cp, std_cp) in enumerate(zip(tau_mean, tau_std)):
    print(f"  Changepoint {i+1}: {mean_cp:.2f} ± {std_cp:.2f}")

# True changepoints for comparison
true_changepoints = np.linspace(0, n_timepoints, n_states+1)[1:-1]
print(f"\nTrue Changepoints: {true_changepoints}")

In [None]:
# Plot changepoint posterior distributions
fig, axes = plt.subplots(1, n_states-1, figsize=(4*(n_states-1), 4))
if n_states-1 == 1:
    axes = [axes]

for i in range(n_states-1):
    axes[i].hist(tau_samples[:, i], bins=50, alpha=0.7, density=True)
    axes[i].axvline(tau_mean[i], color='red', linestyle='-', label=f'Mean: {tau_mean[i]:.1f}')
    axes[i].axvline(true_changepoints[i], color='green', linestyle='--', label=f'True: {true_changepoints[i]:.1f}')
    axes[i].set_xlabel('Time')
    axes[i].set_ylabel('Density')
    axes[i].set_title(f'Changepoint {i+1} Posterior')
    axes[i].legend()
    axes[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Visualize Model Fit

Show the original data with inferred changepoints overlaid.

In [None]:
# Plot data with inferred changepoints
fig, axes = plt.subplots(n_dimensions, 1, figsize=(12, 2*n_dimensions), sharex=True)
if n_dimensions == 1:
    axes = [axes]

for i in range(n_dimensions):
    axes[i].plot(data_array[i, :], 'b-', alpha=0.7, label='Data')
    axes[i].set_ylabel(f'Dimension {i+1}')
    axes[i].grid(True, alpha=0.3)
    
    # Add inferred changepoints
    for j, cp in enumerate(tau_mean):
        axes[i].axvline(cp, color='red', linestyle='-', alpha=0.8, 
                       label='Inferred CP' if i == 0 and j == 0 else '')
    
    # Add true changepoints
    for j, cp in enumerate(true_changepoints):
        axes[i].axvline(cp, color='green', linestyle='--', alpha=0.5,
                       label='True CP' if i == 0 and j == 0 else '')

axes[0].legend()
axes[-1].set_xlabel('Time')
plt.suptitle('Model Fit Results\n(Red: Inferred, Green: True Changepoints)', fontsize=14)
plt.tight_layout()
plt.show()

## Model Parameters

Examine the inferred mean and variance parameters for each state.

In [None]:
# Extract mean and sigma parameters
mu_samples = trace.posterior['mu'].values.reshape(-1, n_dimensions, n_states)
sigma_samples = trace.posterior['sigma'].values.reshape(-1, n_dimensions, n_states)

mu_mean = np.mean(mu_samples, axis=0)
sigma_mean = np.mean(sigma_samples, axis=0)

print("Inferred Parameters by State:")
for state in range(n_states):
    print(f"\nState {state + 1}:")
    for dim in range(n_dimensions):
        print(f"  Dimension {dim + 1}: μ = {mu_mean[dim, state]:.3f}, σ = {sigma_mean[dim, state]:.3f}")

In [None]:
# Plot parameter evolution across states
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Mean parameters
for dim in range(n_dimensions):
    ax1.plot(range(1, n_states+1), mu_mean[dim, :], 'o-', label=f'Dim {dim+1}')
ax1.set_xlabel('State')
ax1.set_ylabel('Mean (μ)')
ax1.set_title('Mean Parameters by State')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Variance parameters
for dim in range(n_dimensions):
    ax2.plot(range(1, n_states+1), sigma_mean[dim, :], 's-', label=f'Dim {dim+1}')
ax2.set_xlabel('State')
ax2.set_ylabel('Standard Deviation (σ)')
ax2.set_title('Variance Parameters by State')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated the `GaussianChangepointMeanVar2D` model:

1. **Data Generation**: Created synthetic Gaussian data with known changepoints
2. **Model Fitting**: Used variational inference to fit the model
3. **Results Analysis**: Extracted and visualized changepoint estimates
4. **Parameter Inspection**: Examined mean and variance parameters for each state

The model successfully detected changepoints in Gaussian data where both mean and variance can change across different states.