# SLDS example: Continuous vs Discrete Energy Landscape

This notebook simulates a simple switching VAR (SLDS-like) process, then:

1. Fits a **discrete energy landscape (DEL)** baseline using a maximum entropy (Ising) model on binarized data.
2. Fits the **continuous energy landscape (CEL)** GCN model.
3. Uses k-means on simple features from each method to recover discrete states.
4. Evaluates both using BR (ARI), TMA, and SDA.


In [1]:
import os, sys

sys.path.append(os.path.abspath(".."))

import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score

from continuous_energy_landscape import ContinuousEnergyLandscape
from simulations.slds_sim_simple import simulate_slds_example
from metrics_energy_landscape import compute_metrics_all
from discrete_energy import binarize, ener_calculate_pca_bin



A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/Users/tranminhtriet/opt/anaconda3/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/tranminhtriet/opt/anaconda3/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/tranminhtriet/opt/anaconda3/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/Users/tranminhtriet/opt/anaconda3/lib/python3.9/site-packages/traitlets/config/application.py", line 992, i

## 1. Simulate SLDS data


In [2]:
K = 3          # number of discrete states
d = 8          # dimensionality
T = 4000       # number of time points
seed = 0

sim, X, g_true = simulate_slds_example(
    K=K,
    d=d,
    T=T,
    scenario="nonlinear_obs",  # <--- key change
    dwell=100,                 # slightly longer dwell
    sigma=0.05,                # add some observation noise
    seed=seed,
)

P_true = sim.P
print("X shape:", X.shape)
print("g_true shape:", g_true.shape)


X shape: (4000, 8)
g_true shape: (4000,)


## 2. Discrete energy landscape (DEL) baseline

We binarize each ROI at its mean and fit an exact Ising (MEM) model.
Then we take the time series of Ising energies and cluster it with k-means
to obtain discrete state labels.


In [3]:
# X has shape (T, d); discrete_energy expects (D, T)
X_bin = binarize(X.T)
h_del, W_del, E_del = ener_calculate_pca_bin(X_bin)

print("Discrete energy shape:", E_del.shape)

# k-means on one-dimensional energy feature
E_feat = E_del.reshape(-1, 1)
km_del = KMeans(n_clusters=K, random_state=seed)
g_del = km_del.fit_predict(E_feat)
print("DEL labels shape:", g_del.shape)


[MEM] converged at it=83 with max|mom-res|=9.135e-06
Discrete energy shape: (4000,)
DEL labels shape: (4000,)


Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



## 3. Continuous energy landscape (CEL) model

We fit the CEL model on the same continuous data and use
simple features (energy + 1D projection on the top eigenvector of S)
for k-means clustering.


In [4]:
cel = ContinuousEnergyLandscape(
    hidden_channels=64,
    rank=32,
    delta=0.10,
    eps=1e-2,
    lambda_reg=1e-2,
    lr=1e-3,
    weight_decay=0.0,
    max_epochs=300,
    clip_grad=1.0,
    verbose=True,
    device="cpu",
    seed=seed,
)

out = cel.fit(X)
E_cel = cel.predict_energy(X)
print("CEL energy shape:", E_cel.shape)

# 1D embedding from top eigenvector of S
S = cel.S_
eigvals, eigvecs = np.linalg.eigh(S)
v_top = eigvecs[:, np.argmax(eigvals)]
z_cel = (X @ v_top).reshape(-1, 1)

F_cel = np.hstack([E_cel.reshape(-1, 1), z_cel])
km_cel = KMeans(n_clusters=K, random_state=seed)
g_cel = km_cel.fit_predict(F_cel)
print("CEL labels shape:", g_cel.shape)


[epoch 0000] loss=118397.945312
[epoch 0100] loss=18051.705078
[epoch 0200] loss=15833.098633
[epoch 0299] loss=15143.773438
CEL energy shape: (4000,)
CEL labels shape: (4000,)


## 4. Metrics: BR, TMA, SDA


In [5]:
metrics_del = compute_metrics_all(g_true, g_del, K=K, P_true=P_true, X=X)
metrics_cel = compute_metrics_all(g_true, g_cel, K=K, P_true=P_true, X=X)

print("DEL metrics:", metrics_del)
print("CEL metrics:", metrics_cel)

print("DEL ARI:", adjusted_rand_score(g_true, g_del))
print("CEL ARI:", adjusted_rand_score(g_true, g_cel))


DEL metrics: {'BR_ARI': 0.03559247765492828, 'TMA': 0.1762474950549533, 'SDA': 0.8581755744537454}
CEL metrics: {'BR_ARI': 0.13029980298750513, 'TMA': 0.3191520490704616, 'SDA': 0.9406671776489169}
DEL ARI: 0.03559247765492828
CEL ARI: 0.13029980298750513
