<a href="https://colab.research.google.com/github/Cauch-BS/cscg-hippo/blob/main/notebooks/Modeling_the_Hippocampus_with_Causal_Graphs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup the Colab Environment

In [None]:
#@title Install dependencies, including original CSCG code

!pip install igraph umap-learn
!pip install git+https://github.com/NABI-SNU/cscg.git

#@markdown - *Modified from reference* : https://www.nature.com/articles/s41467-021-22559-5

In [None]:
# @title Mount Drive
# @markdown 1. Accept the requested permissions from Google Drive
# @markdown  2. Download the calcium imaging data and metadata (time, position)
# @markdown 3. Load the calcium imaging data with `numpy.load()`
# @markdown 4. Load the metadata with pandas.

from google.colab import drive
drive.mount('/content/drive')


!mkdir -p ./CSCG_example_NABI
!mkdir -p ./CSCG_example_NABI/first_run
# !gdown -O ./CSCG_example_NABI/first_run/spk_data.npy 1PZmzB3e2hS5xTaUzLan9nZ5Glc6Etj5X
!gdown -O ./CSCG_example_NABI/first_run/day_data.npy 1lx9ChkwyxNDXQpGZ2ohLex13fM5aBwyT
!gdown -O ./CSCG_example_NABI/first_run/selected_pos.tar.gz 1VH6rZkyteCh5XeMaeoGZojtNiDJLggbH


import numpy as np

base = "/content/CSCG_example_NABI/first_run"

# spk_data = np.load(f"{base}/spk_data.npy")
day_ind_array = np.load(f"{base}/day_data.npy")

import tarfile
from pathlib import Path
import pandas as pd

# Extract tar.gz
parquet_root = Path(base) / "selected_pos_parquet"
parquet_root.mkdir(exist_ok=True)

with tarfile.open(f"{base}/selected_pos.tar.gz", "r:gz") as tar:
    tar.extractall(parquet_root, filter = "data")

# Load parquet files (mirrors original list semantics)
parquet_dir = parquet_root / "dataset_parquet"

selected_pos_big = [
    pd.read_parquet(p)
    for p in sorted(parquet_dir.glob("selected_pos_*.parquet"))
]

In [None]:
# @title Perform UMAP embedding for high-dimensional calcium imaging data
# @markdown Objective
# @markdown - Analyze neuronal spiking data obtained from calcium imaging
# @markdown - Reduce dimensionality using UMAP (3D embedding)
# @markdown - A sample UMAP is shown below.
# @markdown - Full UMAP is computationally expensive
# @markdown - Therefore, we download a precomputed embedding for demonstration

# import umap

# umap_data= umap.UMAP(
#    n_neighbors = 100,
#    n_components = 3,
#    min_dist = 0.1,
#    metric = 'correlation',
#    random_state = seed,
#    verbose = True,
#).fit(
#    X = spk_data.T
#)
# umap_embedding_example = umap_data.embedding_

# We download the actual UMAP data from google drive
!gdown -qO ./CSCG_example_NABI/first_run/embedding_42.npy 1DO3l6ZTFyoShWvI3T34GEyh1MXfCJ1PO
umap_embedding = np.load(f"{base}/embedding_42.npy")

# Display the Hippocampal Spiking Data in a 3D Plot

In [None]:
#@title Enable Widgets

from google.colab import output
output.enable_custom_widget_manager()

In [None]:
#@title Load 3D UMAP
#@markdown - Across each day, the mouse learned more "orthogonal" representations by context
#@markdown - Use the trial slider to select either "near (red)" or "far (green) trials"

# =========================
# Hippocampal Neural Space Explorer (Colab-safe, Full Redraw)
# =========================

import numpy as np
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display, clear_output
from tqdm import tqdm
import matplotlib
from matplotlib import colors

# -------------------------
# Marker definitions
# -------------------------
markers = [
    {'name': 'Track', 'color': '#808080', 'position': 0.9},
    {'name': 'Indicator-Near', 'color': '#FBB4B9', 'position': 0.85},
    {'name': 'R1-Near', 'color': '#F768A1', 'position': 0.8},
    {'name': 'R2-Near', 'color': '#C51B8A', 'position': 0.75},
    {'name': 'Indicator-Far', 'color': '#A8D8A7', 'position': 0.7},
    {'name': 'R1-Far', 'color': '#41AE76', 'position': 0.65},
    {'name': 'R2-Far', 'color': '#3a5a40', 'position': 0.6},
    {'name': 'Teleportation', 'color': '#FFFFFF', 'position': 0.55},
]

# -------------------------
# Data loading per session (NaN-safe, no isna() filtering)
# -------------------------
def get_session_data(day_num):
    """Load and align position + embedding for a session. NaN-safe embedding. No dataframe isna() filtering."""
    day_mask = (np.asarray(day_ind_array).reshape(-1) == day_num)
    embedding_day = umap_embedding[day_mask].astype(np.float64)
    embedding_day = np.nan_to_num(embedding_day, nan=0.0)

    pos_df = selected_pos_big[day_num].copy()

    # Align by row count (1:1 correspondence)
    n_emb = embedding_day.shape[0]
    n_pos = len(pos_df)
    n = min(n_emb, n_pos)
    if n < n_emb:
        embedding_day = embedding_day[:n]
    if n < n_pos:
        pos_df = pos_df.iloc[:n].copy().reset_index(drop=True)
    else:
        pos_df = pos_df.reset_index(drop=True)

    # Color preprocessing
    pos_df['area-color'] = '#808080'
    for m in markers:
        pos_df.loc[pos_df['position_marker'] == m['name'], 'area-color'] = m['color']

    pos_df['position-color'] = pos_df['area-color'].copy()
    norm_pos = colors.Normalize(vmin=0, vmax=230)
    for reward_id in [1, 2]:
        cmap = matplotlib.colormaps['Blues'] if reward_id == 1 else matplotlib.colormaps['YlOrBr']
        ind = (pos_df['reward_id'].to_numpy() == reward_id)
        if ind.any():
            vals = np.nan_to_num(pos_df.loc[ind, 'position'].to_numpy(), nan=0.0)
            pos_df.loc[ind, 'position-color'] = list(map(colors.rgb2hex, cmap(norm_pos(vals))))

    pos_df['set-color'] = np.where(
        (pos_df['set'].to_numpy() == 'Cue Set A'),
        '#000000',
        '#808080'
    )

    pos_df['trial-color'] = pos_df['area-color'].copy()
    norm_trial = colors.Normalize(vmin=0, vmax=100)
    for reward_id in [1, 2]:
        cmap = matplotlib.colormaps['Blues'] if reward_id == 1 else matplotlib.colormaps['YlOrBr']
        ind = (pos_df['reward_id'].to_numpy() == reward_id)
        if ind.any():
            vals = np.nan_to_num(pos_df.loc[ind, 'trial_number'].to_numpy(dtype=float), nan=0.0)
            pos_df.loc[ind, 'trial-color'] = list(map(colors.rgb2hex, cmap(norm_trial(vals))))

    # NaN-safe for JSON serialization
    def safe_num(x):
        return np.nan_to_num(np.asarray(x, dtype=float), nan=0.0)

    trial_numbers = pos_df['trial_number'].to_numpy(dtype=float)
    trial_numbers = np.nan_to_num(trial_numbers, nan=0.0).astype(int)
    unique_trials = np.unique(trial_numbers).tolist()
    if not unique_trials:
        unique_trials = [0]

    # Recompute after filtering
    trial_numbers = pos_df['trial_number'].to_numpy(dtype=float)
    trial_numbers = np.nan_to_num(trial_numbers, nan=0.0).astype(int)
    unique_trials = np.unique(trial_numbers).tolist()
    if not unique_trials:
        unique_trials = [0]

    customdata = np.column_stack([
        pos_df['position_marker'].astype(str).to_numpy(),
        safe_num(pos_df['trial_number']),
        safe_num(pos_df['reward_id']),
        safe_num(pos_df['position']),
        pos_df['set'].astype(str).to_numpy(),
    ])

    return {
        'embedding': embedding_day,
        'position': pos_df,
        'customdata': customdata,
        'unique_trials': unique_trials,
        'trial_numbers': trial_numbers,
    }


# -------------------------
# Unique sessions
# -------------------------
unique_days = [int(d) for d in np.unique(np.asarray(day_ind_array).reshape(-1)).tolist()]
if not unique_days:
    unique_days = [0]

# -------------------------
# Precompute all sessions (once)
# -------------------------
print("Precomputing session data...\n")
SESSION_CACHE = {}

for d in tqdm(unique_days, desc="Preprocessing sessions"):
    SESSION_CACHE[int(d)] = get_session_data(int(d))

print("\nDone precomputing sessions.")


# -------------------------
# Camera (stored in widgets, applied on redraw)
# -------------------------
initial_camera_position = {
    'up': {'x': 0, 'y': 0, 'z': 1},
    'center': {'x': 0, 'y': 0, 'z': 0},
    'eye': {'x': -1.5, 'y': 1.5, 'z': 1.5}
}

# -------------------------
# Widgets
# -------------------------
fig_output = widgets.Output()

session_selector = widgets.SelectionSlider(
    options=unique_days,
    value=unique_days[0],
    description='Day',
    continuous_update=False,
    indent=False,
    layout=widgets.Layout(width='350px'),
)

trial_use_selected = widgets.Checkbox(
    value=False,
    description='Show only selected trial',
    indent=False,
    layout=widgets.Layout(width='200px'),
)

# Trial selector - options updated per session
trial_selector = widgets.SelectionSlider(
    options=[0],
    value=0,
    description='Trial Number',
    continuous_update=False,
    indent=False,
    layout=widgets.Layout(width='350px'),
)

color_options = ['Trial Type - Areas', 'Trial Type - Position', 'Trial Type - Trial Number', 'Cue Sets']
color_scheme_selector = widgets.Dropdown(
    options=color_options,
    value=color_options[0],
    description='Color:',
    layout=widgets.Layout(width='250px'),
    indent=False,
)


# -------------------------
# Redraw (full recreate, Colab-safe)
# -------------------------
def redraw():
    day_num = int(session_selector.value)
    session_data = SESSION_CACHE[day_num]
    emb = session_data['embedding']
    pos = session_data['position']
    cd = session_data['customdata']
    trial_nums = session_data['trial_numbers']
    unique_trials = session_data['unique_trials']

    # Trial filtering
    show_only_trial = trial_use_selected.value
    sel_trial = int(trial_selector.value)
    if show_only_trial:
        hit = (trial_nums == sel_trial)
        if hit.any():
            emb = emb[hit]
            pos = pos.loc[hit].reset_index(drop=True)
            cd = cd[hit]
        else:
            emb = np.zeros((0, 3))
            pos = pos.iloc[:0]
            cd = np.zeros((0, 5))

    # Color scheme
    choice = color_scheme_selector.value
    if choice == 'Trial Type - Areas':
        color_list = pos['area-color'].astype(str).to_list()
    elif choice == 'Trial Type - Position':
        color_list = pos['position-color'].astype(str).to_list()
    elif choice == 'Trial Type - Trial Number':
        color_list = pos['trial-color'].astype(str).to_list()
    elif choice == 'Cue Sets':
        color_list = pos['set-color'].astype(str).to_list()
    else:
        color_list = pos['area-color'].astype(str).to_list()

    scatter = go.Scatter3d(
        x=emb[:, 0],
        y=emb[:, 1],
        z=emb[:, 2],
        mode='markers',
        marker=dict(size=1.6, color=color_list, opacity=0.8),
        customdata=cd,
        hovertemplate="<br>".join([
            "Trial Type: %{customdata[2]}",
            "Position: %{customdata[3]}",
            "Area: %{customdata[0]}",
            "Trial Number: %{customdata[1]}",
            "Set: %{customdata[4]}",
        ]),
    )

    camera = initial_camera_position

    axis_white = dict(
        showbackground=False,
        showline=True,
        linecolor='white',
        zeroline=True,
        zerolinecolor='white',
        showgrid=True,
        gridcolor='rgba(255,255,255,0.3)',
        showticklabels=True,
        tickfont=dict(color='white'),
        title_font=dict(color='white'),
    )

    fig = go.Figure(data=[scatter])
    fig.update_layout(
        margin=dict(t=40),
        template='plotly_dark',
        paper_bgcolor='black',
        plot_bgcolor='black',
        font_color='white',
        scene=dict(
            xaxis_showspikes=False,
            yaxis_showspikes=False,
            zaxis_showspikes=False,
            bgcolor='black',
            camera=camera,
            xaxis={**axis_white, 'title': 'UMAP 1'},
            yaxis={**axis_white, 'title': 'UMAP 2'},
            zaxis={**axis_white, 'title': 'UMAP 3'},
        ),
        width=750,
        height=750,
    )

    with fig_output:
        clear_output(wait=True)
        display(fig)


# -------------------------
# Session change handler
# -------------------------
loading_label = widgets.HTML(value='', layout=widgets.Layout(margin='0 0 8px 0'))


def on_session_change(change=None):
    # Use change['new'] when available so we always get the value that triggered the event
    if change is not None and 'new' in change:
        day_num = int(change['new'])
    else:
        day_num = int(session_selector.value)

    # Show loading immediately when session changes
    loading_label.value = '<span style="color:#888;font-size:14px;">‚è≥ Loading session ' + str(day_num) + '...</span>'

    session_data = SESSION_CACHE[day_num]
    unique_trials = session_data['unique_trials']

    # Turn off trial filtering when session changes (triggers redraw)
    trial_use_selected.value = False

    # Reset trial selector to first valid trial for new session
    trial_selector.options = unique_trials
    trial_selector.value = unique_trials[0]

    # Redraw with new session data
    redraw()
    loading_label.value = ''


# -------------------------
# Camera buttons
# -------------------------
def update_view(_):
    redraw()


def reset_view(_):
    redraw()


update_view_button = widgets.Button(description='Update View')
update_view_button.on_click(update_view)
reset_view_button = widgets.Button(description='Reset View')
reset_view_button.on_click(reset_view)

# -------------------------
# Wire up observers
# -------------------------
def on_trial_or_color_change(change=None):
    redraw()


trial_use_selected.observe(on_trial_or_color_change, names='value')
trial_selector.observe(on_trial_or_color_change, names='value')
color_scheme_selector.observe(on_trial_or_color_change, names='value')
session_selector.observe(on_session_change, names='value')

# Initial session setup (trial options, etc.)
on_session_change()

# -------------------------
# Layout and display
# -------------------------
session_container = widgets.VBox([
    widgets.HTML('<b>Day</b>'),
    session_selector,
])

trial_container = widgets.VBox([
    widgets.HTML('<b>Trial</b>'),
    trial_use_selected,
    trial_selector,
])

view_params_container = widgets.VBox([
    widgets.HTML('<b>Camera</b>'),
    update_view_button,
    reset_view_button,
])

control_panel = widgets.VBox([
    widgets.HTML('<h2>üöÄ Hippocampal Neural Space Explorer</h2>'),
    widgets.HTML('<hr>'),
    loading_label,
    color_scheme_selector,
    session_container,
    trial_container,
    view_params_container,
])

ui = widgets.HBox([control_panel, fig_output])
display(ui)


# Simulating the Hippocampus with Causal Graphs

In [None]:
# @title Train a Clone-HMM on the 2ACDC task
# @markdown ## Goal
# @markdown - See how internal states transition over learning.
# @markdown - Intuition:
# @markdown   - Each hidden state ‚âà latent "context" / "memory state"
# @markdown   - Edges ‚âà probability of switching between contexts

import sys
import numpy as np
from cscg import CHMM, forwardE, backwardE, updateCE
from tqdm import tqdm
from collections import Counter
import matplotlib.pyplot as plt
import plotly.graph_objects as go

import igraph
import os
import scipy
import random
import string
import copy
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML

# -------------------------
# Forward messages helper
# -------------------------
def get_mess_fwd(chmm, x, pseudocount=0.0, pseudocount_E=0.0):
    n_clones = chmm.n_clones

    # Emission-like matrix for forwardE
    E = np.zeros((n_clones.sum(), len(n_clones)))
    last = 0
    for c in range(len(n_clones)):
        E[last : last + n_clones[c], c] = 1
        last += n_clones[c]

    E += pseudocount_E
    norm = E.sum(1, keepdims=True)
    norm[norm == 0] = 1
    E /= norm

    # Transition tensor
    T = chmm.C + pseudocount
    norm = T.sum(2, keepdims=True)
    norm[norm == 0] = 1
    T /= norm
    T = T.mean(0, keepdims=True)

    log2_lik, mess_fwd = forwardE(
        T.transpose(0, 2, 1),
        E,
        chmm.Pi_x,
        chmm.n_clones,
        x,
        x * 0,
        store_messages=True,
    )
    return mess_fwd



COLOR_ORDER_TRIAL1 = [
    "#808080",  # A - Track / Grey corridor
    "#FBB4B9",  # B - Indicator-Near
    "#A8D8A7",  # C - Indicator-Far
    "#F768A1",  # D - R1
    "#C51B8A",  # E - R2
    "#F768A1",  # F - Reward
    "#808080",  # G - Post-reward
    "#404040",  # H - Teleportation
]

COLOR_ORDER_TRIAL2 = [
    "#808080",  # A - Track / Grey corridor
    "#FBB4B9",  # B - Indicator-Near
    "#A8D8A7",  # C - Indicator-Far
    "#41AE76",  # D - R1
    "#3A5A40",  # E - R2
    "#3A5A40",  # F - Reward
    "#808080",  # G - Post-reward
    "#404040",  # H - Teleportation
]

SEMANTIC_MAP = {
    ("A","Near"): "Track",
    ("B","Near"): "Indicator-Near",
    ("D","Near"): "R1-Near",
    ("E","Near"): "R2-Near",
    ("C","Far"): "Indicator-Far",
    ("D","Far"): "R1-Far",
    ("E","Far"): "R2-Far",
    ("H","Near"): "Teleport",
    ("H","Far"): "Teleport",
    ("F","Near"): "Reward-Near",
    ("F","Far"):  "Reward-Far",
}


def transition_graph_plotter(iter_val):
    """Build and return the Plotly figure for the given iteration."""
    iter_index = ITER_OPTIONS.index(iter_val)
    chmm = CHMM_CACHE[iter_val]
    x, a = x_curr, a_curr
    chmm.pseudocount = 0.0
    chmm.learn_viterbi_T(x, a, n_iter=100)

    states = chmm.decode(x, a)[1]

    trial_len = len(trial1x)
    trial1_states = states[:trial_len]
    trial2_states = states[trial_len:]
    v = np.unique(states)
    obs_letters = {v: k for k, v in letter_num_dict.items()}

    node_labels = []
    node_trials = []

    state_trial = {}

    for state in v:
        c1 = np.sum(trial1_states == state)
        c2 = np.sum(trial2_states == state)

        state_trial[state] = 1 if c1 >= c2 else 2
        obs_id = np.floor(state / 100)
        letter = obs_letters.get(obs_id, "?")

        trial_name = "Near" if state_trial[state] == 1 else "Far"

        node_labels.append(letter)
        node_trials.append(trial_name)

    num_obs = 8

    color_dict_trial1 = {}
    color_dict_trial2 = {}

    for i, letter in enumerate(string.ascii_uppercase[:num_obs]):
        obs_id = letter_num_dict[letter]
        color_dict_trial1[obs_id] = COLOR_ORDER_TRIAL1[i]
        color_dict_trial2[obs_id] = COLOR_ORDER_TRIAL2[i]


    # Node colors: by "edge_nodes" = floor(state/100)
    edge_nodes = np.floor(v / 100)

    edge_color = []
    for state, obs in zip(v, edge_nodes):
        if state_trial.get(state, 1) == 1:
            edge_color.append(color_dict_trial1.get(obs, "#808080"))
        else:
            edge_color.append(color_dict_trial2.get(obs, "#808080"))


    # Extract and normalize transition adjacency among visited states
    # v may be int64; ensure int for indexing (chmm.C can be picky)
    v_int = np.asarray(v, dtype=np.intp)
    T = chmm.C[:, v_int][:, :, v_int]
    A = T.sum(0)
    row_sums = A.sum(1, keepdims=True)
    row_sums[row_sums == 0] = 1
    A = A / row_sums

    g = igraph.Graph.Adjacency((A > 0).tolist())
    layout = g.layout_fruchterman_reingold()
    coords = np.array(layout.coords)

    n_nodes = len(v)
    semantic_labels = [
        SEMANTIC_MAP.get((l,t), f"{l}-{t}")
        for l,t in zip(node_labels, node_trials)
    ]

    edge_x = []
    edge_y = []

    for e in g.get_edgelist():
        x0, y0 = coords[e[0]]
        x1, y1 = coords[e[1]]

        edge_x += [x0, x1, None]
        edge_y += [y0, y1, None]

    # Ensure coords are 1D arrays for Plotly (handle 0-node case)
    if n_nodes == 0:
        node_x, node_y = [], []
    else:
        node_x = coords[:, 0].tolist()
        node_y = coords[:, 1].tolist()

    # Explicit axis range so plot is never "empty-looking" when we have data
    if n_nodes > 0:
        x_min, x_max = coords[:, 0].min(), coords[:, 0].max()
        y_min, y_max = coords[:, 1].min(), coords[:, 1].max()
        pad = max((x_max - x_min) * 0.1, (y_max - y_min) * 0.1, 0.5)
        x_range = [x_min - pad, x_max + pad]
        y_range = [y_min - pad, y_max + pad]
    else:
        x_range = [-1, 1]
        y_range = [-1, 1]

    edge_trace = go.Scatter(
        x=edge_x,
        y=edge_y,
        mode="lines",
        line=dict(width=1, color="rgba(255,255,255,0.25)"),
        hoverinfo="none",
    )

    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        mode="markers",
        marker=dict(
            size=10,
            color=edge_color,
            line=dict(width=1, color="black"),
        ),
        customdata=np.stack([node_labels, node_trials, semantic_labels], axis=1)
        if n_nodes > 0 else [],
        hovertemplate="<br>".join([
            "<b>%{customdata[2]}</b>",
            "Observation: %{customdata[0]}",
            "Trial: %{customdata[1]}",
            "<extra></extra>",
        ]),
    )

    fig = go.Figure(data=[edge_trace, node_trace])
    fig.update_layout(
        template="plotly_dark",
        paper_bgcolor="black",
        plot_bgcolor="black",
        xaxis=dict(visible=False, range=x_range),
        yaxis=dict(visible=False, range=y_range),
        title=f"Transition graph @ iter {ITER_OPTIONS[iter_index]}",
        width=750,
        height=650,
        margin=dict(l=0, r=0, t=40, b=0),
    )
    return fig


# -------------------------
# Main
# -------------------------
np.random.seed(0)

#  | Letter | Meaning in Task    | Biological Interpretation |
#  | ------ | ------------------ | ------------------------- |
#  | A      | Grey corridor      | Sensory-ambiguous region  |
#  | B      | Near indicator     | Cue ‚Üí Reward at R1        |
#  | C      | Far indicator      | Cue ‚Üí Reward at R2        |
#  | D      | R1 zone            | Signal for R1             |
#  | E      | R2 zone            | Signal for R2             |
#  | F      | Reward             |                           |
#  | G      | Pre-teleport       | Before reset              |
#  | H      | Teleport / end     | Reset region              |


trial1x_let = np.repeat(
    np.array(['A','A','A','A','A','A','B','B','B','B',
              'A','A','A','D','F','A','A','A','E','E',
              'A','A','G','H','H','H'])
    ,1)
trial2x_let = np.repeat(
    np.array(['A','A','A','A','A','A','C','C','C','C',
              'A','A','A','D','D','A','A','A','E','F',
              'A','A','G','H','H','H'])
    ,1)

num_trials = 5
trials = np.random.choice(2, num_trials - 2)
trials = np.concatenate((trials, np.array([0, 1])))

ln = np.linspace(0, 7, 8)
letter_num_dict = {
    "A": ln[0], "B": ln[1], "C": ln[2], "D": ln[3],
    "E": ln[4], "F": ln[5], "G": ln[6], "H": ln[7],
}

trial1x = np.zeros(len(trial1x_let), dtype=np.int64)
trial2x = np.zeros(len(trial2x_let), dtype=np.int64)
for i in range(len(trial1x_let)):
    trial1x[i] = int(letter_num_dict[trial1x_let[i]])
    trial2x[i] = int(letter_num_dict[trial2x_let[i]])

# Build concatenated observation stream x
tr_len = len(trial1x)
x = np.zeros(num_trials * tr_len, dtype=np.int64)
for t in range(len(trials)):
    if trials[t] == 0:
        x[t * tr_len : (t + 1) * tr_len] = trial1x
    else:
        x[t * tr_len : (t + 1) * tr_len] = trial2x

a = np.zeros(len(x), dtype=np.int64)
OBS = len(np.unique(x))

n_clones = np.ones(OBS + 5, dtype=np.int64) * 100
chmm = CHMM(n_clones=n_clones, pseudocount=1e-10, x=x, a=a, seed=0)

# -------------------------
# Precompute CHMM at iterations 0, 5, 10, 15
# -------------------------
ITER_OPTIONS = [0, 5, 10, 15, 20, 25]
CHMM_CACHE = {}

print("Training and caching CHMM at iterations 0, 5, 10, 15...")
for tot_iter in tqdm(range(26), desc="CHMM learning"):
    chmm.learn_em_T(x, a, n_iter=10, term_early=False)
    get_mess_fwd(chmm, x, pseudocount_E=0.1)

    if tot_iter in ITER_OPTIONS:
        CHMM_CACHE[tot_iter] = copy.deepcopy(chmm)

print("\n Done.")

x_curr = np.concatenate((trial1x, trial2x)).astype(np.int64)
a_curr = np.zeros(len(x_curr), dtype=np.int64)

In [None]:
#@title Display the HMM using plotly

import plotly.io as pio

# Ensure Plotly uses the notebook/colab renderer
pio.renderers.default = "colab"

# -------------------------
# Clone-HMM Transition Graph (interactive)
# -------------------------

# Create persistent Plotly FigureWidget
fig_widget = go.FigureWidget(
    transition_graph_plotter(ITER_OPTIONS[0])
)

fig_widget.update_layout(
    width=750,
    height=650,
)

loading_label = widgets.HTML(
    value='',
    layout=widgets.Layout(margin='0 0 8px 0')
)

# Slider
iteration_selector = widgets.SelectionSlider(
    options=ITER_OPTIONS,
    value=ITER_OPTIONS[0],
    description='Iteration',
    continuous_update=False,
    indent=False,
    layout=widgets.Layout(width='350px'),
)

# -------------------------
# REDRAW FUNCTION (FULL)
# -------------------------
def redraw(change=None):
    iter_val = int(iteration_selector.value)

    # ---- show loading immediately ----
    loading_label.value = (
        f'<span style="color:#888;font-size:14px;">'
        f'‚è≥ Loading iteration {iter_val}...</span>'
    )

    # ---- build figure ----
    new_fig = transition_graph_plotter(iter_val)

    # ---- remove legend completely ----
    new_fig.update_layout(showlegend=False)
    for tr in new_fig.data:
        tr.showlegend = False

    # ---- update widget figure ----
    fig_widget.data = ()
    for trace in new_fig.data:
        fig_widget.add_trace(trace)

    fig_widget.layout = new_fig.layout

    # ---- clear loading text ----
    loading_label.value = ''


# Connect slider ‚Üí redraw
iteration_selector.observe(redraw, names="value")

# Control panel
control_panel = widgets.VBox([
    widgets.HTML('<h2>Clone-HMM Transition Graph</h2>'),
    widgets.HTML('<hr>'),
    loading_label,          # ‚Üê ADD THIS LINE
    widgets.HTML('<b>Iteration</b>'),
    iteration_selector,
])


# Layout
ui = widgets.HBox([control_panel, fig_widget])

# Display UI
display(ui)

# Initial draw
redraw()
