In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import

In [None]:
import copy
import os
import pathlib

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
from IPython.display import display
from sklearn.linear_model import LinearRegression
from src.analysis.plot_helper import get_symmetrized_vmin_vmax
from src.analysis.time_series_preprocessor import preprocess
from src.data.io_matlab import read_matlab_time_series
from src.data.io_pickles import read_pickle, write_pickle
from src.information_theory.corr_helper import calc_lag_rhos
from src.information_theory.loos_klapp_2020 import (
    Config,
    calc_I_x_to_y,
    calc_I_y_to_x,
    calc_Qx,
    calc_Qy,
    calc_Sigmaxx,
    calc_Sigmayy,
)
from src.simulation.sde_coeff_estimator import (
    estimate_sde_coeffs_for_bcs,
    get_Xs_and_ys_for_bcs,
)
from src.utils.random_seed_helper import set_all_seeds
from statsmodels.tsa import stattools
from statsmodels.tsa.api import VAR
from tqdm.notebook import tqdm

plt.rcParams["font.family"] = "serif"
plt.style.use("tableau-colorblind10")

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
BCS_DATA_DIR = f"{ROOT_DIR}/data"

OBS_FILE_PATH = f"{BCS_DATA_DIR}/observations.mat"
GFDL_FILE_PATH = f"{BCS_DATA_DIR}/simulation_GFDLCM4C192.mat"

In [None]:
DPI = 300
WRITE_WEBP = True
WRITE_EPS = False
WRITE_JPG = False
WRITE_PDF = False

FIG_DIR = f"./fig"
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
LABELS_FOR_WBCS = {
    "gulfstrm": "Gulf Stream",
    "kuroshio": "Kuroshio",
}

MATH_LABELS_FOR_WBCS = {
    "gulfstrm": r"$T_G$",
    "kuroshio": r"$T_K$",
    "Tx": r"$D_G$",
    "Ty": r"$D_K$",
    "a": r"$c_{G \leftarrow K}$",
    "b": r"$c_{K \leftarrow G}$",
    "Qx": r"$\dot{\sigma}_G$",
    "Qy": r"$\dot{\sigma}_K$",
    "Ix": r"$\dot{I}_{G \leftarrow K}$",
    "Iy": r"$\dot{I}_{K \leftarrow G}$",
    "rx": r"$r_G$",
    "ry": r"$r_K$",
}

LABELS_FOR_DATA = {
    "Observation": "OISST",
    "Simulation-GFDL-CM4C192": "GFDL-CM4C192",
}

LINE_STYLES = {
    "Observation": "-",
    "Simulation-GFDL-CM4C192": "--",
}

# tableau-colorblind10
COLOR_LIST = [
    "#006BA4",
    "#FF800E",
    "#ABABAB",
    "#595959",
    "#5F9ED1",
    "#C85200",
    "#898989",
    "#A2C8EC",
    "#FFBC79",
    "#CFCFCF",
]

COLORS = {
    "gulfstrm": COLOR_LIST[0],
    "kuroshio": COLOR_LIST[1],
    "Observation": COLOR_LIST[0],
    "Simulation-GFDL-CM4C192": COLOR_LIST[1],
    "AR1": COLOR_LIST[2],
    "VAR1": COLOR_LIST[2],
    "CrossCorr": COLOR_LIST[3],
}

# Time series

In [None]:
rm = 1
max_lag = 1
target_cols = ["gulfstrm", "kuroshio"]

df_obs_time_series = read_matlab_time_series(OBS_FILE_PATH)
df_gfdl_time_series = read_matlab_time_series(GFDL_FILE_PATH)

plt.rcParams["font.size"] = 17
fig, axes = plt.subplots(2, 2, sharey=True, sharex=True, figsize=[14, 5.5])

for j, (data_name, df_time_series) in enumerate(
    zip(
        ["Observation", "Simulation-GFDL-CM4C192"],
        [df_obs_time_series, df_gfdl_time_series],
    )
):
    Xs, ys, ts = get_Xs_and_ys_for_bcs(
        df_time_series, target_cols=target_cols, max_lag=max_lag, rm=rm
    )
    set_all_seeds()
    reg = LinearRegression(fit_intercept=False).fit(Xs, ys)
    preds = reg.predict(Xs)

    for i, col in enumerate(target_cols):
        ax = axes[i, j]
        ax.yaxis.set_ticks_position("both")

        ax.plot(ts, ys[:, i], "-", lw=1.0, color=COLORS[col])
        ax.plot(ts, preds[:, i], "--", lw=1.0, color=COLORS["AR1"])

        # x, y = 1969, 2.5
        # l = LABELS_FOR_DATA[data_name] + ", " + MATH_LABELS_FOR_WBCS[col]
        # ax.text(x, y, l, fontsize="medium")

        l = "--- Estimations by bivariate AR1"
        c = COLORS["AR1"]
        ax.text(1951, -3.3, l, fontsize="medium", color=c, fontweight="bold")

        if i == 1:
            ax.set_xlabel("Time [year]")
            ax.set_xticks(np.linspace(1950, 2020, 8))

        ax.set_xlim([1950, 2020])
        ax.xaxis.set_minor_locator(
            matplotlib.ticker.FixedLocator(np.linspace(1950, 2020, 15))
        )

        if i == 0:
            ax.set_title(f"{LABELS_FOR_DATA[data_name]}, Standardized SST Anomaly")

        ax.set_ylabel(MATH_LABELS_FOR_WBCS[col])
        ax.set_ylim([-3.5, 3.5])
        ax.set_yticks(np.linspace(-3.0, 3.0, 5))
        ax.yaxis.set_minor_locator(
            matplotlib.ticker.FixedLocator(np.linspace(-3.5, 3.5, 15))
        )

        if j == 0 and i == 0:
            ax.text(
                1932,
                -3.0,
                LABELS_FOR_WBCS[col],
                rotation="vertical",
                fontsize="large",
                # fontweight="bold",
            )
        if j == 0 and i == 1:
            ax.text(
                1932,
                -2.2,
                LABELS_FOR_WBCS[col],
                rotation="vertical",
                fontsize="large",
                # fontweight="bold",
            )

ax = axes[0, 0]
xs = [2024, 2024]
ys = [5, -13]
ax.plot(xs, ys, "k", lw=2, clip_on=False)

plt.tight_layout()
plt.show()

fig_name = "fig02"
if WRITE_WEBP:
    fig.savefig(f"{FIG_DIR}/{fig_name}.webp", bbox_inches="tight")
if WRITE_JPG:
    fig.savefig(f"{FIG_DIR}/{fig_name}.jpg", bbox_inches="tight", dpi=DPI)
if WRITE_EPS:
    fig.savefig(f"{FIG_DIR}/{fig_name}.eps", bbox_inches="tight", dpi=DPI)
if WRITE_PDF:
    fig.savefig(f"{FIG_DIR}/{fig_name}.pdf", bbox_inches="tight", dpi=DPI)

# Table of SDE coeffs

In [None]:
df_obs_time_series = read_matlab_time_series(OBS_FILE_PATH)
df_gfdl_time_series = read_matlab_time_series(GFDL_FILE_PATH)

obs_model_params = estimate_sde_coeffs_for_bcs(df_obs_time_series)
gfd_model_params = estimate_sde_coeffs_for_bcs(df_gfdl_time_series)

all_vars = ["rx", "ry", "a", "b", "Tx", "Ty"]
df_table = pd.DataFrame(
    index=[MATH_LABELS_FOR_WBCS[v] for v in all_vars],
    columns=[LABELS_FOR_DATA[d] for d in ["Observation", "Simulation-GFDL-CM4C192"]],
)
alpha = 0.05

for data_name, model_params in zip(
    ["Observation", "Simulation-GFDL-CM4C192"], [obs_model_params, gfd_model_params]
):
    dict_sde_params = read_pickle(f"./bootstrap_replications_{data_name}.pickle")
    cfg = copy.deepcopy(model_params["config"])

    all_sde_data = {v: [] for v in all_vars}
    cnt = 0
    cnt_spiral = 0
    cnt_total = 0
    cnt_spiral_negative = 0

    for param in dict_sde_params.values():
        cnt_total += 1
        for v in all_vars:
            all_sde_data[v].append(param[v])
        tr = param["rx"] + param["ry"]
        det = param["rx"] * param["ry"] - param["a"] * param["b"]
        dis = tr**2 - 4 * det
        if dis <= 0:
            cnt_spiral += 1
            if param["a"] < 0:
                cnt_spiral_negative += 1

        if tr <= 0 or det <= 0:
            cnt += 1
        # if det (determinant) is negative, two eigenvalues have opposite signs for each other.
        # if tr (trace) is negative, eigen modes are unstable.
        # if dis (distriminant) is negative, eigen modes become oscillation modes
    logger.info(
        f"\n{data_name}, cnt (for non-decaying solutions) = {cnt}, cnt (for spiral solutions) = {cnt_spiral}, total cnt = {cnt_total}"
    )
    logger.info(f"cnt for spiral and negative = {cnt_spiral_negative}\n")

    c = LABELS_FOR_DATA[data_name]
    for v in all_vars:
        r = MATH_LABELS_FOR_WBCS[v]
        data = all_sde_data[v]
        assert len(data) == 2000
        l = np.quantile(data, alpha / 2.0)
        u = np.quantile(data, 1 - alpha / 2.0)
        df_table.loc[r, c] = f"{cfg.__dict__[v]:.4f} [{l:.5f}, {u:.4f}]"

display(df_table)
print(df_table.to_markdown())

# Noise PDFs

In [None]:
ymax_lag = 1
rm = 1
n_months = 48
alpha = 0.05  # == 95% CI
target_cols = ["gulfstrm", "kuroshio"]

df_obs_time_series = read_matlab_time_series(OBS_FILE_PATH)
df_gfdl_time_series = read_matlab_time_series(GFDL_FILE_PATH)

plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 6))

for i, (data_name, df_time_series) in enumerate(
    zip(
        ["Observation", "Simulation-GFDL-CM4C192"],
        [df_obs_time_series, df_gfdl_time_series],
    )
):
    Xs, ys, ts = get_Xs_and_ys_for_bcs(
        df_time_series, target_cols, max_lag=max_lag, rm=rm
    )
    set_all_seeds()
    reg = LinearRegression(fit_intercept=False).fit(Xs, ys)
    preds = reg.predict(Xs)
    residuals = ys - preds

    res_gulf, res_kuro = None, None
    for j, col in enumerate(target_cols):
        ax = axes[i, j]

        data = copy.deepcopy(residuals[:, j])
        mean = np.mean(data)
        std = np.std(data, ddof=0)
        ax.hist(data, density=True, range=(-2.5, 2.5), bins=51)

        xs = np.linspace(-2.5, 2.5, 41, endpoint=True)
        ys = scipy.stats.norm.pdf(xs, loc=mean, scale=std)
        ax.plot(xs, ys, color="k", lw=2)

        ax.set_xlabel("Residual")
        ax.set_ylabel("PDF")

        data = (data - mean) / std
        ks_result = scipy.stats.kstest(data, scipy.stats.norm.cdf)
        _, shapiro_pval = scipy.stats.shapiro(data)

        print(
            f"\n{LABELS_FOR_WBCS[col]} ({LABELS_FOR_DATA[data_name]})\nKS test P value = {ks_result.pvalue:.5f}\nSW test P value = {shapiro_pval}\n"
        )

        ax.set_title(
            f"{LABELS_FOR_WBCS[col]} ({LABELS_FOR_DATA[data_name]})\nKS test P value = {ks_result.pvalue:.5f}"
        )

plt.tight_layout()
plt.show()

# Noise auto/cross-corr. coeffs

In [None]:
max_lag = 1
rm = 1
n_months = 48
alpha = 0.05  # == 95% CI
target_cols = ["gulfstrm", "kuroshio"]

df_obs_time_series = read_matlab_time_series(OBS_FILE_PATH)
df_gfdl_time_series = read_matlab_time_series(GFDL_FILE_PATH)

plt.rcParams["font.size"] = 18
fig, axes = plt.subplots(2, 3, sharey=True, figsize=(15, 5.5))

for i, (data_name, df_time_series) in enumerate(
    zip(
        ["Observation", "Simulation-GFDL-CM4C192"],
        [df_obs_time_series, df_gfdl_time_series],
    )
):
    Xs, ys, ts = get_Xs_and_ys_for_bcs(
        df_time_series, target_cols, max_lag=max_lag, rm=rm
    )
    set_all_seeds()
    reg = LinearRegression(fit_intercept=False).fit(Xs, ys)
    preds = reg.predict(Xs)
    residuals = ys - preds

    res_gulf, res_kuro = None, None
    for j, col in enumerate(target_cols):
        if col == "gulfstrm":
            res_gulf = residuals[:, j]
        if col == "kuroshio":
            res_kuro = residuals[:, j]

        ax = axes[i, j]
        ax.yaxis.set_ticks_position("both")
        for lag in np.arange(12, n_months, 12):
            ax.axvline(lag, ls="-", lw=0.5, color="gray")

        acf, cnf = stattools.acf(
            residuals[:, j],
            alpha=alpha,
            nlags=n_months,
            fft=False,
            adjusted=False,
            bartlett_confint=False,
        )
        cnf = cnf - acf[..., None]
        lags = np.arange(len(acf))  # month

        for _t in range(len(acf)):
            if _t == 0:
                continue
            if acf[_t] > cnf[_t, 1] or acf[_t] < cnf[_t, 0]:
                logger.info(
                    f"{col}: {_t:02}, acf = {acf[_t]:.3f} [{cnf[_t, 0]:.3f}, {cnf[_t, 1]:.3f}]"
                )

        ax.plot(lags, acf, ls="-", color=COLORS[col])
        ax.fill_between(lags[1:], cnf[1:, 0], cnf[1:, 1], alpha=0.15, color="gray")

        ax.axhline(0.0, ls="-", lw=0.75, color="k")

        ax.set_ylim([-0.25, 0.25])
        ax.set_yticks(np.linspace(-0.2, 0.2, 5))
        ax.yaxis.set_minor_locator(
            matplotlib.ticker.FixedLocator(np.linspace(-0.25, 0.25, 11))
        )
        if j == 0:
            ax.set_ylabel("Corr.")

        ax.set_xlim([0, n_months])
        ax.set_xticks(np.linspace(0, n_months, 5), labels=None if i == 1 else [])
        ax.xaxis.set_minor_locator(
            matplotlib.ticker.FixedLocator(np.linspace(0, n_months, 17))
        )
        if i == 1:
            ax.set_xlabel("Lag [month]")
        if i == 0:
            v = r"$\xi_G$" if col == "gulfstrm" else r"$\xi_K$"
            ax.set_title(f"Auto-Correlations for {v}")

        if j == 0 and i == 0:
            l = LABELS_FOR_DATA[data_name]
            ax.text(-18, -0.1, l, rotation="vertical", fontsize="large")
        if j == 0 and i == 1:
            l = LABELS_FOR_DATA[data_name]
            ax.text(-18, -0.5, l, rotation="vertical", fontsize="large")

    assert len(res_gulf) == len(res_kuro)

    gulf_ccf, gulf_cnf = stattools.ccf(
        x=res_kuro, y=res_gulf, adjusted=False, fft=False, nlags=n_months, alpha=alpha
    )
    gulf_cnf = gulf_cnf - gulf_ccf[..., None]

    kuro_ccf, kuro_cnf = stattools.ccf(
        x=res_gulf, y=res_kuro, adjusted=False, fft=False, nlags=n_months, alpha=alpha
    )
    kuro_cnf = kuro_cnf - kuro_ccf[..., None]

    ax = axes[i, 2]
    ax.yaxis.set_ticks_position("both")

    for lag in np.arange(-n_months, n_months, 12):
        if lag == 0:
            continue
        ax.axvline(lag, ls="-", lw=0.5, color="gray")

    ts = +np.arange(n_months)
    ax.plot(ts, gulf_ccf, "-", color=COLORS["CrossCorr"])
    ax.fill_between(ts, gulf_cnf[:, 0], gulf_cnf[:, 1], color="gray", alpha=0.15)

    ts = -np.arange(n_months)
    ax.plot(ts, kuro_ccf, "-", color=COLORS["CrossCorr"])
    ax.fill_between(ts, kuro_cnf[:, 0], kuro_cnf[:, 1], color="gray", alpha=0.15)

    ax.axhline(0.0, ls="-", lw=0.75, color="k")
    ax.axvline(0.0, ls="-", lw=0.75, color="k")

    print("Max Cross Corr = ", np.max(np.abs(gulf_ccf)), np.max(np.abs(kuro_ccf)))

    ax.set_ylim([-0.25, 0.25])
    ax.set_yticks(np.linspace(-0.2, 0.2, 5))
    ax.yaxis.set_minor_locator(
        matplotlib.ticker.FixedLocator(np.linspace(-0.25, 0.25, 11))
    )

    if i == 1:
        ax.text(+30, -0.5, r"$\xi_G$ leads", fontsize="medium")
        ax.text(-50, -0.5, r"$\xi_K$ leads", fontsize="medium")

    ax.set_xlim([-n_months, n_months])
    ax.set_xticks(np.linspace(-n_months, n_months, 5), labels=None if i == 1 else [])
    ax.xaxis.set_minor_locator(
        matplotlib.ticker.FixedLocator(np.linspace(-n_months, n_months, 9))
    )
    if i == 1:
        ax.set_xlabel("Lag [month]")
    if i == 0:
        ax.set_title(r"Cross-Correlations for $\xi_G$ and $\xi_K$")


ax = axes[1, 0]
xs = [-19, 165]
ys = [0.292, 0.292]
ax.plot(xs, ys, "k", lw=2, clip_on=False)

plt.tight_layout()
plt.show()

fig_name = "fig03"
if WRITE_WEBP:
    fig.savefig(f"{FIG_DIR}/{fig_name}.webp", bbox_inches="tight")
if WRITE_JPG:
    fig.savefig(f"{FIG_DIR}/{fig_name}.jpg", bbox_inches="tight", dpi=DPI)
if WRITE_EPS:
    fig.savefig(f"{FIG_DIR}/{fig_name}.eps", bbox_inches="tight", dpi=DPI)
if WRITE_PDF:
    fig.savefig(f"{FIG_DIR}/{fig_name}.pdf", bbox_inches="tight", dpi=DPI)

# Lag correlation including CI

In [None]:
lags = np.linspace(-25, 25, 251)
dict_results = {}
alpha = 0.05

if os.path.exists("./all_rhos_from_bootstrap.pickle"):
    dict_results = read_pickle("./all_rhos_from_bootstrap.pickle")

else:
    for data_name in ["Simulation-GFDL-CM4C192", "Observation"]:
        pickle_path = f"./bootstrap_replications_{data_name}.pickle"
        dict_params = read_pickle(pickle_path)

        list_lag_corrs = []

        for param in tqdm(dict_params.values()):
            cfg = Config(
                Tx=param["Tx"],
                Ty=param["Ty"],
                a=param["a"],
                b=param["b"],
                rx=param["rx"],
                ry=param["ry"],
            )
            rhos = calc_lag_rhos(cfg, lags)[::-1]
            list_lag_corrs.append(rhos)
        all_rhos = np.stack(list_lag_corrs, axis=0)

        dict_results[data_name] = {
            "mean_rhos": np.mean(all_rhos, axis=0),
            "min_rhos": np.quantile(all_rhos, q=alpha / 2.0, axis=0),
            "max_rhos": np.quantile(all_rhos, q=1.0 - alpha / 2.0, axis=0),
        }
    write_pickle(dict_results, "./all_rhos_from_bootstrap.pickle")

In [None]:
plt.rcParams["font.size"] = 16
fig = plt.figure(figsize=[5, 4])
ax = fig.add_subplot(111)
ax.yaxis.set_ticks_position("both")

for data_name in ["Simulation-GFDL-CM4C192", "Observation"]:
    mean_rhos = dict_results[data_name]["mean_rhos"]
    min_rhos = dict_results[data_name]["min_rhos"]
    max_rhos = dict_results[data_name]["max_rhos"]
    significant_rhos = np.where(min_rhos >= 5e-3, mean_rhos, np.nan)

    if data_name == "Observation":
        c = COLOR_LIST[4]
    else:
        c = COLOR_LIST[5]

    ax.plot(lags, mean_rhos, lw=1.5, ls="--", color=c)
    ax.plot(lags, significant_rhos, lw=3, color=c, label=LABELS_FOR_DATA[data_name])

ax.legend(loc="lower left", bbox_to_anchor=(1.05, 0.0), edgecolor="k")
ax.axhline(0.0, ls="-", color="gray", lw=0.8)
ax.axvline(0.0, ls="-", color="gray", lw=0.8)
ax.set_ylabel("Lag Correlation")

n_months = 24
ax.set_xlim([-n_months, n_months])
ax.set_xticks(np.linspace(-n_months, n_months, 5), labels=None)
ax.xaxis.set_minor_locator(
    matplotlib.ticker.FixedLocator(np.linspace(-n_months, n_months, 17))
)
ax.set_xlabel("Lag [month]")
ax.set_title(
    r"   $\leftarrow$ Kuroshio leads       Gulf Stream leads$\rightarrow$",
    fontsize="medium",
)

plt.show()

fig_name = "fig04"
if WRITE_WEBP:
    fig.savefig(f"{FIG_DIR}/{fig_name}.webp", bbox_inches="tight")
if WRITE_JPG:
    fig.savefig(f"{FIG_DIR}/{fig_name}.jpg", bbox_inches="tight", dpi=DPI)
if WRITE_EPS:
    fig.savefig(f"{FIG_DIR}/{fig_name}.eps", bbox_inches="tight", dpi=DPI)
if WRITE_PDF:
    fig.savefig(f"{FIG_DIR}/{fig_name}.pdf", bbox_inches="tight", dpi=DPI)

# Regime diagrams without corrs

In [None]:
df_obs_time_series = read_matlab_time_series(OBS_FILE_PATH)
df_gfdl_time_series = read_matlab_time_series(GFDL_FILE_PATH)

obs_model_params = estimate_sde_coeffs_for_bcs(df_obs_time_series)
gfd_model_params = estimate_sde_coeffs_for_bcs(df_gfdl_time_series)

lb_x = "gulfstrm"
lb_y = "kuroshio"

# a = b Tx/Ty

for target_name in ["gulfstrm", "kuroshio"]:
    if target_name == "gulfstrm":
        lb_Q = r"$\dot{\sigma}_G$"
        lb_I = r"$\dot{I}_{G \leftarrow K}$"
        lb_V = r"$\sqrt{\langle T_G^2 \rangle} - \sqrt{\langle T_G^2 \rangle}_{c_{G \leftarrow K}=0}$"
        calc_Q = calc_Qx
        calc_I = calc_I_y_to_x
        calc_V = calc_Sigmaxx
        is_a_zero = True
    elif target_name == "kuroshio":
        lb_Q = r"$\dot{\sigma}_K$"
        lb_I = r"$\dot{I}_{K \leftarrow G}$"
        lb_V = r"$\sqrt{\langle T_K^2 \rangle} - \sqrt{\langle T_K^2 \rangle}_{c_{K \leftarrow G}=0}$"
        calc_Q = calc_Qy
        calc_I = calc_I_x_to_y
        calc_V = calc_Sigmayy
        is_a_zero = False
    else:
        raise Exception()

    plt.rcParams["font.size"] = 18
    fig, axes = plt.subplots(2, 3, figsize=[12, 7.0], sharex=True, sharey=True)

    for i_row, (data_name, model_params) in enumerate(
        zip(
            ["Observation", "Simulation-GFDL-CM4C192"],
            [obs_model_params, gfd_model_params],
        )
    ):
        pickle_path = f"./bootstrap_replications_{data_name}.pickle"
        dict_sde_params = read_pickle(pickle_path)
        all_sde_data = {}
        for v in ["a", "b"]:
            all_sde_data[v] = []
            for param in dict_sde_params.values():
                all_sde_data[v].append(param[v])

        j_col = 0
        cfg = copy.deepcopy(model_params["config"])

        lst_a = np.linspace(-0.10, 0.20, num=76, endpoint=True)
        lst_b = np.linspace(-0.10, 0.20, num=76, endpoint=True)

        grd_b, grd_a = np.meshgrid(lst_b, lst_a, indexing="ij")

        grd_Q, grd_I, grd_V = (
            np.zeros_like(grd_b),
            np.zeros_like(grd_b),
            np.zeros_like(grd_b),
        )

        for i in range(grd_b.shape[0]):
            for j in range(grd_b.shape[1]):
                c1 = Config(
                    Tx=cfg.Tx,
                    Ty=cfg.Ty,
                    a=grd_a[i, j],
                    b=grd_b[i, j],
                    rx=cfg.rx,
                    ry=cfg.ry,
                )

                if is_a_zero:
                    c2 = Config(
                        Tx=cfg.Tx, Ty=cfg.Ty, a=0.0, b=grd_b[i, j], rx=cfg.rx, ry=cfg.ry
                    )
                else:
                    c2 = Config(
                        Tx=cfg.Tx, Ty=cfg.Ty, a=grd_a[i, j], b=0.0, rx=cfg.rx, ry=cfg.ry
                    )

                grd_Q[i, j] = calc_Q(c1)
                grd_I[i, j] = calc_I(c1)
                grd_V[i, j] = np.sqrt(calc_V(c1)) - np.sqrt(calc_V(c2))

        ax = axes[i_row, j_col]
        vmin, vmax = -0.01, 0.01
        cnt = ax.pcolormesh(grd_b, grd_a, grd_Q, vmin=vmin, vmax=vmax, cmap="seismic")
        cbar = fig.colorbar(cnt, ax=ax, extend="both")
        cbar.set_ticks([vmin, 0.0, vmax])
        ax.set_title(lb_Q)
        ax.set_ylabel(r"$c_{G \leftarrow K}$")

        l = LABELS_FOR_DATA[data_name]
        if i_row == 0:
            ax.text(-0.16, 0.05, l, rotation="vertical", fontsize="large")
        elif i_row == 1:
            ax.text(-0.16, -0.07, l, rotation="vertical", fontsize="large")
        j_col += 1

        ax = axes[i_row, j_col]
        vmin, vmax = -0.03, 0.03
        cnt = ax.pcolormesh(
            grd_b, grd_a, grd_I, vmin=vmin, vmax=vmax, cmap="twilight_shifted"
        )
        cbar = fig.colorbar(cnt, ax=ax, extend="both")
        cbar.set_ticks([vmin, 0.0, vmax])
        ax.set_title(lb_I)
        j_col += 1

        ax = axes[i_row, j_col]
        vmin, vmax = -0.20, 0.20
        cnt = ax.pcolormesh(grd_b, grd_a, grd_V, vmin=vmin, vmax=vmax, cmap="seismic")
        cbar = fig.colorbar(cnt, ax=ax, extend="both")
        cbar.set_ticks([vmin, 0.0, vmax])
        ax.set_title(lb_V, fontsize="small")
        j_col += 1

        for ax in np.ravel(axes[i_row, :]):
            ax.plot(lst_b, cfg.Tx / cfg.Ty * lst_b, ls="--", color="gray")
            if i_row == 1:
                ax.set_xlabel(r"$c_{K \leftarrow G}$")
            ax.xaxis.set_ticks_position("both")
            ax.yaxis.set_ticks_position("both")
            ax.axvline(0, color="k")
            ax.axhline(0, color="k")

            data_a = all_sde_data["a"]
            data_b = all_sde_data["b"]

            a_min = np.quantile(data_a, 0.025)
            a_max = np.quantile(data_a, 0.975)
            a_std = np.std(data_a, ddof=1)
            b_min = np.quantile(data_b, 0.025)
            b_max = np.quantile(data_b, 0.975)
            b_std = np.std(data_b, ddof=1)

            # yerr = [[cfg.a - a_min], [a_max - cfg.a]]
            # xerr = [[cfg.b - b_min], [b_max - cfg.b]]
            yerr = [a_std]
            xerr = [b_std]

            ax.errorbar(
                [cfg.b],
                [cfg.a],
                xerr=xerr,
                yerr=yerr,
                marker="o",
                elinewidth=1.5,
                ecolor="k",
                color="w",
                capsize=4,
                capthick=1.5,
                markerfacecolor="w",
                markeredgecolor="k",
                markersize=8,
            )
            # ax.scatter([cfg.b], [cfg.a], s=[300], color="w", marker="*", edgecolors="k")

            max_x = 0.20
            ax.set_xticks(np.linspace(0, max_x, 3, endpoint=True))
            ax.set_yticks(np.linspace(0, max_x, 3, endpoint=True))
            ticks = np.linspace(-0.05, max_x, 6, endpoint=True)
            ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(ticks))
            ax.yaxis.set_minor_locator(matplotlib.ticker.FixedLocator(ticks))
            ax.set_xlim(0.0, max_x)
            ax.set_ylim(-0.05, max_x)
            ax.set_aspect("equal")

    plt.suptitle(f"{LABELS_FOR_WBCS[target_name]}")
    plt.tight_layout()
    plt.show()

    # fig_name = f"regime_diagrams_without_corr_{target_name}"
    # if WRITE_WEBP:
    #     fig.savefig(f"{FIG_DIR}/{fig_name}.webp", bbox_inches="tight")
    # if WRITE_EPS:
    #     fig.savefig(f"{FIG_DIR}/{fig_name}.eps", bbox_inches="tight", dpi=DPI)
    # if WRITE_JPG:
    #     fig.savefig(f"{FIG_DIR}/{fig_name}.jpg", bbox_inches="tight", dpi=DPI)

# Regime diagrams without corr (combined figure).

In [None]:
df_obs_time_series = read_matlab_time_series(OBS_FILE_PATH)
df_gfdl_time_series = read_matlab_time_series(GFDL_FILE_PATH)

obs_model_params = estimate_sde_coeffs_for_bcs(df_obs_time_series)
gfd_model_params = estimate_sde_coeffs_for_bcs(df_gfdl_time_series)

lb_x = "gulfstrm"
lb_y = "kuroshio"

# a = b Tx/Ty

for i_data, (data_name, model_params) in enumerate(
    zip(
        ["Observation", "Simulation-GFDL-CM4C192"],
        [obs_model_params, gfd_model_params],
    )
):
    plt.rcParams["font.size"] = 20
    fig, axes = plt.subplots(2, 3, figsize=[12, 8.0], sharex=True, sharey=True)

    for i_current, target_name in enumerate(["gulfstrm", "kuroshio"]):
        if target_name == "gulfstrm":
            lb_Q = r"(a) $\dot{\sigma}_G$"
            lb_I = r"(b) $\dot{I}_{G \leftarrow K}$"
            lb_V = r"(c) $\sqrt{\langle T_G^2 \rangle} - \sqrt{\langle T_G^2 \rangle}\vert_{c_{G \leftarrow K}=0}$"
            calc_Q = calc_Qx
            calc_I = calc_I_y_to_x
            calc_V = calc_Sigmaxx
            is_a_zero = True
        elif target_name == "kuroshio":
            lb_Q = r"(d) $\dot{\sigma}_K$"
            lb_I = r"(e) $\dot{I}_{K \leftarrow G}$"
            lb_V = r"(f) $\sqrt{\langle T_K^2 \rangle} - \sqrt{\langle T_K^2 \rangle}\vert_{c_{K \leftarrow G}=0}$"
            calc_Q = calc_Qy
            calc_I = calc_I_x_to_y
            calc_V = calc_Sigmayy
            is_a_zero = False
        else:
            raise Exception()

        pickle_path = f"./bootstrap_replications_{data_name}.pickle"
        dict_sde_params = read_pickle(pickle_path)
        all_sde_data = {}
        for v in ["a", "b"]:
            all_sde_data[v] = []
            for param in dict_sde_params.values():
                all_sde_data[v].append(param[v])

        j_col = 0
        cfg = copy.deepcopy(model_params["config"])

        lst_a = np.linspace(-0.10, 0.20, num=76, endpoint=True)
        lst_b = np.linspace(-0.10, 0.20, num=76, endpoint=True)

        grd_b, grd_a = np.meshgrid(lst_b, lst_a, indexing="ij")

        grd_Q, grd_I, grd_V = (
            np.zeros_like(grd_b),
            np.zeros_like(grd_b),
            np.zeros_like(grd_b),
        )

        for i in range(grd_b.shape[0]):
            for j in range(grd_b.shape[1]):
                c1 = Config(
                    Tx=cfg.Tx,
                    Ty=cfg.Ty,
                    a=grd_a[i, j],
                    b=grd_b[i, j],
                    rx=cfg.rx,
                    ry=cfg.ry,
                )

                if is_a_zero:
                    c2 = Config(
                        Tx=cfg.Tx, Ty=cfg.Ty, a=0.0, b=grd_b[i, j], rx=cfg.rx, ry=cfg.ry
                    )
                else:
                    c2 = Config(
                        Tx=cfg.Tx, Ty=cfg.Ty, a=grd_a[i, j], b=0.0, rx=cfg.rx, ry=cfg.ry
                    )

                grd_Q[i, j] = calc_Q(c1)
                grd_I[i, j] = calc_I(c1)
                grd_V[i, j] = np.sqrt(calc_V(c1)) - np.sqrt(calc_V(c2))

        i_row = i_current

        ax = axes[i_row, j_col]
        vmin, vmax = -0.01, 0.01
        cnt = ax.pcolormesh(grd_b, grd_a, grd_Q, vmin=vmin, vmax=vmax, cmap="seismic")
        cbar = fig.colorbar(cnt, ax=ax, extend="both")
        cbar.set_ticks([vmin, 0.0, vmax])
        ax.set_title(lb_Q)
        ax.set_ylabel(r"$c_{G \leftarrow K}$")
        j_col += 1

        ax = axes[i_row, j_col]
        vmin, vmax = -0.03, 0.03
        cnt = ax.pcolormesh(
            grd_b, grd_a, grd_I, vmin=vmin, vmax=vmax, cmap="twilight_shifted"
        )
        cbar = fig.colorbar(cnt, ax=ax, extend="both")
        cbar.set_ticks([vmin, 0.0, vmax])
        ax.set_title(lb_I)
        j_col += 1

        ax = axes[i_row, j_col]
        vmin, vmax = -0.10, 0.10
        cnt = ax.pcolormesh(grd_b, grd_a, grd_V, vmin=vmin, vmax=vmax, cmap="bwr")
        cbar = fig.colorbar(cnt, ax=ax, extend="both")
        cbar.set_ticks([vmin, 0.0, vmax])
        ax.set_title(lb_V, fontsize=20)
        j_col += 1

        for ax in np.ravel(axes[i_row, :]):
            ax.plot(lst_b, cfg.Tx / cfg.Ty * lst_b, ls="--", color="gray")
            if i_row == 1:
                ax.set_xlabel(r"$c_{K \leftarrow G}$")
            ax.xaxis.set_ticks_position("both")
            ax.yaxis.set_ticks_position("both")
            ax.axvline(0, color="k")
            ax.axhline(0, color="k")

            data_a = all_sde_data["a"]
            data_b = all_sde_data["b"]

            a_min = np.quantile(data_a, 0.025)
            a_max = np.quantile(data_a, 0.975)
            a_std = np.std(data_a, ddof=1)
            b_min = np.quantile(data_b, 0.025)
            b_max = np.quantile(data_b, 0.975)
            b_std = np.std(data_b, ddof=1)

            # yerr = [[cfg.a - a_min], [a_max - cfg.a]]
            # xerr = [[cfg.b - b_min], [b_max - cfg.b]]
            yerr = [a_std]
            xerr = [b_std]

            ax.errorbar(
                [cfg.b],
                [cfg.a],
                xerr=xerr,
                yerr=yerr,
                marker="o",
                elinewidth=1.5,
                ecolor="k",
                color="w",
                capsize=4,
                capthick=1.5,
                markerfacecolor="w",
                markeredgecolor="k",
                markersize=8,
            )
            # ax.scatter([cfg.b], [cfg.a], s=[300], color="w", marker="*", edgecolors="k")

            max_x = 0.20
            ax.set_xticks(np.linspace(0, max_x, 3, endpoint=True))
            ax.set_yticks(np.linspace(0, max_x, 3, endpoint=True))
            ticks = np.linspace(-0.05, max_x, 6, endpoint=True)
            ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(ticks))
            ax.yaxis.set_minor_locator(matplotlib.ticker.FixedLocator(ticks))
            ax.set_xlim(0.0, max_x)
            ax.set_ylim(-0.05, max_x)
            ax.set_aspect("equal")

    plt.suptitle(LABELS_FOR_DATA[data_name], fontweight="bold")
    plt.tight_layout()
    plt.show()

    if data_name == "Observation":
        fig_name = "fig05"
    else:
        fig_name = "fig06"
    if WRITE_WEBP:
        fig.savefig(f"{FIG_DIR}/{fig_name}.webp", bbox_inches="tight")
    if WRITE_EPS:
        fig.savefig(f"{FIG_DIR}/{fig_name}.eps", bbox_inches="tight", dpi=DPI)
    if WRITE_JPG:
        fig.savefig(f"{FIG_DIR}/{fig_name}.jpg", bbox_inches="tight", dpi=DPI)
    if WRITE_PDF:
        fig.savefig(f"{FIG_DIR}/{fig_name}.pdf", bbox_inches="tight", dpi=DPI)

# Regime diagrams in the case of r_k = 0

In [None]:
df_obs_time_series = read_matlab_time_series(OBS_FILE_PATH)
df_gfdl_time_series = read_matlab_time_series(GFDL_FILE_PATH)

obs_model_params = estimate_sde_coeffs_for_bcs(df_obs_time_series)
gfd_model_params = estimate_sde_coeffs_for_bcs(df_gfdl_time_series)

lb_x = "gulfstrm"
lb_y = "kuroshio"

# a = b Tx/Ty

for i_data, (data_name, model_params) in enumerate(
    zip(
        ["Observation", "Simulation-GFDL-CM4C192"],
        [obs_model_params, gfd_model_params],
    )
):
    plt.rcParams["font.size"] = 20
    fig, axes = plt.subplots(2, 3, figsize=[12, 8.0], sharex=True, sharey=True)

    for i_current, target_name in enumerate(["gulfstrm", "kuroshio"]):
        if target_name == "gulfstrm":
            lb_Q = r"(a) $\dot{\sigma}_G$"
            lb_I = r"(b) $\dot{I}_{G \leftarrow K}$"
            lb_V = r"(c) $\sqrt{\langle T_G^2 \rangle} - \sqrt{\langle T_G^2 \rangle}\vert_{c_{G \leftarrow K}=0}$"
            calc_Q = calc_Qx
            calc_I = calc_I_y_to_x
            calc_V = calc_Sigmaxx
            is_a_zero = True
        elif target_name == "kuroshio":
            lb_Q = r"(d) $\dot{\sigma}_K$"
            lb_I = r"(e) $\dot{I}_{K \leftarrow G}$"
            lb_V = r"(f) $\sqrt{\langle T_K^2 \rangle} - \sqrt{\langle T_K^2 \rangle}\vert_{c_{K \leftarrow G}=0}$"
            calc_Q = calc_Qy
            calc_I = calc_I_x_to_y
            calc_V = calc_Sigmayy
            is_a_zero = False
        else:
            raise Exception()

        pickle_path = f"./bootstrap_replications_{data_name}.pickle"
        dict_sde_params = read_pickle(pickle_path)
        all_sde_data = {}
        for v in ["a", "b"]:
            all_sde_data[v] = []
            for param in dict_sde_params.values():
                all_sde_data[v].append(param[v])

        j_col = 0
        cfg = copy.deepcopy(model_params["config"])

        lst_a = np.linspace(-0.10, 0.20, num=76, endpoint=True)
        lst_b = np.linspace(-0.10, 0.20, num=76, endpoint=True)

        grd_b, grd_a = np.meshgrid(lst_b, lst_a, indexing="ij")

        grd_Q, grd_I, grd_V = (
            np.zeros_like(grd_b),
            np.zeros_like(grd_b),
            np.zeros_like(grd_b),
        )

        for i in range(grd_b.shape[0]):
            for j in range(grd_b.shape[1]):
                c1 = Config(
                    Tx=cfg.Tx,
                    Ty=cfg.Ty,
                    a=grd_a[i, j],
                    b=grd_b[i, j],
                    rx=cfg.rx,
                    ry=0.0,
                )

                if is_a_zero:
                    c2 = Config(
                        Tx=cfg.Tx,
                        Ty=cfg.Ty,
                        a=0.0,
                        b=grd_b[i, j],
                        rx=cfg.rx,
                        ry=0.0,
                    )
                else:
                    c2 = Config(
                        Tx=cfg.Tx,
                        Ty=cfg.Ty,
                        a=grd_a[i, j],
                        b=0.0,
                        rx=cfg.rx,
                        ry=0.0,
                    )

                grd_Q[i, j] = calc_Q(c1)
                grd_I[i, j] = calc_I(c1)
                grd_V[i, j] = np.sqrt(calc_V(c1)) - np.sqrt(calc_V(c2))

        i_row = i_current

        ax = axes[i_row, j_col]
        vmin, vmax = -0.01, 0.01
        cnt = ax.pcolormesh(grd_b, grd_a, grd_Q, vmin=vmin, vmax=vmax, cmap="seismic")
        cbar = fig.colorbar(cnt, ax=ax, extend="both")
        cbar.set_ticks([vmin, 0.0, vmax])
        ax.set_title(lb_Q)
        ax.set_ylabel(r"$c_{G \leftarrow K}$")
        j_col += 1

        ax = axes[i_row, j_col]
        vmin, vmax = -0.03, 0.03
        cnt = ax.pcolormesh(
            grd_b, grd_a, grd_I, vmin=vmin, vmax=vmax, cmap="twilight_shifted"
        )
        cbar = fig.colorbar(cnt, ax=ax, extend="both")
        cbar.set_ticks([vmin, 0.0, vmax])
        ax.set_title(lb_I)
        j_col += 1

        ax = axes[i_row, j_col]
        vmin, vmax = -0.10, 0.10
        cnt = ax.pcolormesh(grd_b, grd_a, grd_V, vmin=vmin, vmax=vmax, cmap="bwr")
        cbar = fig.colorbar(cnt, ax=ax, extend="both")
        cbar.set_ticks([vmin, 0.0, vmax])
        ax.set_title(lb_V, fontsize=20)
        j_col += 1

        for ax in np.ravel(axes[i_row, :]):
            ax.plot(lst_b, cfg.Tx / cfg.Ty * lst_b, ls="--", color="gray")
            if i_row == 1:
                ax.set_xlabel(r"$c_{K \leftarrow G}$")
            ax.xaxis.set_ticks_position("both")
            ax.yaxis.set_ticks_position("both")
            ax.axvline(0, color="k")
            ax.axhline(0, color="k")

            data_a = all_sde_data["a"]
            data_b = all_sde_data["b"]

            a_min = np.quantile(data_a, 0.025)
            a_max = np.quantile(data_a, 0.975)
            a_std = np.std(data_a, ddof=1)
            b_min = np.quantile(data_b, 0.025)
            b_max = np.quantile(data_b, 0.975)
            b_std = np.std(data_b, ddof=1)

            # yerr = [[cfg.a - a_min], [a_max - cfg.a]]
            # xerr = [[cfg.b - b_min], [b_max - cfg.b]]
            yerr = [a_std]
            xerr = [b_std]

            ax.errorbar(
                [cfg.b],
                [cfg.a],
                xerr=xerr,
                yerr=yerr,
                marker="o",
                elinewidth=1.5,
                ecolor="k",
                color="w",
                capsize=4,
                capthick=1.5,
                markerfacecolor="w",
                markeredgecolor="k",
                markersize=8,
            )
            # ax.scatter([cfg.b], [cfg.a], s=[300], color="w", marker="*", edgecolors="k")

            max_x = 0.20
            ax.set_xticks(np.linspace(0, max_x, 3, endpoint=True))
            ax.set_yticks(np.linspace(0, max_x, 3, endpoint=True))
            ticks = np.linspace(-0.05, max_x, 6, endpoint=True)
            ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(ticks))
            ax.yaxis.set_minor_locator(matplotlib.ticker.FixedLocator(ticks))
            ax.set_xlim(0.0, max_x)
            ax.set_ylim(-0.05, max_x)
            ax.set_aspect("equal")

    plt.suptitle(LABELS_FOR_DATA[data_name], fontweight="bold")
    plt.tight_layout()
    plt.show()

    # fig_name = f"regime_diagrams_without_corr_combined_{LABELS_FOR_DATA[data_name]}"
    # if WRITE_WEBP:
    #     fig.savefig(f"{FIG_DIR}/{fig_name}.webp", bbox_inches="tight")
    # if WRITE_EPS:
    #     fig.savefig(f"{FIG_DIR}/{fig_name}.eps", bbox_inches="tight", dpi=DPI)
    # if WRITE_JPG:
    #     fig.savefig(f"{FIG_DIR}/{fig_name}.jpg", bbox_inches="tight", dpi=DPI)

# Regime diagrams for corrs and lags

## Calc corr

In [None]:
plt_pickle_path = "./data_regime_diagrams.pickle"

df_obs_time_series = read_matlab_time_series(OBS_FILE_PATH)
df_gfdl_time_series = read_matlab_time_series(GFDL_FILE_PATH)

obs_model_params = estimate_sde_coeffs_for_bcs(df_obs_time_series)
gfd_model_params = estimate_sde_coeffs_for_bcs(df_gfdl_time_series)

lb_x = "gulfstrm"
lb_y = "kuroshio"

lst_ts = np.linspace(-12, 12, 241)
lst_a = np.linspace(-0.10, 0.20, num=76, endpoint=True)
lst_b = np.linspace(-0.10, 0.20, num=76, endpoint=True)
grd_b, grd_a = np.meshgrid(lst_b, lst_a, indexing="ij")

if os.path.exists(plt_pickle_path):
    dict_plt_data = read_pickle(plt_pickle_path)
    logger.info("Read pickle")
else:
    dict_plt_data = {}
    logger.info("Create an empty dict")

for target_name in ["gulfstrm", "kuroshio"]:
    if target_name == "gulfstrm":
        lb_Q = r"$Q_G$"
        lb_I = r"$I_{K \rightarrow G}$"
        lb_V = r"$\sqrt{\langle T_G^2 \rangle} - \sqrt{\langle T_G^2 \rangle}_{c_{K \rightarrow G}=0}$"
        calc_Q = calc_Qx
        calc_I = calc_I_y_to_x
        calc_V = calc_Sigmaxx
        is_a_zero = True
    elif target_name == "kuroshio":
        lb_Q = r"$Q_K$"
        lb_I = r"$I_{G \rightarrow K}$"
        lb_V = r"$\sqrt{\langle T_K^2 \rangle} - \sqrt{\langle T_K^2 \rangle}_{c_{G \rightarrow K}=0}$"
        calc_Q = calc_Qy
        calc_I = calc_I_x_to_y
        calc_V = calc_Sigmayy
        is_a_zero = False
    else:
        raise Exception()

    for data_name, model_params in zip(
        ["Observation", "Simulation-GFDL-CM4C192"], [obs_model_params, gfd_model_params]
    ):
        dict_key = f"{data_name}_{target_name}"
        if dict_key in dict_plt_data:
            continue

        cfg = copy.deepcopy(model_params["config"])

        grd_Q = np.zeros_like(grd_b)
        grd_I = np.zeros_like(grd_b)
        grd_V = np.zeros_like(grd_b)
        grd_lag = np.zeros_like(grd_b)
        grd_rho = np.zeros_like(grd_b)

        for i in tqdm(range(grd_b.shape[0])):
            for j in range(grd_b.shape[1]):
                c1 = Config(
                    Tx=cfg.Tx,
                    Ty=cfg.Ty,
                    a=grd_a[i, j],
                    b=grd_b[i, j],
                    rx=cfg.rx,
                    ry=cfg.ry,
                )

                lst_rhos = calc_lag_rhos(conf=c1, lst_lags=lst_ts)
                lag_idx, _ = scipy.signal.find_peaks(lst_rhos)
                assert len(lag_idx) <= 1
                lag_idx = lag_idx[0] if len(lag_idx) == 1 else None

                if is_a_zero:
                    c2 = Config(
                        Tx=cfg.Tx, Ty=cfg.Ty, a=0.0, b=grd_b[i, j], rx=cfg.rx, ry=cfg.ry
                    )
                else:
                    c2 = Config(
                        Tx=cfg.Tx, Ty=cfg.Ty, a=grd_a[i, j], b=0.0, rx=cfg.rx, ry=cfg.ry
                    )

                grd_Q[i, j] = calc_Q(c1)
                grd_I[i, j] = calc_I(c1)
                grd_V[i, j] = np.sqrt(calc_V(c1)) - np.sqrt(calc_V(c2))
                grd_rho[i, j] = lst_rhos[lag_idx] if lag_idx is not None else np.nan
                grd_lag[i, j] = -lst_ts[lag_idx] if lag_idx is not None else np.nan

        dict_plt_data[dict_key] = {
            "Q": grd_Q,
            "I": grd_I,
            "V": grd_V,
            "rho": grd_rho,
            "lag": grd_lag,
        }
        write_pickle(dict_plt_data, plt_pickle_path)

## Plot

In [None]:
for target_name in ["gulfstrm"]:
    if target_name == "gulfstrm":
        lb_Q = r"$\sigma_G$"
        lb_I = r"$I_{K \rightarrow G}$"
        lb_V = r"$\sqrt{\langle T_G^2 \rangle} - \sqrt{\langle T_G^2 \rangle}_{c_{K \rightarrow G}=0}$"
        calc_Q = calc_Qx
        calc_I = calc_I_y_to_x
        calc_V = calc_Sigmaxx
        is_a_zero = True
    elif target_name == "kuroshio":
        lb_Q = r"$\sigma_K$"
        lb_I = r"$I_{G \rightarrow K}$"
        lb_V = r"$\sqrt{\langle T_K^2 \rangle} - \sqrt{\langle T_K^2 \rangle}_{c_{G \rightarrow K}=0}$"
        calc_Q = calc_Qy
        calc_I = calc_I_x_to_y
        calc_V = calc_Sigmayy
        is_a_zero = False
    else:
        raise Exception()

    plt.rcParams["font.size"] = 18
    fig, axes = plt.subplots(2, 2, figsize=[7.0, 6], sharex=True, sharey=True)

    for i_row, (data_name, model_params) in enumerate(
        zip(
            ["Observation", "Simulation-GFDL-CM4C192"],
            [obs_model_params, gfd_model_params],
        )
    ):
        cfg = copy.deepcopy(model_params["config"])

        dict_key = f"{data_name}_{target_name}"
        plt_data = dict_plt_data[dict_key]

        dict_sde_params = read_pickle(f"./bootstrap_replications_{data_name}.pickle")
        all_sde_data = {}
        for v in ["a", "b"]:
            all_sde_data[v] = []
            for param in dict_sde_params.values():
                all_sde_data[v].append(param[v])

        j_col = 0

        ax = axes[i_row, j_col]
        vmin, vmax = 0.0, 0.60
        d = np.where(np.isnan(plt_data["rho"]), np.nan, plt_data["rho"])
        cnt = ax.pcolormesh(grd_b, grd_a, d, vmin=vmin, vmax=vmax, cmap="Greens")
        cbar = fig.colorbar(cnt, ax=ax, extend="max")
        cbar.set_ticks([0.0, vmax / 2, vmax])

        ax.set_title("Max Lag Corr.", fontsize="small")
        ax.set_ylabel(r"$c_{G \leftarrow K}$")

        l = LABELS_FOR_DATA[data_name]
        if i_row == 0:
            ax.text(-0.16, 0.05, l, rotation="vertical", fontsize="large")
        elif i_row == 1:
            ax.text(-0.16, -0.07, l, rotation="vertical", fontsize="large")

        j_col += 1

        ax = axes[i_row, j_col]
        vmin, vmax = -2.5, 2.5
        d = np.where(np.isnan(plt_data["lag"]), 3.0, plt_data["lag"])
        cnt = ax.pcolormesh(grd_b, grd_a, -d, vmin=vmin, vmax=vmax, cmap="PuOr")

        l = r"$\leftarrow$GS leads / KC leads$\rightarrow$"
        cbar = fig.colorbar(cnt, ax=ax, extend="both")
        cbar.set_ticks([vmin, 0.0, vmax])
        cbar.set_label(label=l, fontsize="small")

        ax.set_title("Lag [month]", fontsize="small")
        j_col += 1

        for ax in np.ravel(axes[i_row, :]):
            ax.plot(lst_b, cfg.Tx / cfg.Ty * lst_b, ls="--", color="gray")
            if i_row == 1:
                ax.set_xlabel(r"$c_{K \leftarrow G}$")
            ax.xaxis.set_ticks_position("both")
            ax.yaxis.set_ticks_position("both")
            ax.axvline(0, color="k")
            ax.axhline(0, color="k")

            data_a = all_sde_data["a"]
            data_b = all_sde_data["b"]

            a_min = np.quantile(data_a, 0.025)
            a_max = np.quantile(data_a, 0.975)
            a_std = np.std(data_a, ddof=1)
            b_min = np.quantile(data_b, 0.025)
            b_max = np.quantile(data_b, 0.975)
            b_std = np.std(data_b, ddof=1)

            # yerr = [[cfg.a - a_min], [a_max - cfg.a]]
            # xerr = [[cfg.b - b_min], [b_max - cfg.b]]
            yerr = [a_std]
            xerr = [b_std]

            ax.errorbar(
                [cfg.b],
                [cfg.a],
                xerr=xerr,
                yerr=yerr,
                marker="o",
                elinewidth=1.5,
                ecolor="k",
                color="w",
                capsize=4,
                capthick=1.5,
                markerfacecolor="w",
                markeredgecolor="k",
                markersize=8,
            )
            # ax.scatter([cfg.b], [cfg.a], s=[300], color="w", marker="*", edgecolors="k")

            max_x = 0.20
            ax.set_xticks(np.linspace(0, max_x, 3, endpoint=True))
            ax.set_yticks(np.linspace(0, max_x, 3, endpoint=True))
            ticks = np.linspace(-0.05, max_x, 6, endpoint=True)
            ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(ticks))
            ax.yaxis.set_minor_locator(matplotlib.ticker.FixedLocator(ticks))
            ax.set_xlim(0.0, max_x)
            ax.set_ylim(-0.05, max_x)
            ax.set_aspect("equal")

    ax = axes[1, 0]
    xs = [-0.18, 0.68]
    ys = [0.25, 0.25]
    ax.plot(xs, ys, "k", lw=2, clip_on=False)

    plt.tight_layout()
    plt.show()

    fig_name = "fig08"
    if WRITE_JPG:
        fig.savefig(f"{FIG_DIR}/{fig_name}.jpg", bbox_inches="tight", dpi=DPI)
    if WRITE_WEBP:
        fig.savefig(f"{FIG_DIR}/{fig_name}.webp", bbox_inches="tight")
    if WRITE_EPS:
        fig.savefig(f"{FIG_DIR}/{fig_name}.eps", bbox_inches="tight", dpi=DPI)
    if WRITE_PDF:
        fig.savefig(f"{FIG_DIR}/{fig_name}.pdf", bbox_inches="tight", dpi=DPI)

# Others

## Time series using statsmodels

In [None]:
rm = 1
max_lag = 1
target_cols = ["gulfstrm", "kuroshio"]

df_obs_time_series = read_matlab_time_series(OBS_FILE_PATH)
df_gfdl_time_series = read_matlab_time_series(GFDL_FILE_PATH)

plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(2, 2, sharey=True, sharex=True, figsize=[14, 4])

for i, (data_name, df_time_series) in enumerate(
    zip(
        ["Observation", "Simulation-GFDL-CM4C192"],
        [df_obs_time_series, df_gfdl_time_series],
    )
):
    data = []
    for col in target_cols:
        d = (
            preprocess(
                df_time_series[col],
                rm_window=rm,
                apply_detrend=True,
                apply_standardize=True,
            )
            .squeeze()
            .numpy()
        )
        data.append(d)
    data = np.stack(data, axis=1)

    assert np.all(~np.isnan(data))
    assert data.ndim == 2  # time and name
    assert data.shape[1] == 2  # num of names (= len(data_columns))

    model = VAR(data)
    model_result = model.fit(maxlags=max_lag, ic=None, trend="n")
    preds = model_result.fittedvalues
    ts = df_time_series["time"].values

    for j, col in enumerate(target_cols):
        ax = axes[i, j]
        ax.yaxis.set_ticks_position("both")

        n = preds.shape[0]
        ax.plot(ts[-n:], data[-n:, j], "-", lw=1.0, color=COLORS[col])
        ax.plot(ts[-n:], preds[:, j], "--", lw=1.0, color=COLORS["AR1"])

        if i == 0:
            x, y = 1969, 2.5
        else:
            x, y = 1969, 2.5
        l = LABELS_FOR_DATA[data_name] + ", " + MATH_LABELS_FOR_WBCS[col]
        ax.text(x, y, l, fontsize="medium")

        l = "---- Estimations by VAR1"
        c = COLORS["AR1"]
        ax.text(1951, -3.3, l, fontsize="medium", color=c, fontweight="bold")

        if i == 1:
            ax.set_xlabel("Time [year]")
            ax.set_xticks(np.linspace(1950, 2020, 8))

        ax.set_xlim([1950, 2020])
        ax.xaxis.set_minor_locator(
            matplotlib.ticker.FixedLocator(np.linspace(1950, 2020, 15))
        )

        if i == 0:
            ax.set_title(f"{LABELS_FOR_WBCS[col]}, Standardized SST")

        ax.set_ylabel(MATH_LABELS_FOR_WBCS[col])
        ax.set_ylim([-3.5, 3.5])
        ax.set_yticks(np.linspace(-3.0, 3.0, 5))
        ax.yaxis.set_minor_locator(
            matplotlib.ticker.FixedLocator(np.linspace(-3.5, 3.5, 15))
        )

plt.tight_layout()
plt.show()

## Steadiness tests (DF tests)

In [None]:
df_obs_time_series = read_matlab_time_series(OBS_FILE_PATH)
df_gfdl_time_series = read_matlab_time_series(GFDL_FILE_PATH)

df_results = pd.DataFrame()
i_row = 0

for data_name, df_time_series in zip(
    ["Observation", "Simulation-GFDL-CM4C192"],
    [df_obs_time_series, df_gfdl_time_series],
):
    for col in ["gulfstrm", "kuroshio"]:
        logger.info(f"\n{data_name}, {col}")
        adf = stattools.adfuller(
            df_time_series[col].values, regression="n", autolag=None, maxlag=1
        )
        logger.info(f"t-value : {adf[0]:.5f}, p-value : {adf[1]*100:.10f}%")
        logger.info(f"Lags used : {adf[2]}, data size : {adf[3]}")

        df_results.loc[i_row, "Data name"] = data_name
        df_results.loc[i_row, "Kind name"] = col
        df_results.loc[i_row, "P value"] = adf[1]
        i_row += 1
print("")
print(df_results.to_markdown())

## Lagged corr. coeffs. based on SDEs

In [None]:
df_obs_time_series = read_matlab_time_series(OBS_FILE_PATH)
df_gfdl_time_series = read_matlab_time_series(GFDL_FILE_PATH)

obs_model_params = estimate_sde_coeffs_for_bcs(df_obs_time_series)
gfd_model_params = estimate_sde_coeffs_for_bcs(df_gfdl_time_series)

plt.rcParams["font.size"] = 13
fig = plt.figure(figsize=[5, 4])
ax = fig.add_subplot(111)
ax.yaxis.set_ticks_position("both")

for data_name, params in zip(
    ["Simulation-GFDL-CM4C192", "Observation"], [gfd_model_params, obs_model_params]
):
    cfg = copy.deepcopy(params["config"])
    lags = np.linspace(-25, 25, 251)
    rhos = calc_lag_rhos(cfg, lags)[::-1]

    lag_idx, _ = scipy.signal.find_peaks(rhos)
    assert len(lag_idx) == 1
    lag_idx = lag_idx[0]

    l = LABELS_FOR_DATA[data_name]
    ax.plot(lags, rhos, color="k", ls=LINE_STYLES[data_name], label=l)
    ax.plot(lags[lag_idx], rhos[lag_idx], "o", color="k")

ax.legend(loc="upper left", ncol=1, edgecolor="k", fontsize="small")
ax.axvline(0.0, ls="-", lw=0.5, color="k")
ax.axhline(0.0, ls="-", lw=0.5, color="k")

ax.set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])
ax.set_ylim(-0.05, 0.45)
ax.yaxis.set_minor_locator(matplotlib.ticker.FixedLocator(np.linspace(-0.05, 0.45, 11)))
ax.set_ylabel("Lag Correlations")

n_months = 24
ax.set_xlim([-n_months, n_months])
ax.set_xticks(np.linspace(-n_months, n_months, 9), labels=None)
ax.xaxis.set_minor_locator(
    matplotlib.ticker.FixedLocator(np.linspace(-n_months, n_months, 17))
)
ax.set_xlabel("Lag [month]")

ax.set_title("Theoretical Lag Correlations")
ax.text(+10, -0.18, r"Gulf Stream leads", fontsize="small")
ax.text(-25, -0.18, r"Kuroshio leads", fontsize="small")

plt.tight_layout()
plt.show()

# fig_name = "theoretical_lag_corrs"
# fig.savefig(f"{FIG_DIR}/{fig_name}.webp", bbox_inches="tight")
# if WRITE_EPS:
#     fig.savefig(f"{FIG_DIR}/{fig_name}.eps", bbox_inches="tight", dpi=DPI)

## Bootstrap replications' distributions

In [None]:
all_vars = ["rx", "a", "Tx", "Qx", "Ix", "ry", "b", "Ty", "Qy", "Iy"]

plt.rcParams["font.size"] = 18
fig, all_axes = plt.subplots(4, 5, figsize=[14.0, 11])

for data_name in ["Observation", "Simulation-GFDL-CM4C192"]:
    pickle_path = f"./bootstrap_replications_{data_name}.pickle"
    dict_params = read_pickle(pickle_path)

    if data_name == "Observation":
        axes = np.ravel(all_axes[0:2])
    else:
        axes = np.ravel(all_axes[2:4])

    all_data = {}
    for v in all_vars:
        all_data[v] = []
        for param in dict_params.values():
            all_data[v].append(param[v])

    for ax, v in zip(axes, all_vars):
        if v == "rx" or v == "ry":
            c = COLOR_LIST[-1]
            r = [0.1, 0.7]
            xticks = [0.1, 0.3, 0.5, 0.7]
            minor_xticks = [0.1, 0.3, 0.5, 0.7]
        elif v == "a" or v == "b":
            c = COLOR_LIST[-2]
            r = [-0.1, 0.3]
            xticks = [0.0, 0.2]
            minor_xticks = [-0.1, 0.0, 0.1, 0.2, 0.3]
        elif v == "Tx" or v == "Ty":
            c = COLOR_LIST[-3]
            r = [0.1, 0.6]
            xticks = [0.10, 0.35, 0.60]
            minor_xticks = [0.10, (0.35 + 0.1) / 2, 0.35, (0.6 + 0.35) / 2, 0.60]
        elif v == "Qx" or v == "Qy":
            c = COLOR_LIST[-4]
            r = [-0.03, 0.03]
            xticks = [-0.03, 0.0, 0.03]
            minor_xticks = [-0.03, -0.015, 0.0, 0.015, 0.03]
        elif v == "Ix" or v == "Iy":
            c = COLOR_LIST[-5]
            r = [-0.04, 0.04]
            xticks = [-0.04, 0.0, 0.04]
            minor_xticks = [-0.04, -0.02, 0.0, 0.02, 0.04]
        else:
            raise Exception()

        if data_name == "Simulation-GFDL-CM4C192" and (v == "Qx" or v == "Qy"):
            r = [-0.01, 0.01]
            xticks = [-0.01, 0.0, 0.01]
            minor_xticks = [-0.01, -0.005, 0.0, 0.005, 0.01]

        data = all_data[v]

        cnt = ax.hist(data, bins=40, density=False, color=c, range=r)

        ymax = np.round(np.max(cnt[0]) * 1.1 / 100, decimals=1) * 100
        ax.set_ylim(0, ymax)
        ax.set_yticks([0, ymax // 2, ymax])
        ax.yaxis.set_ticks_position("both")
        if v == "rx" or v == "ry":
            ax.set_ylabel("Frequency")

        ax.xaxis.set_ticks_position("both")
        ax.set_xlabel(MATH_LABELS_FOR_WBCS[v])
        ax.set_xticks(xticks)
        ax.set_xlim(r[0], r[1])
        ax.xaxis.set_minor_locator(matplotlib.ticker.FixedLocator(minor_xticks))

        if r[0] < 0 and r[1] > 0:
            ax.axvline(0, ls="-", color="k", lw=1.5)

plt.tight_layout()

ax = all_axes[0, 0]
ax.text(
    -0.5,
    -170,
    LABELS_FOR_DATA["Observation"],
    fontsize="large",
    color="k",
    fontweight="bold",
    rotation="vertical",
)

ax = all_axes[2, 0]
ax.text(
    -0.5,
    -400,
    LABELS_FOR_DATA["Simulation-GFDL-CM4C192"],
    fontsize="large",
    color="k",
    fontweight="bold",
    rotation="vertical",
)

ax = all_axes[2, 0]
xs = [-0.5, 4.9]
ys = [450, 450]
ax.plot(xs, ys, "k", lw=2, clip_on=False)

plt.show()

fig_name = f"fig10"

if WRITE_WEBP:
    fig.savefig(f"{FIG_DIR}/{fig_name}.webp", bbox_inches="tight")
if WRITE_EPS:
    fig.savefig(f"{FIG_DIR}/{fig_name}.eps", bbox_inches="tight", dpi=DPI)
if WRITE_JPG:
    fig.savefig(f"{FIG_DIR}/{fig_name}.jpg", bbox_inches="tight", dpi=DPI)
if WRITE_PDF:
    fig.savefig(f"{FIG_DIR}/{fig_name}.pdf", bbox_inches="tight", dpi=DPI)