In [1]:
import numpy as np
import xarray as xr
import pandas as pd
from scipy.spatial import distance
from scipy.stats import linregress
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display

def interactive_ccm(ages, series1, series2, E=2, tau=1):
    """
    ages:       1D array of time points
    series1:    1D array (e.g. sat_diff) of same length as ages
    series2:    1D array (e.g. pre)    of same length as ages
    E, tau:     embedding dimension and delay
    """
    # 1) Build shadow manifolds & distance matrix
    def build_manifold(X):
        T, t0 = len(X), (E-1)*tau
        t_steps = np.arange(t0, T)
        M = np.zeros((T - t0, E))
        for i, t in enumerate(t_steps):
            for j in range(E):
                M[i,j] = X[t - j*tau]
        return M, t_steps

    M1, t_steps = build_manifold(series1)
    M2, _       = build_manifold(series2)
    dists       = distance.cdist(M1, M1)

    def get_neighbors(idx):
        return np.argsort(dists[idx])[1:E+2]

    # 2) Create a 4×2 grid, but merge right‐top (rows 1–2) and right‐bottom (3–4)
    specs = [
      [ {"rowspan":1}, {"rowspan":2} ],
      [ {"rowspan":1}, None           ],
      [ {"rowspan":1}, {"rowspan":2} ],
      [ {"rowspan":1}, None           ]
    ]
    titles = [
      "1) series1 TS",    "5) series2 & CCM",
      "2) series1 phase", "",
      "3) series2 TS",    "6) series2 vs pred",
      "4) series2 phase", ""
    ]

    fig = go.FigureWidget(make_subplots(
        rows=4, cols=2,
        column_widths=[0.4,0.6],
        specs=specs,
        subplot_titles=titles
    ))

    # for row in (2,4):   # the two phase‐space subplots
    #     fig.update_yaxes(
    #         scaleanchor = f"x{row*2-2}",  # matches the corresponding x‑axis
    #         scaleratio = 1,
    #         row=row, col=1
    #     )


    small_sz = 6
    # 3) Base (gray) traces on left
    left_panels = [
      (1, ages,       series1, 'lines+markers'),
      (2, M1[:,0],    M1[:,1], 'markers'),
      (3, ages,       series2, 'lines+markers'),
      (4, M2[:,0],    M2[:,1], 'markers'),
    ]
    for row, x, y, mode in left_panels:
        fig.add_trace(go.Scatter(
            x=x, y=y, mode=mode,
            marker=dict(color='lightgray', size=small_sz),
            showlegend=False
        ), row=row, col=1)

    # 4) Top‐right merged: raw series2 + CCM‐pred timeline
    #    (both go into row1,col2)
    fig.add_trace(go.Scatter(
        x=ages, y=series2, mode='lines+markers',
        marker=dict(color='lightgray', size=small_sz),
        name='series2 raw'
    ), row=1, col=2)
    fig.add_trace(go.Scatter(
        x=[], y=[], mode='markers',
        marker=dict(color='black', size=8),
        name='CCM pred'
    ), row=1, col=2)

    # 5) Bottom‐right merged: scatter series2 vs predicted + fit line
    fig.add_trace(go.Scatter(
        x=[], y=[], mode='markers',
        marker=dict(size=8),
        name='series2 vs pred'
    ), row=3, col=2)
    fig.add_trace(go.Scatter(
        x=[], y=[], mode='lines', line=dict(dash='dash'),
        name='fit (ρ=–)'
    ), row=3, col=2)

    # 6) Add empty “highlight” traces on left for neighbors + target
    highlight_traces = []  # list of (nbr_idx, tgt_idx) per row
    for row in (1,2,3,4):
        # neighbors
        nbr_idx = len(fig.data)
        fig.add_trace(go.Scatter(x=[], y=[],
                                 mode='markers',
                                 marker=dict(size=10),
                                 showlegend=False),
                      row=row, col=1)
        # target
        tgt_idx = len(fig.data)
        fig.add_trace(go.Scatter(x=[], y=[],
                                 mode='markers',
                                 marker=dict(color='red', size=14),
                                 showlegend=False),
                      row=row, col=1)
        highlight_traces.append((nbr_idx, tgt_idx))

    # layout
    fig.update_layout(
        height=1100, width=800,
        margin=dict(l=40, r=40, t=60, b=40),
        title="Interactive CCM"
    )
    fig.update_xaxes(matches="x1", row=1, col=2)

    # 7) Button callback
    current = {'i':0}
    pred_t, pred_y = [], []

    btn = widgets.Button(description="▶️ next CCM step")
    out = widgets.Output()

    def step(_):
        i = current['i']
        if i >= len(t_steps):
            with out: print("done")
            return

        t    = t_steps[i]
        nbrs = get_neighbors(i)
        # weights & predict
        dvec    = dists[i,nbrs]
        w       = np.exp(-dvec/np.max([1e-6, dvec.min()]))
        w      /= w.sum()
        yhat    = (w * series2[t_steps[nbrs]]).sum()

        pred_t.append(t); pred_y.append(yhat)

        # update top‐right CCM pred trace (second trace in that cell → trace idx = 5)
        # base traces are 0–3 (left), 4=series2 raw, 5=CCM pred, 6=scatter,7=fit, then highlights
        fig.data[5].x = ages[pred_t]
        fig.data[5].y = pred_y

        # update bottom‐right scatter & fit
        # trace 6 = scatter, 7 = fit
        fig.data[6].x = series2[pred_t]
        fig.data[6].y = pred_y
        if len(pred_t) > 3:
            lr = linregress(series2[pred_t], pred_y)
            x0, x1 = series2[pred_t].min(), series2[pred_t].max()
            fig.data[7].x = [x0,x1]
            fig.data[7].y = [lr.intercept+lr.slope*x0,
                              lr.intercept+lr.slope*x1]
            fig.data[7].name = f"fit (ρ={lr.rvalue:.2f})"

        # # update highlights on left
        # for ridx, (nbr_idx, tgt_idx) in enumerate(highlight_traces):
        #     row = ridx+1
        #     # base geometry for that panel:
        #     if row in (1,3):
        #         xbase, ybase = (ages, series1) if row==1 else (ages, series2)
        #     else:
        #         m = M1 if row==2 else M2
        #         xbase, ybase = m[:,0], m[:,1]

        #     # neighbors
        #     xi = xbase[nbrs]
        #     yi = ybase[nbrs]
        #     colors = ['blue','green','orange','purple','brown','pink'][:len(nbrs)]
        #     fig.data[nbr_idx].update(x=xi, y=yi, marker_color=colors)

        #     # target
        #     xt, yt = (xbase[t], ybase[t])
        #     fig.data[tgt_idx].update(x=[xt], y=[yt])

        # current['i'] += 1

        # update highlights on left
        for ridx, (nbr_idx, tgt_idx) in enumerate(highlight_traces):
            row = ridx + 1

            if row == 1:
                xbase, ybase = ages, series1
                idx_use = t
            elif row == 2:
                xbase, ybase = M1[:,0], M1[:,1]
                idx_use = i                 # ← use manifold index
            elif row == 3:
                xbase, ybase = ages, series2
                idx_use = t
            else:  # row == 4
                xbase, ybase = M2[:,0], M2[:,1]
                idx_use = i                 # ← use manifold index

            # neighbors
            xi = xbase[nbrs]
            yi = ybase[nbrs]
            colors = ['blue','green','orange','purple','brown','pink'][:len(nbrs)]
            fig.data[nbr_idx].update(x=xi, y=yi, marker_color=colors)

            # target
            xt, yt = xbase[idx_use], ybase[idx_use]
            fig.data[tgt_idx].update(x=[xt], y=[yt])

        current['i'] += 1

    btn.on_click(step)
    display(btn, out)
    return fig




In [3]:
# read the csv file
df_sq = pd.read_csv(r"D:\VScode\bipolar_seesaw_CCM\test\df_sq.csv")
df_pre = pd.read_csv(r"D:\VScode\bipolar_seesaw_CCM\test\df_pre.csv")

ages = df_pre['age'].values
ts   = df_sq['sq'].values
pre  = df_pre['pre'].values

# call the function
fig = interactive_ccm(ages, ts, pre, E=2, tau=5)
fig

Button(description='▶️ next CCM step', style=ButtonStyle())

Output()

FigureWidget({
    'data': [{'marker': {'color': 'lightgray', 'size': 6},
              'mode': 'lines+markers',
              'showlegend': False,
              'type': 'scatter',
              'uid': 'aa253b2d-5c2e-43b5-9d09-31475c019a78',
              'x': array([     0,   1000,   2000, ..., 638000, 639000, 640000], dtype=int64),
              'xaxis': 'x',
              'y': array([-0.79822398, -0.91942033, -1.07206329, ..., -0.9720985 , -0.88808487,
                           0.99305073]),
              'yaxis': 'y'},
             {'marker': {'color': 'lightgray', 'size': 6},
              'mode': 'markers',
              'showlegend': False,
              'type': 'scatter',
              'uid': 'e66b7d05-d8a4-4186-855d-ddca126b1d5c',
              'x': array([-1.01218643,  0.99950211,  0.90586195, ..., -0.9720985 , -0.88808487,
                           0.99305073]),
              'xaxis': 'x3',
              'y': array([-0.79822398, -0.91942033, -1.07206329, ..., -0.94693906, 