In [None]:
# %load ANOVA_Analysis_plots3.py
#!/usr/bin/env python3


#!/usr/bin/env python3
# -*- coding: utf-8 -*-


# Combined figure panels:
# (1) ω² panels with Explained SD annotations (sqrt(AEV)) for two-way and three-way (T, SSR)
# (2) Explained SD-only panels (sqrt(AEV)) for two-way and three-way (T, SSR)
#
# The code reads the provided Excel coupling tables and .mat time series,
# computes |ΔIF|, bins by SM/VPD/T/SSR quartiles, runs ANOVA,
# bootstraps CIs, and saves 400 dpi multi-panel PNGs. It also displays the figures inline.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from pathlib import Path
import scipy.io as sio

# -------------------------
# Global font size increase
# -------------------------
mpl.rcParams.update({
    "font.size": 14,          # default/base text
    "axes.titlesize": 14,     # subplot titles (unless overridden)
    "axes.labelsize": 14,     # axis labels
    "xtick.labelsize": 12,    # tick labels
    "ytick.labelsize": 12,
    "legend.fontsize": 12,
})

# -------------------------
# Paths
# -------------------------
P_LAI_XLSX = Path("IF_Couplings_China_LAI.xlsx")
P_GPP_XLSX = Path("IF_Couplings_China_GPP.xlsx")
P_LAI_MAT  = Path("China_LAI.mat")
P_GPP_MAT  = Path("China_GPP.mat")

OUT_COMBO_W_SD = Path("Combined_ANOVA_omega2_with_SD.png")
OUT_COMBO_SD   = Path("Combined_ANOVA_SD_only.png")

# Bootstrap settings
B = 400
QS = (2.5, 97.5)
RSEED = 20251108

# -------------------------
# Helpers
# -------------------------
def load_if_tables(p_lai, p_gpp):
    lai = pd.read_excel(p_lai)
    gpp = pd.read_excel(p_gpp)
    def _fix_dates(df):
        date_col = None
        for cand in ["date", "Date", "DATE", "Unnamed: 0", "time", "Time"]:
            if cand in df.columns:
                date_col = cand
                break
        if date_col is None:
            raise ValueError("Could not find a date column in IF table.")
        df = df.rename(columns={date_col: "date"})
        df["date"] = pd.to_datetime(df["date"])
        return df
    return _fix_dates(lai), _fix_dates(gpp)


def find_col(cols, driver, target, conds=None):
    base = f"{driver}-{target}".upper()
    norm = [c.strip().upper() for c in cols]
    if not conds:
        for raw, up in zip(cols, norm):
            if up.startswith(base) and ("|" not in up):
                return raw
        return None
    want = set(s.strip().upper() for s in conds)
    # exact-match of RHS tokens
    for raw, up in zip(cols, norm):
        if not up.startswith(base) or "|" not in up:
            continue
        rhs = up.split("|", 1)[1].replace(" ", "")
        parts = set(p.strip().upper() for p in rhs.split(",") if p)
        if parts == want:
            return raw
    # fallback: contains-all
    for raw, up in zip(cols, norm):
        if not up.startswith(base) or "|" not in up:
            continue
        if all(tok in up for tok in want):
            return raw
    return None


def load_mat_series(p_mat, var_tokens, start_year):
    mat = sio.loadmat(p_mat)
    keys = [k for k in mat.keys() if not k.startswith("__")]
    chosen = None
    for k in keys:
        ok = True
        for t in var_tokens:
            if t.lower() not in k.lower():
                ok = False
                break
        if ok and isinstance(mat[k], np.ndarray):
            chosen = k
            break
    if chosen is None:
        raise ValueError(f"Could not find variable containing tokens {var_tokens} in {p_mat.name}. Keys sample: {keys[:12]}")
    arr = np.asarray(mat[chosen]).squeeze()
    if arr.ndim == 2:
        # collapse spatial dimension if present
        arr = np.nanmean(arr, axis=1)
    dates = pd.date_range(f"{start_year}-01-01", periods=arr.shape[0], freq="MS")
    return pd.Series(arr, index=dates, name=chosen)


def quartile_bins(x):
    s = pd.Series(x).astype(float)
    ranks = pd.qcut(s, 4, labels=False, duplicates='drop')
    return (ranks + 1).astype(int).to_numpy()


def build_if_frame(df_if, group_label, driver, target, conds_all,
                   sm_series, vpd_series, T_series, SSR_series):
    col_bi = find_col(df_if.columns, driver, target, conds=None)
    col_mv = find_col(df_if.columns, driver, target, conds=conds_all)
    if col_bi is None or col_mv is None:
        raise ValueError(f"[{group_label}] Missing columns for {driver}-{target}|{','.join(conds_all)}")

    df = pd.DataFrame({
        "date": df_if["date"],
        "T_bi": df_if[col_bi].astype(float).to_numpy(),
        "T_mv": df_if[col_mv].astype(float).to_numpy(),
    }).set_index("date")

    # align regimes
    df["SM"]  = sm_series.reindex(df.index).to_numpy()
    df["VPD"] = vpd_series.reindex(df.index).to_numpy()
    df["T"]   = T_series.reindex(df.index).to_numpy()
    df["SSR"] = SSR_series.reindex(df.index).to_numpy()

    # bins
    df["SM_q"]  = quartile_bins(df["SM"])
    df["VPD_q"] = quartile_bins(df["VPD"])
    df["T_q"]   = quartile_bins(df["T"])
    df["SSR_q"] = quartile_bins(df["SSR"])

    df["abs_dIF"] = np.abs(df["T_bi"] - df["T_mv"])
    df = df[["T_bi","T_mv","abs_dIF","SM_q","VPD_q","T_q","SSR_q"]].dropna()
    df["group"] = group_label
    return df.reset_index(drop=False)


# ANOVA SS/df
def ss_df_two_way(y, A, B):
    y = np.asarray(y, float).reshape(-1)
    A = np.asarray(A, int).reshape(-1)
    B = np.asarray(B, int).reshape(-1)
    mask = np.isfinite(y) & np.isfinite(A) & np.isfinite(B)
    y = y[mask]; A = A[mask]; B = B[mask]
    ybar = np.mean(y)
    N = len(y)
    SS_tot = np.sum((y - ybar)**2)

    a_levels = np.unique(A); na = len(a_levels)
    b_levels = np.unique(B); nb = len(b_levels)

    SS_A = sum((A==a).sum() * (np.mean(y[A==a]) - ybar)**2 for a in a_levels)
    SS_B = sum((B==b).sum() * (np.mean(y[B==b]) - ybar)**2 for b in b_levels)
    SS_AB = 0.0
    for a in a_levels:
        for b in b_levels:
            idx = (A==a)&(B==b)
            if idx.any():
                y_ab = np.mean(y[idx])
                y_a  = np.mean(y[A==a])
                y_b  = np.mean(y[B==b])
                SS_AB += idx.sum() * (y_ab - y_a - y_b + ybar)**2

    df_A  = na - 1
    df_B  = nb - 1
    df_AB = (na - 1) * (nb - 1)
    df_tot = N - 1
    df_err = max(df_tot - (df_A + df_B + df_AB), 1)

    SS_err = SS_tot - SS_A - SS_B - SS_AB
    return (SS_tot, SS_A, SS_B, SS_AB, SS_err,
            df_A, df_B, df_AB, df_err)


def ss_df_three_way(y, A, B, C):
    y = np.asarray(y, float).reshape(-1)
    A = np.asarray(A, int).reshape(-1)
    B = np.asarray(B, int).reshape(-1)
    C = np.asarray(C, int).reshape(-1)
    mask = np.isfinite(y) & np.isfinite(A) & np.isfinite(B) & np.isfinite(C)
    y = y[mask]; A = A[mask]; B = B[mask]; C = C[mask]
    ybar = np.mean(y); N = len(y)
    SS_tot = np.sum((y - ybar)**2)

    a_levels = np.unique(A); na = len(a_levels)
    b_levels = np.unique(B); nb = len(b_levels)
    c_levels = np.unique(C); nc = len(c_levels)

    mean_a = {a: np.mean(y[A==a]) for a in a_levels}
    mean_b = {b: np.mean(y[B==b]) for b in b_levels}
    mean_c = {c: np.mean(y[C==c]) for c in c_levels}
    SS_A = sum((A==a).sum() * (mean_a[a] - ybar)**2 for a in a_levels)
    SS_B = sum((B==b).sum() * (mean_b[b] - ybar)**2 for b in b_levels)
    SS_C = sum((C==c).sum() * (mean_c[c] - ybar)**2 for c in c_levels)

    mean_ab = {}
    mean_ac = {}
    mean_bc = {}
    for a in a_levels:
        for b in b_levels:
            idx = (A==a)&(B==b)
            if idx.any():
                mean_ab[(a,b)] = np.mean(y[idx])
    for a in a_levels:
        for c in c_levels:
            idx = (A==a)&(C==c)
            if idx.any():
                mean_ac[(a,c)] = np.mean(y[idx])
    for b in b_levels:
        for c in c_levels:
            idx = (B==b)&(C==c)
            if idx.any():
                mean_bc[(b,c)] = np.mean(y[idx])

    SS_AB = sum(((A==a)&(B==b)).sum() * (mean_ab[(a,b)] - mean_a[a] - mean_b[b] + ybar)**2
                for a in a_levels for b in b_levels if (a,b) in mean_ab)
    SS_AC = sum(((A==a)&(C==c)).sum() * (mean_ac[(a,c)] - mean_a[a] - mean_c[c] + ybar)**2
                for a in a_levels for c in c_levels if (a,c) in mean_ac)
    SS_BC = sum(((B==b)&(C==c)).sum() * (mean_bc[(b,c)] - mean_b[b] - mean_c[c] + ybar)**2
                for b in b_levels for c in c_levels if (b,c) in mean_bc)

    SS_ABC = 0.0
    for a in a_levels:
        for b in b_levels:
            for c in c_levels:
                idx = (A==a) & (B==b) & (C==c)
                n = idx.sum()
                if n > 0:
                    y_abc = np.mean(y[idx])
                    pred = ( mean_a[a] + mean_b[b] + mean_c[c]
                             - ybar
                             + (mean_ab.get((a,b), ybar) - mean_a[a] - mean_b[b] + ybar)
                             + (mean_ac.get((a,c), ybar) - mean_a[a] - mean_c[c] + ybar)
                             + (mean_bc.get((b,c), ybar) - mean_b[b] - mean_c[c] + ybar) )
                    resid = y_abc - pred
                    SS_ABC += n * (resid**2)

    df_A  = na - 1
    df_B  = nb - 1
    df_C  = nc - 1
    df_AB = (na - 1) * (nb - 1)
    df_AC = (na - 1) * (nc - 1)
    df_BC = (nb - 1) * (nc - 1)
    df_ABC= (na - 1) * (nb - 1) * (nc - 1)
    df_tot= N - 1
    df_err= max(df_tot - (df_A + df_B + df_C + df_AB + df_AC + df_BC + df_ABC), 1)

    SS_err = SS_tot - (SS_A + SS_B + SS_C + SS_AB + SS_AC + SS_BC + SS_ABC)
    return (SS_tot, SS_A, SS_B, SS_C, SS_AB, SS_AC, SS_BC, SS_ABC, SS_err,
            df_A, df_B, df_C, df_AB, df_AC, df_BC, df_ABC, df_err)


def omega_sq(SS_eff, df_eff, SS_err, df_err, SS_tot):
    if df_err <= 0:
        return np.nan
    MS_err = SS_err / df_err
    denom = SS_tot + MS_err
    num = SS_eff - df_eff * MS_err
    return max(0.0, num / denom) if denom > 0 else np.nan


def within_cell_bootstrap_indices(groups, rng):
    idxs = []
    for gidx in groups:
        gidx = np.asarray(gidx, int)
        if gidx.size == 0:
            continue
        draw = rng.choice(gidx, size=gidx.size, replace=True)
        idxs.append(draw)
    if not idxs:
        return np.array([], dtype=int)
    return np.concatenate(idxs)


def bootstrap_two_way(y, A, Bf, B=400, qs=(2.5,97.5), seed=0):
    rng = np.random.default_rng(seed)
    y = np.asarray(y, float).reshape(-1)
    A = np.asarray(A, int).reshape(-1)
    Bf = np.asarray(Bf, int).reshape(-1)
    mask = np.isfinite(y) & np.isfinite(A) & np.isfinite(Bf)
    y = y[mask]; A = A[mask]; Bf = Bf[mask]

    # Cells
    a_levels = np.unique(A); b_levels = np.unique(Bf)
    groups = [np.where((A==a)&(Bf==b))[0] for a in a_levels for b in b_levels]

    # Point estimates
    SS_tot, SS_A, SS_B, SS_AB, SS_err, df_A, df_B, df_AB, df_err = ss_df_two_way(y, A, Bf)
    var_y = SS_tot / max(1, (len(y)-1))
    wA = omega_sq(SS_A, df_A, SS_err, df_err, SS_tot)
    wB = omega_sq(SS_B, df_B, SS_err, df_err, SS_tot)
    wAB= omega_sq(SS_AB, df_AB, SS_err, df_err, SS_tot)
    SD_A = np.sqrt(max(wA,0)*var_y) if np.isfinite(wA) else np.nan
    SD_B = np.sqrt(max(wB,0)*var_y) if np.isfinite(wB) else np.nan
    SD_AB= np.sqrt(max(wAB,0)*var_y) if np.isfinite(wAB) else np.nan

    # Bootstrap arrays
    b_wA = np.zeros(B); b_wB = np.zeros(B); b_wAB = np.zeros(B)
    b_SD_A = np.zeros(B); b_SD_B = np.zeros(B); b_SD_AB = np.zeros(B)
    for b in range(B):
        idx = within_cell_bootstrap_indices(groups, rng)
        if idx.size < 4:
            b_wA[b]=b_wB[b]=b_wAB[b]=np.nan
            b_SD_A[b]=b_SD_B[b]=b_SD_AB[b]=np.nan
            continue
        (SS_tot_b, SS_A_b, SS_B_b, SS_AB_b, SS_err_b,
         df_A_b, df_B_b, df_AB_b, df_err_b) = ss_df_two_way(y[idx], A[idx], Bf[idx])
        var_y_b = SS_tot_b / max(1, (idx.size-1))
        wA_b = omega_sq(SS_A_b, df_A_b, SS_err_b, df_err_b, SS_tot_b)
        wB_b = omega_sq(SS_B_b, df_B_b, SS_err_b, df_err_b, SS_tot_b)
        wAB_b= omega_sq(SS_AB_b, df_AB_b, SS_err_b, df_err_b, SS_tot_b)
        b_wA[b], b_wB[b], b_wAB[b] = wA_b, wB_b, wAB_b
        b_SD_A[b] = np.sqrt(max(wA_b,0)*var_y_b) if np.isfinite(wA_b) else np.nan
        b_SD_B[b] = np.sqrt(max(wB_b,0)*var_y_b) if np.isfinite(wB_b) else np.nan
        b_SD_AB[b]= np.sqrt(max(wAB_b,0)*var_y_b) if np.isfinite(wAB_b) else np.nan

    def ci(arr):
        return (np.nanpercentile(arr, qs[0]), np.nanpercentile(arr, qs[1]))

    return (
        [wA, wB, wAB], [ci(b_wA), ci(b_wB), ci(b_wAB)],
        [SD_A, SD_B, SD_AB], [ci(b_SD_A), ci(b_SD_B), ci(b_SD_AB)]
    )


def bootstrap_three_way(y, A, Bf, C, B=400, qs=(2.5,97.5), seed=0):
    rng = np.random.default_rng(seed)
    y = np.asarray(y, float).reshape(-1)
    A = np.asarray(A, int).reshape(-1)
    Bf = np.asarray(Bf, int).reshape(-1)
    C = np.asarray(C, int).reshape(-1)
    mask = np.isfinite(y) & np.isfinite(A) & np.isfinite(Bf) & np.isfinite(C)
    y = y[mask]; A = A[mask]; Bf = Bf[mask]; C = C[mask]

    # Cells
    a_levels = np.unique(A); b_levels = np.unique(Bf); c_levels = np.unique(C)
    groups = [np.where((A==a)&(Bf==b)&(C==c))[0]
              for a in a_levels for b in b_levels for c in c_levels]

    # Point estimates
    (SS_tot, SS_A, SS_B, SS_C, SS_AB, SS_AC, SS_BC, SS_ABC, SS_err,
     df_A, df_B, df_C, df_AB, df_AC, df_BC, df_ABC, df_err) = ss_df_three_way(y, A, Bf, C)
    var_y = SS_tot / max(1, (len(y)-1))

    def _w(SS, df): return omega_sq(SS, df, SS_err, df_err, SS_tot)
    ws = [_w(SS_A, df_A), _w(SS_B, df_B), _w(SS_C, df_C),
          _w(SS_AB, df_AB), _w(SS_AC, df_AC), _w(SS_BC, df_BC), _w(SS_ABC, df_ABC)]
    SDs = [np.sqrt(max(w,0)*var_y) if np.isfinite(w) else np.nan for w in ws]

    # Bootstrap arrays
    b_ws = [np.zeros(B) for _ in range(7)]
    b_SDs= [np.zeros(B) for _ in range(7)]
    for b in range(B):
        idx = within_cell_bootstrap_indices(groups, rng)
        if idx.size < 8:
            for arr in b_ws + b_SDs: arr[b] = np.nan
            continue
        (SS_tot_b, SS_A_b, SS_B_b, SS_C_b, SS_AB_b, SS_AC_b, SS_BC_b, SS_ABC_b, SS_err_b,
         df_A_b, df_B_b, df_C_b, df_AB_b, df_AC_b, df_BC_b, df_ABC_b, df_err_b) = ss_df_three_way(
             y[idx], A[idx], Bf[idx], C[idx])
        var_y_b = SS_tot_b / max(1, (idx.size-1))
        def _wb(SS, df): return omega_sq(SS, df, SS_err_b, df_err_b, SS_tot_b)
        wbs = [_wb(SS_A_b, df_A_b), _wb(SS_B_b, df_B_b), _wb(SS_C_b, df_C_b),
               _wb(SS_AB_b, df_AB_b), _wb(SS_AC_b, df_AC_b), _wb(SS_BC_b, df_BC_b), _wb(SS_ABC_b, df_ABC_b)]
        for i, w in enumerate(wbs):
            b_ws[i][b]  = w
            b_SDs[i][b] = np.sqrt(max(w,0)*var_y_b) if np.isfinite(w) else np.nan

    def ci(arr): return (np.nanpercentile(arr, QS[0]), np.nanpercentile(arr, QS[1]))

    ws_ci  = [ci(arr) for arr in b_ws]
    SDs_ci = [ci(arr) for arr in b_SDs]
    return ws, ws_ci, SDs, SDs_ci


def bar_ci(ax, labels, means, cis, color, ylabel, title, ylim=None, annotate=None):
    x = np.arange(len(labels))
    bars = ax.bar(x, means, color=color)
    lo = np.array([c[0] for c in cis]); hi = np.array([c[1] for c in cis])
    m  = np.array(means, float)
    yerr = np.vstack([np.maximum(m - lo, 0.0), np.maximum(hi - m, 0.0)])
    ax.errorbar(x, m, yerr=yerr, fmt='none', capsize=3, color="k", lw=1)
    ax.set_xticks(x); ax.set_xticklabels(labels, rotation=25, ha='right')
    if ylim is not None:
        ax.set_ylim(*ylim)
    ax.set_ylabel(ylabel)
    ax.set_title(title, fontsize=10)
    # You can re-enable SD text annotations here if you want
    if annotate is not None:
        ymax = ax.get_ylim()[1]
        for xi, bi, txt in zip(x, bars, annotate):
            y = bi.get_height()
            # ax.text(xi, min(ymax*0.98, y + 0.04*(ylim[1] if ylim else 1.0)),
            #         txt, ha='center', va='bottom', fontsize=8)


# -------------------------
# Load data
# -------------------------
lai_if, gpp_if = load_if_tables(P_LAI_XLSX, P_GPP_XLSX)

SM_LAI  = load_mat_series(P_LAI_MAT, ("SM","China"), 1981)
VPD_LAI = load_mat_series(P_LAI_MAT, ("VPD","China"), 1981)
T_LAI   = load_mat_series(P_LAI_MAT, ("T","China"), 1981)
SSR_LAI = load_mat_series(P_LAI_MAT, ("SSR","China"), 1981)

SM_GPP  = load_mat_series(P_GPP_MAT, ("SM","China"), 1982)
VPD_GPP = load_mat_series(P_GPP_MAT, ("VPD","China"), 1982)
T_GPP   = load_mat_series(P_GPP_MAT, ("T","China"), 1982)
SSR_GPP = load_mat_series(P_GPP_MAT, ("SSR","China"), 1982)

conds_full = ["VPD","T","SSR"]

df_SM_LAI  = build_if_frame(lai_if, "SM→LAI",  "SM",  "LAI", conds_full, SM_LAI, VPD_LAI, T_LAI, SSR_LAI)
df_VPD_LAI = build_if_frame(lai_if, "VPD→LAI", "VPD", "LAI", conds_full, SM_LAI, VPD_LAI, T_LAI, SSR_LAI)
df_SM_GPP  = build_if_frame(gpp_if, "SM→GPP",  "SM",  "GPP", conds_full, SM_GPP, VPD_GPP, T_GPP, SSR_GPP)
df_VPD_GPP = build_if_frame(gpp_if, "VPD→GPP", "VPD", "GPP", conds_full, SM_GPP, VPD_GPP, T_GPP, SSR_GPP)

datasets = [
    ("SM→LAI",  df_SM_LAI),
    ("VPD→LAI", df_VPD_LAI),
    ("SM→GPP",  df_SM_GPP),
    ("VPD→GPP", df_VPD_GPP),
]

# -------------------------
# Build combined figures
# -------------------------
row_panel_labels = ["(a)", "(b)", "(c)"]
col_titles = [d[0] for d in datasets]  # ["SM→LAI", "VPD→LAI", "SM→GPP", "VPD→GPP"]

# (1) ω² with SD annotations (sqrt(AEV))
fig_wsd, axes_wsd = plt.subplots(3, 4, figsize=(16, 10), constrained_layout=True)

for col, (title, df) in enumerate(datasets):
    # Two-way
    w_means, w_cis, sd_means, sd_cis = bootstrap_two_way(
        df["abs_dIF"].to_numpy(), df["SM_q"].to_numpy(), df["VPD_q"].to_numpy(),
        B=B, qs=QS, seed=RSEED
    )
    labels = ["SM", "VPD", "SM×VPD"]
    ax = axes_wsd[0, col]
    bar_ci(ax, labels, w_means, w_cis, color="#4C72B0",
           ylabel="ω² (95% CI)", title="", ylim=(0,1),
           annotate=[f"SD={v:.3f}" for v in sd_means])
    ax.set_title(col_titles[col], fontsize=14)

    # Three-way with T
    w_means, w_cis, sd_means, sd_cis = bootstrap_three_way(
        df["abs_dIF"].to_numpy(), df["SM_q"].to_numpy(), df["VPD_q"].to_numpy(), df["T_q"].to_numpy(),
        B=B, qs=QS, seed=RSEED
    )
    labels = ["SM", "VPD", "T", "SM×VPD", "SM×T", "VPD×T", "SM×VPD×T"]
    ax = axes_wsd[1, col]
    bar_ci(ax, labels, w_means, w_cis, color="#4C72B0",
           ylabel="ω² (95% CI)", title="", ylim=(0,1),
           annotate=[f"SD={v:.3f}" for v in sd_means])

    # Three-way with SSR
    w_means, w_cis, sd_means, sd_cis = bootstrap_three_way(
        df["abs_dIF"].to_numpy(), df["SM_q"].to_numpy(), df["VPD_q"].to_numpy(), df["SSR_q"].to_numpy(),
        B=B, qs=QS, seed=RSEED
    )
    labels = ["SM", "VPD", "SSR", "SM×VPD", "SM×SSR", "VPD×SSR", "SM×VPD×SSR"]
    ax = axes_wsd[2, col]
    bar_ci(ax, labels, w_means, w_cis, color="#4C72B0",
           ylabel="ω² (95% CI)", title="", ylim=(0,1),
           annotate=[f"SD={v:.3f}" for v in sd_means])

# Add row labels (a), (b), (c) at top-left of each row
for r in range(3):
    ax_row = axes_wsd[r, 0]
    ax_row.text(-0.15, 1.05, row_panel_labels[r],
                transform=ax_row.transAxes,
                fontsize=16, fontweight='bold',
                ha='right', va='bottom')

fig_wsd.savefig(OUT_COMBO_W_SD, dpi=400, bbox_inches="tight")
plt.show()

# (2) SD-only (sqrt(AEV)) — y-axis label only on the leftmost subplot of each row
YLAB_SD = "Explained SD of |ΔIF| (95% CI)"
SD_COLOR = "#ea841e"

fig_sd, axes_sd = plt.subplots(3, 4, figsize=(16, 10), constrained_layout=True)

for col, (title, df) in enumerate(datasets):
    # Two-way
    w_means, w_cis, sd_means, sd_cis = bootstrap_two_way(
        df["abs_dIF"].to_numpy(), df["SM_q"].to_numpy(), df["VPD_q"].to_numpy(),
        B=B, qs=QS, seed=RSEED
    )
    labels = ["SM", "VPD", "SM×VPD"]
    ax = axes_sd[0, col]
    bar_ci(ax, labels, sd_means, sd_cis, color=SD_COLOR,
           ylabel=(YLAB_SD if col == 0 else ""), title="")
    ax.set_title(col_titles[col], fontsize=14)

    # Three-way with T
    w_means, w_cis, sd_means, sd_cis = bootstrap_three_way(
        df["abs_dIF"].to_numpy(), df["SM_q"].to_numpy(), df["VPD_q"].to_numpy(), df["T_q"].to_numpy(),
        B=B, qs=QS, seed=RSEED
    )
    labels = ["SM", "VPD", "T", "SM×VPD", "SM×T", "VPD×T", "SM×VPD×T"]
    ax = axes_sd[1, col]
    bar_ci(ax, labels, sd_means, sd_cis, color=SD_COLOR,
           ylabel=(YLAB_SD if col == 0 else ""), title="")

    # Three-way with SSR
    w_means, w_cis, sd_means, sd_cis = bootstrap_three_way(
        df["abs_dIF"].to_numpy(), df["SM_q"].to_numpy(), df["VPD_q"].to_numpy(), df["SSR_q"].to_numpy(),
        B=B, qs=QS, seed=RSEED
    )
    labels = ["SM", "VPD", "SSR", "SM×VPD", "SM×SSR", "VPD×SSR", "SM×VPD×SSR"]
    ax = axes_sd[2, col]
    bar_ci(ax, labels, sd_means, sd_cis, color=SD_COLOR,
           ylabel=(YLAB_SD if col == 0 else ""), title="")

# Row labels for SD-only figure as well (moved left/up to avoid the y-axis label)
for r in range(3):
    ax_row = axes_sd[r, 0]
    ax_row.text(-0.07, 1.08, row_panel_labels[r],
                transform=ax_row.transAxes,
                fontsize=16, fontweight='bold',
                ha='right', va='bottom', clip_on=False)

fig_sd.savefig(OUT_COMBO_SD, dpi=400, bbox_inches="tight")
plt.show()

print("Saved 400 dpi:")
print(f" - {OUT_COMBO_W_SD}")
print(f" - {OUT_COMBO_SD}")
