# CategoricalChangepoint2D Model Demo

This notebook demonstrates the usage of the CategoricalChangepoint2D model with dummy data.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add pytau to path
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), '..', '..'))

from pytau.changepoint_model import CategoricalChangepoint2D, gen_test_array

## Generate Test Data

In [None]:
# Generate dummy categorical data
np.random.seed(42)
array_size = 100
n_states = 3
n_categories = 4

# Generate test array with categorical data (integers representing categories)
test_data = np.random.randint(0, n_categories, size=(10, array_size))  # (trials, time)
print(f"Generated test data shape: {test_data.shape}")
print(f"Data range: [{test_data.min()}, {test_data.max()}]")
print(f"Data type: {test_data.dtype}")
print(f"Unique categories: {np.unique(test_data)}")

## Initialize and Fit Model

In [None]:
# Initialize model
model = CategoricalChangepoint2D(
    data_array=test_data,
    n_states=n_states,
    fit_type='vi'
)

# Generate the PyMC model
model.generate_model()
print("Model generated successfully")

## Run Model Test

In [None]:
# Test the model
try:
    model.test()
    print("Model test completed successfully")
except Exception as e:
    print(f"Model test failed: {e}")

## Visualize Input Data

In [None]:
# Plot the input data
plt.figure(figsize=(15, 8))

# Plot time series for first few trials
plt.subplot(2, 2, 1)
for trial in range(min(5, test_data.shape[0])):
    plt.plot(test_data[trial], 'o-', markersize=2, alpha=0.7, label=f'Trial {trial+1}')
plt.title('Categorical Time Series (First 5 Trials)')
plt.xlabel('Time')
plt.ylabel('Category')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot category distribution
plt.subplot(2, 2, 2)
plt.hist(test_data.flatten(), bins=range(n_categories+1), alpha=0.7, edgecolor='black')
plt.title('Category Distribution')
plt.xlabel('Category')
plt.ylabel('Frequency')
plt.xticks(range(n_categories))
plt.grid(True, alpha=0.3)

# Plot heatmap of all trials
plt.subplot(2, 1, 2)
plt.imshow(test_data, aspect='auto', cmap='viridis', interpolation='nearest')
plt.colorbar(label='Category')
plt.title('Categorical Data Heatmap (All Trials)')
plt.xlabel('Time')
plt.ylabel('Trial')

plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated:
1. Generating dummy categorical data
2. Initializing the CategoricalChangepoint2D model
3. Running model tests
4. Visualizing the categorical input data

The CategoricalChangepoint2D model is designed for detecting changepoints in categorical/discrete data where observations can take on one of several discrete values.