In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
from logging import INFO, WARNING, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

# Import

In [None]:
import copy
import os
import pathlib

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from src.data.io_matlab import read_matlab_time_series
from src.data.io_pickles import read_pickle, write_pickle
from src.information_theory.corr_helper import calc_lag_rhos
from src.information_theory.loos_klapp_2020 import (
    Config,
    calc_I_x_to_y,
    calc_I_y_to_x,
    calc_Qx,
    calc_Qy,
)
from src.simulation.block_bootstrap import BlockBootstrap
from src.simulation.sde_coeff_estimator import estimate_sde_coeffs_for_bcs
from src.utils.random_seed_helper import set_all_seeds
from tqdm.notebook import tqdm

plt.rcParams["font.family"] = "serif"
plt.style.use("tableau-colorblind10")

# Define constants

In [None]:
ROOT_DIR = str((pathlib.Path(os.environ["PYTHONPATH"]) / "..").resolve())
BCS_DATA_DIR = f"{ROOT_DIR}/data"

OBS_FILE_PATH = f"{BCS_DATA_DIR}/observations.mat"
GFDL_FILE_PATH = f"{BCS_DATA_DIR}/simulation_GFDLCM4C192.mat"

N_BOOTSTRAP = 2000

# Read data

In [None]:
df_obs_time_series = read_matlab_time_series(OBS_FILE_PATH)
df_gfdl_time_series = read_matlab_time_series(GFDL_FILE_PATH)

# Estimate VAR1 coeffs using bootstrap resamples

In [None]:
target_cols = ["gulfstrm", "kuroshio"]

for data_name, df_time_series in zip(
    ["Observation", "Simulation-GFDL-CM4C192"],
    [df_obs_time_series, df_gfdl_time_series],
):
    logger.info(f"\nData name = {data_name}")
    pickle_path = f"./bootstrap_replications_{data_name}.pickle"

    data = df_time_series.loc[:, target_cols].values
    bootstrap = BlockBootstrap(time_series=data, block_length=10)

    block_length = bootstrap.calc_block_length_using_Sherman98()
    bootstrap.set_new_block_length(block_length)

    if os.path.exists(pickle_path):
        dict_results = read_pickle(pickle_path)
        logger.info("Read pickle file")
    else:
        dict_results = {}
        logger.info("Create empty dict")

    set_all_seeds()

    for i in tqdm(range(N_BOOTSTRAP)):
        resample = bootstrap.generature_a_resample()

        # This judge for continue must be after resampling.
        if i in dict_results.keys():
            continue

        df = pd.DataFrame(resample, columns=target_cols)
        df["time"] = df_time_series["time"].copy()

        logger.setLevel(WARNING)
        config = estimate_sde_coeffs_for_bcs(df)["config"]
        logger.setLevel(INFO)

        params = copy.deepcopy(config.__dict__)

        params["Qx"] = calc_Qx(config)
        params["Qy"] = calc_Qy(config)
        params["Ix"] = calc_I_y_to_x(config)
        params["Iy"] = calc_I_x_to_y(config)

        dict_results[i] = params

        if (i + 1) % 10 == 0 or (i + 1) == N_BOOTSTRAP:
            write_pickle(dict_results, pickle_path)

# Plot replications' distributions

In [None]:
all_vars = ["Tx", "rx", "a", "Qx", "Ix", "Ty", "ry", "b", "Qy", "Iy"]


for data_name in ["Observation", "Simulation-GFDL-CM4C192"]:
    pickle_path = f"./bootstrap_replications_{data_name}.pickle"
    dict_params = read_pickle(pickle_path)

    all_data = {}
    for v in all_vars:
        all_data[v] = []
        for param in dict_params.values():
            all_data[v].append(param[v])

    plt.rcParams["font.size"] = 15
    fig, axes = plt.subplots(2, 5, figsize=[14, 5])
    axes = np.ravel(axes)

    for ax, v in zip(axes, ["Tx", "rx", "a", "Qx", "Ix", "Ty", "ry", "b", "Qy", "Iy"]):
        data = all_data[v]
        ax.hist(data, bins=31, density=True)
        ax.set_xlabel(v)
        ax.set_ylabel("PDF")
        ax.set_title(v)

    plt.suptitle(data_name)
    plt.tight_layout()
    plt.show()

In [None]:
all_vars = ["Tx", "rx", "a", "Qx", "Ix", "Ty", "ry", "b", "Qy", "Iy"]


for data_name in ["Observation", "Simulation-GFDL-CM4C192"]:
    pickle_path = f"./bootstrap_replications_{data_name}.pickle"
    dict_params = read_pickle(pickle_path)

    all_data = {}
    for v in all_vars:
        all_data[v] = []
        for param in dict_params.values():
            all_data[v].append(param[v])

    plt.rcParams["font.size"] = 15
    fig = plt.figure(figsize=[5, 4])
    ax = plt.subplot(111)

    Qx = all_data["Qx"]
    Qy = all_data["Qy"]
    a = all_data["a"]
    b = all_data["b"]
    cnt1, cnt2 = 0, 0
    for i in range(len(Qx)):
        if a[i] > 0 and b[i] > 0:
            continue
        else:
            cnt2 += 1

        if Qx[i] * Qy[i] < 0:
            continue
        else:
            cnt1 += 1
    print(len(Qx) - cnt1, len(Qx) - cnt2)
    print(cnt1, cnt2)

    ax.scatter(Qx, Qy, marker=".")

    ax.set_xlabel("Qx")
    ax.set_ylabel("Qy")

    ax.axvline(0, color="k", ls="--")
    ax.axhline(0, color="k", ls="--")

    ax.set_title(data_name)
    ax.set_xlim(-0.01, 0.01)
    ax.set_ylim(-0.01, 0.04)

    plt.tight_layout()
    plt.show()

# Lag correlation

In [None]:
lags = np.linspace(-25, 25, 251)
dict_results = {}
alpha = 0.05

for data_name in ["Observation", "Simulation-GFDL-CM4C192"]:
    pickle_path = f"./bootstrap_replications_{data_name}.pickle"
    dict_params = read_pickle(pickle_path)

    list_lag_corrs = []

    for param in tqdm(dict_params.values()):
        cfg = Config(
            Tx=param["Tx"],
            Ty=param["Ty"],
            a=param["a"],
            b=param["b"],
            rx=param["rx"],
            ry=param["ry"],
        )
        rhos = calc_lag_rhos(cfg, lags)[::-1]
        list_lag_corrs.append(rhos)
    all_rhos = np.stack(list_lag_corrs, axis=0)

    dict_results[data_name] = {
        "mean_rhos": np.mean(all_rhos, axis=0),
        "min_rhos": np.quantile(all_rhos, q=alpha / 2.0, axis=0),
        "max_rhos": np.quantile(all_rhos, q=1.0 - alpha / 2.0, axis=0),
    }

In [None]:
for data_name in ["Observation", "Simulation-GFDL-CM4C192"]:
    mean_rhos = dict_results[data_name]["mean_rhos"]
    min_rhos = dict_results[data_name]["min_rhos"]
    max_rhos = dict_results[data_name]["max_rhos"]

    plt.rcParams["font.size"] = 13
    fig = plt.figure(figsize=[5, 4])
    ax = fig.add_subplot(111)

    ax.plot(lags, mean_rhos, lw=2, color="k")
    ax.fill_between(lags, y1=min_rhos, y2=max_rhos, alpha=0.25, color="gray")
    ax.axhline(0.0, ls="--", color="k")
    ax.axvline(0.0, ls="--", color="k")
    ax.set_ylabel("Lag Correlations")

    n_months = 24
    ax.set_xlim([-n_months, n_months])
    ax.set_xticks(np.linspace(-n_months, n_months, 9), labels=None)
    ax.xaxis.set_minor_locator(
        matplotlib.ticker.FixedLocator(np.linspace(-n_months, n_months, 17))
    )
    ax.set_xlabel("Lag [month]")

    ax.set_title(data_name)
    plt.tight_layout()
    plt.show()