# CategoricalChangepoint2D Model Demo

This notebook demonstrates the `CategoricalChangepoint2D` model for detecting changepoints in categorical data.

## Model Overview
- **Data Type**: Categorical/discrete data
- **Input Shape**: 2D array (dimensions × time)
- **Use Case**: Discrete state sequences, behavioral data, or discretized neural states

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add pytau to path
sys.path.append('../../')

from pytau.changepoint_model import CategoricalChangepoint2D
import pymc as pm

## Generate Synthetic Categorical Data

Create synthetic categorical data with known changepoints.

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

# Parameters
n_dimensions = 4
n_timepoints = 100
n_states = 3
n_categories = 5  # Number of possible categorical values (0, 1, 2, 3, 4)

# Generate synthetic categorical data with changepoints
def generate_categorical_data(shape, n_states, n_categories):
    """Generate categorical data with changepoints"""
    n_dims, n_time = shape
    data = np.zeros((n_dims, n_time), dtype=int)
    
    # Define changepoints
    changepoints = np.linspace(0, n_time, n_states + 1).astype(int)
    
    # Different probability distributions for each state
    state_probs = []
    for state in range(n_states):
        # Create different probability distributions for each state
        probs = np.random.dirichlet(np.ones(n_categories) * 2, size=n_dims)
        state_probs.append(probs)
    
    # Generate data for each segment
    for state in range(n_states):
        start_idx = changepoints[state]
        end_idx = changepoints[state + 1]
        
        for dim in range(n_dims):
            # Sample from categorical distribution for this state and dimension
            data[dim, start_idx:end_idx] = np.random.choice(
                n_categories, 
                size=end_idx - start_idx, 
                p=state_probs[state][dim]
            )
    
    return data, changepoints[1:-1]  # Return data and internal changepoints

data_array, true_changepoints = generate_categorical_data(
    (n_dimensions, n_timepoints), n_states, n_categories
)

print(f"Data shape: {data_array.shape}")
print(f"Data range: [{data_array.min()}, {data_array.max()}]")
print(f"Number of categories: {n_categories}")
print(f"True changepoints: {true_changepoints}")

## Visualize Input Data

Plot the synthetic categorical data to visualize the changepoints.

In [None]:
# Plot the categorical data as a heatmap
fig, axes = plt.subplots(n_dimensions, 1, figsize=(12, 2*n_dimensions), sharex=True)
if n_dimensions == 1:
    axes = [axes]

# Create a colormap for categories
colors = plt.cm.Set1(np.linspace(0, 1, n_categories))

for i in range(n_dimensions):
    # Plot as colored segments
    for t in range(n_timepoints):
        category = data_array[i, t]
        axes[i].bar(t, 1, color=colors[category], width=1, alpha=0.8)
    
    axes[i].set_ylabel(f'Dimension {i+1}')
    axes[i].set_ylim(0, 1)
    axes[i].grid(True, alpha=0.3)
    
    # Add vertical lines at true changepoints
    for cp in true_changepoints:
        axes[i].axvline(cp, color='red', linestyle='--', alpha=0.7, linewidth=2)

axes[-1].set_xlabel('Time')
plt.suptitle('Synthetic Categorical Data\n(Red dashed lines show true changepoints)', fontsize=14)
plt.tight_layout()
plt.show()

# Create legend for categories
fig, ax = plt.subplots(1, 1, figsize=(8, 2))
for cat in range(n_categories):
    ax.bar(cat, 1, color=colors[cat], alpha=0.8, label=f'Category {cat}')
ax.set_xlabel('Category')
ax.set_ylabel('Legend')
ax.set_title('Category Color Legend')
ax.legend(ncol=n_categories, loc='upper center', bbox_to_anchor=(0.5, -0.1))
plt.tight_layout()
plt.show()

In [None]:
# Plot category proportions over time (smoothed)
window_size = 10
fig, axes = plt.subplots(n_dimensions, 1, figsize=(12, 2*n_dimensions), sharex=True)
if n_dimensions == 1:
    axes = [axes]

for dim in range(n_dimensions):
    # Calculate smoothed category proportions
    smoothed_props = np.zeros((n_categories, n_timepoints))
    
    for t in range(n_timepoints):
        start_idx = max(0, t - window_size // 2)
        end_idx = min(n_timepoints, t + window_size // 2 + 1)
        window_data = data_array[dim, start_idx:end_idx]
        
        for cat in range(n_categories):
            smoothed_props[cat, t] = np.mean(window_data == cat)
    
    # Plot stacked area chart
    bottom = np.zeros(n_timepoints)
    for cat in range(n_categories):
        axes[dim].fill_between(range(n_timepoints), bottom, 
                              bottom + smoothed_props[cat], 
                              color=colors[cat], alpha=0.7, 
                              label=f'Cat {cat}' if dim == 0 else '')
        bottom += smoothed_props[cat]
    
    axes[dim].set_ylabel(f'Dimension {dim+1}\nProportion')
    axes[dim].set_ylim(0, 1)
    axes[dim].grid(True, alpha=0.3)
    
    # Add changepoints
    for cp in true_changepoints:
        axes[dim].axvline(cp, color='red', linestyle='--', alpha=0.7, linewidth=2)

axes[0].legend(ncol=n_categories, loc='upper right')
axes[-1].set_xlabel('Time')
plt.suptitle(f'Category Proportions Over Time (Window Size: {window_size})', fontsize=14)
plt.tight_layout()
plt.show()

## Create and Fit Model

Initialize the CategoricalChangepoint2D model and fit it to the data.

In [None]:
# Create model instance
model_instance = CategoricalChangepoint2D(data_array, n_states=n_states)

# Generate the PyMC model
model = model_instance.generate_model()

print("Model created successfully!")
print(f"Model variables: {[var.name for var in model.unobserved_RVs]}")

In [None]:
# Fit the model using variational inference
with model:
    # Use ADVI for faster inference
    inference = pm.ADVI()
    approx = pm.fit(n=6000, method=inference)
    
    # Sample from the approximation
    trace = approx.sample(draws=1000)

print("Model fitting completed!")

## Analyze Results

Extract and visualize the inferred changepoints and category probabilities.

In [None]:
# Extract changepoint estimates
tau_samples = trace.posterior['tau'].values.reshape(-1, n_states-1)
tau_mean = np.mean(tau_samples, axis=0)
tau_std = np.std(tau_samples, axis=0)

print("Inferred Changepoints:")
for i, (mean_cp, std_cp) in enumerate(zip(tau_mean, tau_std)):
    print(f"  Changepoint {i+1}: {mean_cp:.2f} ± {std_cp:.2f}")

print(f"\nTrue Changepoints: {true_changepoints}")

In [None]:
# Plot changepoint posterior distributions
fig, axes = plt.subplots(1, n_states-1, figsize=(4*(n_states-1), 4))
if n_states-1 == 1:
    axes = [axes]

for i in range(n_states-1):
    axes[i].hist(tau_samples[:, i], bins=50, alpha=0.7, density=True)
    axes[i].axvline(tau_mean[i], color='red', linestyle='-', label=f'Mean: {tau_mean[i]:.1f}')
    axes[i].axvline(true_changepoints[i], color='green', linestyle='--', label=f'True: {true_changepoints[i]:.1f}')
    axes[i].set_xlabel('Time')
    axes[i].set_ylabel('Density')
    axes[i].set_title(f'Changepoint {i+1} Posterior')
    axes[i].legend()
    axes[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Visualize Model Fit

Show the original data with inferred changepoints overlaid.

In [None]:
# Plot data with inferred changepoints
fig, axes = plt.subplots(n_dimensions, 1, figsize=(12, 2*n_dimensions), sharex=True)
if n_dimensions == 1:
    axes = [axes]

for i in range(n_dimensions):
    # Plot categorical data as colored segments
    for t in range(n_timepoints):
        category = data_array[i, t]
        axes[i].bar(t, 1, color=colors[category], width=1, alpha=0.8)
    
    axes[i].set_ylabel(f'Dimension {i+1}')
    axes[i].set_ylim(0, 1)
    axes[i].grid(True, alpha=0.3)
    
    # Add inferred changepoints
    for j, cp in enumerate(tau_mean):
        axes[i].axvline(cp, color='red', linestyle='-', alpha=0.8, linewidth=3,
                       label='Inferred CP' if i == 0 and j == 0 else '')
    
    # Add true changepoints
    for j, cp in enumerate(true_changepoints):
        axes[i].axvline(cp, color='black', linestyle='--', alpha=0.7, linewidth=2,
                       label='True CP' if i == 0 and j == 0 else '')

axes[0].legend()
axes[-1].set_xlabel('Time')
plt.suptitle('Categorical Changepoint Model Results\n(Red: Inferred, Black: True Changepoints)', fontsize=14)
plt.tight_layout()
plt.show()

## Model Parameters

Examine the inferred category probability parameters for each state.

In [None]:
# Extract probability parameters
p_samples = trace.posterior['p'].values.reshape(-1, n_dimensions, n_states, n_categories)
p_mean = np.mean(p_samples, axis=0)

print("Inferred Category Probabilities by State:")
for state in range(n_states):
    print(f"\nState {state + 1}:")
    for dim in range(n_dimensions):
        print(f"  Dimension {dim + 1}:")
        for cat in range(n_categories):
            print(f"    Category {cat}: {p_mean[dim, state, cat]:.3f}")

In [None]:
# Plot probability parameters as heatmaps for each dimension
fig, axes = plt.subplots(n_dimensions, 1, figsize=(10, 3*n_dimensions))
if n_dimensions == 1:
    axes = [axes]

for dim in range(n_dimensions):
    im = axes[dim].imshow(p_mean[dim].T, aspect='auto', cmap='viridis', 
                         interpolation='nearest', vmin=0, vmax=1)
    axes[dim].set_xlabel('State')
    axes[dim].set_ylabel('Category')
    axes[dim].set_title(f'Category Probabilities - Dimension {dim+1}')
    axes[dim].set_xticks(range(n_states))
    axes[dim].set_xticklabels([f'State {i+1}' for i in range(n_states)])
    axes[dim].set_yticks(range(n_categories))
    axes[dim].set_yticklabels([f'Cat {i}' for i in range(n_categories)])
    
    # Add text annotations
    for state in range(n_states):
        for cat in range(n_categories):
            text = axes[dim].text(state, cat, f'{p_mean[dim, state, cat]:.2f}',
                                ha="center", va="center", color="white", fontweight='bold')
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=axes[dim])
    cbar.set_label('Probability')

plt.tight_layout()
plt.show()

In [None]:
# Plot probability evolution across states for each dimension
fig, axes = plt.subplots(1, n_dimensions, figsize=(4*n_dimensions, 6))
if n_dimensions == 1:
    axes = [axes]

for dim in range(n_dimensions):
    for cat in range(n_categories):
        axes[dim].plot(range(1, n_states+1), p_mean[dim, :, cat], 
                      'o-', color=colors[cat], label=f'Cat {cat}', 
                      linewidth=2, markersize=8)
    
    axes[dim].set_xlabel('State')
    axes[dim].set_ylabel('Probability')
    axes[dim].set_title(f'Dimension {dim+1}')
    axes[dim].legend()
    axes[dim].grid(True, alpha=0.3)
    axes[dim].set_xticks(range(1, n_states+1))
    axes[dim].set_ylim(0, 1)

plt.suptitle('Category Probability Evolution Across States', fontsize=14)
plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated the `CategoricalChangepoint2D` model:

1. **Data Generation**: Created synthetic categorical data with known changepoints
2. **Data Visualization**: Showed categorical sequences and smoothed proportions
3. **Model Fitting**: Used variational inference to fit the changepoint model
4. **Results Analysis**: Extracted changepoint estimates and category probabilities
5. **Parameter Inspection**: Visualized probability parameters for each state and category

The model successfully detected changepoints in categorical data. Key features:

**Strengths:**
- Handles discrete/categorical observations
- Infers probability distributions for each category in each state
- Provides uncertainty estimates for changepoints
- Flexible for different numbers of categories

**Use Cases:**
- Behavioral state analysis (e.g., animal behaviors)
- Discrete neural state transitions
- Sequence analysis with categorical outcomes
- Market regime detection with discrete states
- Any time series with categorical/discrete observations

**Model Interpretation:**
- Each state has its own probability distribution over categories
- Changepoints mark transitions between different probability regimes
- The model can capture both gradual and abrupt changes in category preferences