In [None]:

#!/usr/bin/env python3

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from numpy.linalg import inv, LinAlgError
from matplotlib.patches import Patch

# ======================
# ---- CONFIG ----------
# ======================
N_SAMPLES        = 3000      # per realization (post-burn-in)
BURN_IN          = 200
NUM_REALIZATIONS = 80        # increase for smoother bands
K_DIFF           = 1         # finite-difference step for derivative
DT               = 1.0
RNG_BASE         = 54000

# Bootstrap ribbons
BOOTSTRAP_B      = 6000
RIBBON_QS        = (2.5, 97.5)  # ~95%

# ---- VAR dynamics (tunable) ----
PHI  = 0.25        # AR(1) diagonal persistence
C21  = 0.05        # x2 -> x1 (weak back edge to avoid cycles dominating)
C31  = 0.55        # x3 -> x1  (mediated path leg, strong)
C12  = 0.45        # x1 -> x2  (mediated path leg, strong)
C32  = 0.12        # x3 -> x2  (direct path, small)

# Hidden-layer design
GAMMA_VALS   = np.linspace(0.0, 0.8, 9)  # x4→{x1,x2}
ETA_X5_TO_X4 = 0.70                      # x5→x4

# Innovation scales (std dev per node): [x1, x2, x3, x4, x5]
B_DIAG = np.array([0.6, 0.5, 0.5, 0.6, 0.6], dtype=float)

EPS = 1e-12

# ======================
# ---- PLOT STYLE ------
# ======================
mpl.rcParams.update({
    "font.size": 16,         # was 15
    "axes.titlesize": 18,    # was 17
    "axes.labelsize": 16,    # was 15
    "xtick.labelsize": 14,   # was 13
    "ytick.labelsize": 14,   # was 13
    "legend.fontsize": 15,   # was 13  (legend +2)
})

# Requested palette
BLUE_HEX   = "#256ea2"
ORANGE_HEX = "#ea841e"

# ======================
# ---- LKIF helpers ----
# ======================
def _center(X):
    return X - X.mean(axis=0, keepdims=True)

def cov_mat(X):
    Xc = _center(X)
    return (Xc.T @ Xc) / max(1, (Xc.shape[0] - 1))

def forward_diff(x, k=1, dt=1.0):
    return (x[k:] - x[:-k]) / (k * dt)

def _var(x):
    xc = x - np.mean(x)
    return float(np.dot(xc, xc) / max(1, len(xc) - 1))

def stable_inv(C, eps=EPS):
    try:
        return inv(C)
    except LinAlgError:
        return inv(C + eps * np.eye(C.shape[0]))

def _cov_blocks_for_target(X, i, k=1, dt=1.0):
    """
    Liang row for target i:  xdot_i ≈ f_i + Σ_j a_{ij} x_j + noise
    Returns C, Cxdot, vardot for the row estimate.
    """
    Xlag = X[:-k, :]
    xdot = forward_diff(X[:, i], k=k, dt=dt)
    if Xlag.shape[0] != xdot.shape[0]:
        raise ValueError("Length mismatch between Xlag and xdot.")
    C = cov_mat(Xlag)
    Xlagc = _center(Xlag)
    xdotc = xdot - xdot.mean()
    n = max(1, Xlag.shape[0] - 1)
    Cxdot = (Xlagc.T @ xdotc) / n
    vardot = _var(xdot)
    return C, Cxdot, vardot

def if_and_tau_multivar_cov(X, i, j, k=1, dt=1.0, eps=EPS):
    """
    Multivariate IF within subspace X (columns index local).
    Returns (T_{j→i}, τ_{j→i}).
    """
    C, Cxdot, vardot = _cov_blocks_for_target(X, i, k=k, dt=dt)
    a   = stable_inv(C, eps=eps) @ Cxdot
    Cii = C[i, i] + eps

    # contributions to i from all parents in X
    T_all = np.zeros(C.shape[0])
    for m in range(C.shape[0]):
        if m == i:
            continue
        T_all[m] = a[m] * (C[i, m] / Cii)

    Hself = a[i]
    var_resid = vardot - 2.0 * (a @ Cxdot) + float(a.T @ C @ a)
    g = max(var_resid / dt, 0.0)
    Hnoise = g / (2.0 * Cii)

    Z = abs(Hself) + np.sum(np.abs(T_all)) + abs(Hnoise)
    Z = max(Z, eps)
    T = float(T_all[j])
    tau = float(T / Z)
    return T, tau

def if_and_tau_bivariate_cov(X2, k=1, dt=1.0, eps=EPS):
    """
    Bivariate IF on X2 = [Xi, Xj] (columns = [target, parent]).
    Returns (T_{j→i}, τ_{j→i}).
    """
    C, Cxdot, vardot = _cov_blocks_for_target(X2, i=0, k=k, dt=dt)
    C11, C12_, C22 = C[0, 0], C[0, 1], C[1, 1]
    detC = C11 * C22 - C12_ * C12_
    if abs(detC) < eps:
        C = C + eps * np.eye(2)
        C11, C12_, C22 = C[0, 0], C[0, 1], C[1, 1]
        detC = C11 * C22 - C12_ * C12_

    a_self = ( C22 * Cxdot[0] - C12_ * Cxdot[1]) / detC
    a_par  = (-C12_ * Cxdot[0] + C11 * Cxdot[1]) / detC

    Cii = C11 + eps
    T   = float(a_par * (C12_ / Cii))

    a_vec = np.array([a_self, a_par])
    var_resid = vardot - 2.0 * (a_vec @ Cxdot) + float(a_vec.T @ C @ a_vec)
    g = max(var_resid / dt, 0.0)
    Hself  = a_self
    Hnoise = g / (2.0 * Cii)
    Z = abs(Hself) + abs(Hnoise) + abs(T)
    Z = max(Z, eps)
    tau = float(T / Z)
    return T, tau

# ======================
# ---- SIMULATOR -------
# ======================
def simulate_var(A, b_diag, N=4000, burn=800, seed=0):
    rng = np.random.default_rng(seed)
    d = A.shape[0]
    X = np.zeros((N + burn, d))
    for t in range(N + burn - 1):
        eps = rng.normal(0.0, 1.0, size=d) * b_diag
        X[t + 1] = A @ X[t] + eps
    return X[burn:]

def build_A(gamma, eta=ETA_X5_TO_X4):
    """
    Nodes: [x1, x2, x3, x4, x5]
     - diag PHI
     - x2→x1 (C21), x3→x1 (C31), x1→x2 (C12)
     - x3→x2 (C32)  [direct, small]
     - x4→x1, x4→x2 (gamma), x5→x4 (eta)
    """
    A = np.zeros((5, 5), dtype=float)
    np.fill_diagonal(A, PHI)

    # observed edges (direct + mediated structure)
    A[0, 1] = C21   # x2 -> x1
    A[0, 2] = C31   # x3 -> x1  [mediated part 1]
    A[1, 0] = C12   # x1 -> x2  [mediated part 2]
    A[1, 2] = C32   # x3 -> x2  [direct, small]

    # hidden confounder layer
    A[0, 3] = gamma # x4 -> x1
    A[1, 3] = gamma # x4 -> x2
    A[3, 4] = eta   # x5 -> x4
    return A

# ======================
# ---- BOOTSTRAP -------
# ======================
def bootstrap_mean_bands(samples_2d, B=BOOTSTRAP_B, qs=RIBBON_QS, seed=12345):
    """
    samples_2d: (n_gamma, R) — realizations per gamma
    Returns mean and bootstrap [lo,hi] per gamma.
    """
    rng = np.random.default_rng(seed)
    n_gamma, R = samples_2d.shape
    means = samples_2d.mean(axis=1)
    lo   = np.empty(n_gamma); hi = np.empty(n_gamma)
    for g in range(n_gamma):
        x = samples_2d[g]
        if R <= 1:
            lo[g] = hi[g] = means[g]; continue
        idxs = rng.integers(0, R, size=(B, R))
        m = x[idxs].mean(axis=1)
        lo[g], hi[g] = np.percentile(m, qs)
    return dict(mean=means, lo=lo, hi=hi)

def signif_mask(band):
    """True where 0 is outside the ribbon (~95% significance)."""
    return (band['lo'] > 0) | (band['hi'] < 0)

# ======================
# ---- EXPERIMENT -------
# ======================
def run_sweep_all_paths(gammas=GAMMA_VALS, eta=ETA_X5_TO_X4,
                        N=N_SAMPLES, R=NUM_REALIZATIONS):
    """
    Sweep γ, simulate, and store arrays (n_gamma, R) for:
      T12_bi ≡ T_{1→2}   on [x2,x1]
      T12_mi ≡ T_{1→2|3} on [x1,x2,x3]
      T13_bi ≡ T_{1→3}   on [x3,x1]
      T13_mi ≡ T_{1→3|2} on [x1,x2,x3]
      T32_bi ≡ T_{3→2}   on [x2,x3]
      T32_mi ≡ T_{3→2|1} on [x1,x2,x3]
    """
    nG = len(gammas)
    T12_bi = np.zeros((nG, R)); T12_mi = np.zeros((nG, R))
    T13_bi = np.zeros((nG, R)); T13_mi = np.zeros((nG, R))
    T32_bi = np.zeros((nG, R)); T32_mi = np.zeros((nG, R))

    for gi, gamma in enumerate(gammas):
        A = build_A(gamma, eta=eta)
        for r in range(R):
            seed = RNG_BASE + 10000 * int(100 * gamma) + 17 * r
            X = simulate_var(A, b_diag=B_DIAG, N=N, burn=BURN_IN, seed=seed)
            Xobs = X[:, [0, 1, 2]]  # [x1,x2,x3]

            # 1→2
            Tb12, _ = if_and_tau_bivariate_cov(Xobs[:, [1, 0]], k=K_DIFF, dt=DT)
            Tm12, _ = if_and_tau_multivar_cov(Xobs, i=1, j=0, k=K_DIFF, dt=DT)
            T12_bi[gi, r] = Tb12
            T12_mi[gi, r] = Tm12

            # 1→3
            Tb13, _ = if_and_tau_bivariate_cov(Xobs[:, [2, 0]], k=K_DIFF, dt=DT)
            Tm13, _ = if_and_tau_multivar_cov(Xobs, i=2, j=0, k=K_DIFF, dt=DT)
            T13_bi[gi, r] = Tb13
            T13_mi[gi, r] = Tm13

            # 3→2
            Tb32, _ = if_and_tau_bivariate_cov(Xobs[:, [1, 2]], k=K_DIFF, dt=DT)
            Tm32, _ = if_and_tau_multivar_cov(Xobs, i=1, j=2, k=K_DIFF, dt=DT)
            T32_bi[gi, r] = Tb32
            T32_mi[gi, r] = Tm32

    return dict(
        gammas=np.array(gammas),
        T12_bi=T12_bi, T12_mi=T12_mi,
        T13_bi=T13_bi, T13_mi=T13_mi,
        T32_bi=T32_bi, T32_mi=T32_mi,
        meta=dict(eta=eta)
    )

# ======================
# ---- PLOTTING --------
# ======================
def make_two_panel_overlay(res, savepath="if_two_panel_overlay.png", dpi=400,
                           ylim_top=None, ylim_bottom_left=None, ylim_bottom_right=(0,1)):
    g   = res['gammas']
    eta = res['meta']['eta']

    # Bootstrap bands
    b_T12_bi = bootstrap_mean_bands(res['T12_bi'])
    b_T12_mi = bootstrap_mean_bands(res['T12_mi'])
    b_T13_bi = bootstrap_mean_bands(res['T13_bi'])
    b_T13_mi = bootstrap_mean_bands(res['T13_mi'])
    b_T32_bi = bootstrap_mean_bands(res['T32_bi'])
    b_T32_mi = bootstrap_mean_bands(res['T32_mi'])

    sig_T12_bi = signif_mask(b_T12_bi)
    sig_T12_mi = signif_mask(b_T12_mi)
    sig_T13_bi = signif_mask(b_T13_bi)
    sig_T13_mi = signif_mask(b_T13_mi)
    sig_T32_bi = signif_mask(b_T32_bi)
    sig_T32_mi = signif_mask(b_T32_mi)

    # ΔIF share for 3→2
    mean_32_bi = b_T32_bi['mean']
    mean_32_mi = b_T32_mi['mean']
    with np.errstate(divide='ignore', invalid='ignore'):
        share = 1.0 - np.where(mean_32_bi != 0, mean_32_mi / mean_32_bi, np.nan)

    colors = {
        '12': BLUE_HEX,    # x1→x2
        '13': ORANGE_HEX,  # x1→x3
        '32': 'C2',        # x3→x2
    }

    fig, (ax_top, ax_bot_left) = plt.subplots(2, 1, figsize=(10.8, 9.6), sharex=True)

    # ---------- TOP: T1→2 & T1→3 ----------
    # ribbons
    ax_top.fill_between(g, b_T12_bi['lo'], b_T12_bi['hi'], color=colors['12'], alpha=0.15)
    ax_top.fill_between(g, b_T12_mi['lo'], b_T12_mi['hi'], color=colors['12'], alpha=0.08)
    ax_top.fill_between(g, b_T13_bi['lo'], b_T13_bi['hi'], color=colors['13'], alpha=0.15)
    ax_top.fill_between(g, b_T13_mi['lo'], b_T13_mi['hi'], color=colors['13'], alpha=0.08)

    # lines
    ax_top.plot(g, b_T12_bi['mean'], '-',  color=colors['12'], lw=1.9, label=r'$T_{1\to2}$ (bivar)')
    ax_top.plot(g, b_T12_mi['mean'], '--', color=colors['12'], lw=1.9, label=r'$T_{1\to2\,|\,3}$ (multi)')
    ax_top.plot(g, b_T13_bi['mean'], '-',  color=colors['13'], lw=1.9, label=r'$T_{1\to3}$ (bivar)')
    ax_top.plot(g, b_T13_mi['mean'], '--', color=colors['13'], lw=1.9, label=r'$T_{1\to3\,|\,2}$ (multi)')

    # filled markers
    for xi, yi in zip(g, b_T12_bi['mean']):
        ax_top.scatter([xi],[yi], s=52, marker='o', facecolor=colors['12'], edgecolor='white', linewidth=1.0)
    for xi, yi in zip(g, b_T12_mi['mean']):
        ax_top.scatter([xi],[yi], s=52, marker='s', facecolor=colors['12'], edgecolor='white', linewidth=1.0)
    for xi, yi in zip(g, b_T13_bi['mean']):
        ax_top.scatter([xi],[yi], s=52, marker='^', facecolor=colors['13'], edgecolor='white', linewidth=1.0)
    for xi, yi in zip(g, b_T13_mi['mean']):
        ax_top.scatter([xi],[yi], s=52, marker='v', facecolor=colors['13'], edgecolor='white', linewidth=1.0)

    ax_top.axhline(0, lw=0.8, color='k', alpha=0.6)
    ax_top.set_ylabel('Information flow [nats/unit time]', fontsize=16)  # was 15
    ax_top.set_title('(b)', fontsize=18)                                 # was 17
    ax_top.grid(True, which='major', axis='y', alpha=0.2)
    if ylim_top is not None:
        ax_top.set_ylim(ylim_top)
    ax_top.legend(ncol=2, frameon=True, framealpha=0.95, fontsize=15)    # was 13

    # ---------- BOTTOM: T3→2 with ΔIF-share twin axis ----------
    ax_bot_left.fill_between(g, b_T32_bi['lo'], b_T32_bi['hi'], color=colors['32'], alpha=0.15)
    ax_bot_left.fill_between(g, b_T32_mi['lo'], b_T32_mi['hi'], color=colors['32'], alpha=0.08)

    ax_bot_left.plot(g, mean_32_bi, '-',  color=colors['32'], lw=1.9, label=r'$T_{3\to2}$ (bivar)')
    ax_bot_left.plot(g, mean_32_mi, '--', color=colors['32'], lw=1.9, label=r'$T_{3\to2\,|\,1}$ (multi)')

    for xi, yi in zip(g, mean_32_bi):
        ax_bot_left.scatter([xi],[yi], s=52, marker='o', facecolor=colors['32'], edgecolor='white', linewidth=1.0)
    for xi, yi in zip(g, mean_32_mi):
        ax_bot_left.scatter([xi],[yi], s=52, marker='s', facecolor=colors['32'], edgecolor='white', linewidth=1.0)

    ax_bot_left.axhline(0, lw=0.8, color='k', alpha=0.6)
    ax_bot_left.set_xlabel('Confounder strength $\\gamma$ (x4→{x1,x2})', fontsize=16)  # was 15
    ax_bot_left.set_ylabel('Information flow [nats/unit time]', fontsize=16)          # was 15
    ax_bot_left.set_title('(c)', fontsize=18)                                         # was 17
    ax_bot_left.grid(True, which='major', axis='y', alpha=0.2)
    if ylim_bottom_left is not None:
        ax_bot_left.set_ylim(ylim_bottom_left)

    # twin axis for share
    ax_bot_right = ax_bot_left.twinx()
    ax_bot_right.plot(g, share, color='C5', lw=1.7, marker='^',
                      markerfacecolor='C5', markeredgecolor='white',
                      markeredgewidth=1.1, label='ΔIF share (1 - multi/bivar)')
    ax_bot_right.set_ylabel('ΔIF share', fontsize=16)  # was 15
    if ylim_bottom_right is not None:
        ax_bot_right.set_ylim(*ylim_bottom_right)

    # Legends
    band_proxy = Patch(facecolor='0.5', alpha=0.15, edgecolor='none', label='95% bootstrap band')
    h_left, l_left = ax_bot_left.get_legend_handles_labels()
    h_right, l_right = ax_bot_right.get_legend_handles_labels()
    ax_bot_left.legend(h_left + [band_proxy] + h_right,
                       l_left + ['95% band'] + l_right,
                       loc='upper right', frameon=True, framealpha=0.95, fontsize=15)  # was 13

    plt.tight_layout()
    plt.savefig(savepath, dpi=dpi, bbox_inches='tight')
    plt.show()
    print(f"Saved figure at {dpi} dpi: {savepath}")

# ======================
# ---- MAIN ------------
# ======================
if __name__ == "__main__":
    res = run_sweep_all_paths(gammas=GAMMA_VALS,
                              eta=ETA_X5_TO_X4,
                              N=N_SAMPLES,
                              R=NUM_REALIZATIONS)

    # Example y-limits you can tweak:
    make_two_panel_overlay(res,
                           savepath="if_two_panel_overlay.png",
                           dpi=400,
                           ylim_top=None,                # e.g., (-0.05, 0.8)
                           ylim_bottom_left=None,        # e.g., (-0.05, 0.6)
                           ylim_bottom_right=(0.0, 1.0)  # ΔIF share axis
                           )