# GaussianChangepointMeanVar2D Model Demo

This notebook demonstrates the usage of the GaussianChangepointMeanVar2D model with dummy data.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add pytau to path
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), '..', '..'))

from pytau.changepoint_model import GaussianChangepointMeanVar2D, gen_test_array

## Generate Test Data

In [None]:
# Generate dummy data with obvious changepoints for Gaussian model
np.random.seed(42)
n_trials = 10
n_time = 100
n_states = 3

# Create data with obvious state transitions
# State 1: low mean, low variance
# State 2: high mean, low variance
# State 3: medium mean, high variance
state_means = [1.0, 5.0, 3.0]
state_stds = [0.5, 0.5, 2.0]
transition_points = [33, 66]  # Clear transition points

test_data = np.zeros((n_trials, n_time))
for trial in range(n_trials):
    # Add some trial-to-trial variability in transition points
    trial_transitions = [t + np.random.randint(-3, 4) for t in transition_points]
    trial_transitions = [max(5, min(n_time-5, t)) for t in trial_transitions]  # Keep within bounds
    
    # Generate data for each state
    test_data[trial, :trial_transitions[0]] = np.random.normal(state_means[0], state_stds[0], trial_transitions[0])
    test_data[trial, trial_transitions[0]:trial_transitions[1]] = np.random.normal(state_means[1], state_stds[1], trial_transitions[1] - trial_transitions[0])
    test_data[trial, trial_transitions[1]:] = np.random.normal(state_means[2], state_stds[2], n_time - trial_transitions[1])

print(f"Generated test data shape: {test_data.shape}")
print(f"Data range: [{test_data.min():.2f}, {test_data.max():.2f}]")
print(f"State means: {state_means}")
print(f"State stds: {state_stds}")
print(f"Transition points (approx): {transition_points}")

## Initialize and Fit Model

In [None]:
# Initialize model
model = GaussianChangepointMeanVar2D(
    data_array=test_data,
    n_states=n_states,
    fit_type='vi'
)

# Generate the PyMC model
model.generate_model()
print("Model generated successfully")

## Run Model Test

In [None]:
# Test the model
try:
    model.test()
    print("Model test completed successfully")
except Exception as e:
    print(f"Model test failed: {e}")

## Visualize Input Data

In [None]:
# Plot the input data to visualize changepoints
plt.figure(figsize=(15, 10))

# Plot individual trials
plt.subplot(2, 2, 1)
for i in range(min(5, n_trials)):
    plt.plot(test_data[i], alpha=0.7, label=f'Trial {i+1}')
# Add vertical lines at approximate transition points
for tp in transition_points:
    plt.axvline(x=tp, color='red', linestyle='--', alpha=0.7)
plt.title('Individual Trials (First 5)')
plt.xlabel('Time')
plt.ylabel('Value')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot average across trials
plt.subplot(2, 2, 2)
mean_data = test_data.mean(axis=0)
std_data = test_data.std(axis=0)
plt.plot(mean_data, 'b-', linewidth=2, label='Mean')
plt.fill_between(range(n_time), mean_data - std_data, mean_data + std_data, alpha=0.3, label='±1 STD')
# Add vertical lines at approximate transition points
for tp in transition_points:
    plt.axvline(x=tp, color='red', linestyle='--', alpha=0.7, label='Transition' if tp == transition_points[0] else '')
plt.title('Average Across Trials')
plt.xlabel('Time')
plt.ylabel('Value')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot histogram by state
plt.subplot(2, 2, 3)
state_data = [
    test_data[:, :transition_points[0]].flatten(),
    test_data[:, transition_points[0]:transition_points[1]].flatten(),
    test_data[:, transition_points[1]:].flatten()
]
colors = ['blue', 'orange', 'green']
for i, (data, color) in enumerate(zip(state_data, colors)):
    plt.hist(data, bins=20, alpha=0.6, color=color, label=f'State {i+1}', density=True)
plt.title('Distribution by State')
plt.xlabel('Value')
plt.ylabel('Density')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot heatmap
plt.subplot(2, 2, 4)
plt.imshow(test_data, aspect='auto', cmap='viridis', interpolation='nearest')
plt.colorbar(label='Value')
# Add vertical lines at approximate transition points
for tp in transition_points:
    plt.axvline(x=tp, color='red', linestyle='--', alpha=0.8)
plt.title('All Trials Heatmap')
plt.xlabel('Time')
plt.ylabel('Trial')

plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated:
1. Generating dummy Gaussian data
2. Initializing the GaussianChangepointMeanVar2D model
3. Running model tests
4. Visualizing the input data

The GaussianChangepointMeanVar2D model is designed for detecting changepoints in Gaussian data where both mean and variance can change.