In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import sys
sys.path.append("..")
import time
import warnings
from pathlib import Path
from copy import copy

import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import torch

from falkon.kernels import GaussianKernel
from falkon import FalkonOptions
from nyskoop.data.lorenz import Lorenz63
from experiment_helpers import *

In [5]:
set_matplotlib_rc(18, 20, 22)

In [6]:
def gen_l63_fulldata(n_train: int, n_test: int, num_test_sets: int, max_q: int, **kwargs):
    ics = np.array([1, 1, 1])
    lorenz63_dset = Lorenz63(ics[0], ics[1], ics[2], sigma=10, rho=28, beta=8/3, **kwargs)
    full_data = lorenz63_dset.solve(n_train + num_test_sets * n_test + max_q * (num_test_sets + 1))
    full_data = torch.from_numpy(full_data)
    return full_data

In [7]:
fig_path = Path("outputs/timings")
fig_path.mkdir(exist_ok=True)

Define a few helper functions to run the experiments

In [16]:
q = 5
num_test_sets = 5
n_test = 2_000
full_data = gen_l63_fulldata(
    10_000, n_test, num_test_sets, max_q=q, dt=0.01, burnin=5, l63_version='normal'
)

In [17]:
def gen_data_get_error(n_train, n_reps, dtype, **model_hps):
    X_train, Y_train, all_test_sets = gen_lagged(
        full_data, n_train, n_test, num_test_sets, q, dtype=dtype
    )
    errs, times = [], []
    for i in range(n_reps):
        t_start = time.time()
        est = train_est(X_train, Y_train, **model_hps)
        t_end = time.time()
        times.append(t_end - t_start)
        err = nrmse(est, all_test_sets, variable='all')[0].mean()
        errs.append(err.item())
    return (
        np.median(errs),
        np.percentile(errs, 90),
        np.percentile(errs, 10),
        np.min(times)
   )

In [18]:
def hp_search_inner(hp_dict, fn, out_df: pd.DataFrame, pbar: tqdm, **kwargs):
    new_dict = copy(hp_dict)
    
    final_hp_key, hp_list = None, []
    for hp_key in list(new_dict.keys()):
        hp_list = new_dict.pop(hp_key)
        # Allow to specify different HP values for different 'kind' of algorithms
        if isinstance(hp_list, dict):
            hp_list = hp_list.get(kwargs['kind'], [])
        if len(hp_list) > 0:
            final_hp_key = hp_key
            break
    
    if len(new_dict) == 0 and final_hp_key is None:  # base case??
        err, err_high, err_low, t_elapsed = fn(**kwargs)
        kwargs["err"] = err
        kwargs["err_high"] = err_high
        kwargs["err_low"] = err_low
        kwargs["time"] = t_elapsed
        pbar.update(1)
        return pd.DataFrame.from_dict([kwargs])
    
    hp_dfs = []
    for hp_val in hp_list:
        kwargs[final_hp_key] = hp_val
        try:
            hp_dfs.append(hp_search_inner(new_dict, fn, out_df, pbar, **kwargs))
        except Exception as e:
            warnings.warn(f"{kwargs} failed.")
            raise
    return pd.concat([out_df] + hp_dfs)

def hp_search(hp_dict, fn, out_df, **kwargs):
    np.random.seed(42)
    torch.manual_seed(42)
    # Calculate total iterations taking into account the special dict case
    total = 0
    for algo_kind in hp_dict["kind"]:
        subtotal = 1
        for k, v, in hp_dict.items():
            if k == "kind": continue
            if isinstance(v, dict):
                subtotal *= len(v.get(algo_kind, [1]))
            elif isinstance(v, list) or isinstance(v, tuple):
                subtotal *= len(v)
        total += subtotal
    with tqdm(desc='HP Search', total=total, mininterval=1.0) as pbar:
        return hp_search_inner(hp_dict, fn, out_df, pbar, **kwargs)

In [21]:
pcr_hps_n = {
    "kind": ["full_rrr", "rrr", "full_pcr", "pcr",],
    "M": {"pcr": [250], "rrr": [250]},
    "kernel": [GaussianKernel(3.5, opt=FalkonOptions(use_cpu=True))],
    "num_components": [25],
    "n_train": {
        "full_rrr": np.logspace(2, 4, 10, dtype=int)[:-3],
        "rrr": np.logspace(2, 4, 10, dtype=int),
        "full_pcr": np.logspace(2, 4, 10, dtype=int),
        "pcr": np.logspace(2, 4, 10, dtype=int),
    },
    "n_reps": {
        "full_rrr": [2],
        "full_pcr": [2],
        "rrr": [20],
        "pcr": [20]
    },
    "dtype": [torch.float32],
    "penalty": {"rrr": [1e-6], "full_rrr": [1e-6]},
}
pcr_n_df_path = Path("outputs/timings/l63_varyn_m250_all.csv")
pcr_n_df_path.parent.mkdir(parents=True, exist_ok=True)

In [23]:
pcr_n_df = hp_search(
    pcr_hps_n, 
    gen_data_get_error,
    pd.DataFrame()
)
pcr_n_df["sigma"] = pcr_n_df["kernel"].apply(lambda r: r.sigma.item())
pcr_n_df = pcr_n_df.drop(columns=["kernel"])
pcr_n_df.to_csv(pcr_n_df_path, index=False)

HP Search:   8%|▊         | 3/37 [00:02<00:26,  1.28it/s]


KeyboardInterrupt: 

In [None]:
df = pd.read_csv(pcr_n_df_path)
fig, ax = plt.subplots(ncols=2, figsize=(8, 4))
models = [
    ("full_rrr", "FullRRR", IBM_COLORS[3], "-"),
    ("rrr", "NysRRR", IBM_COLORS[1], "--"),
    ("full_pcr", "FullPCR", IBM_COLORS[0], "-"),
    ("pcr", "NysPCR", IBM_COLORS[2], "--"),
]

for kind, lbl, c, ls in models:
    # Error
    ax[0].plot(df[df["kind"] == kind]["n_train"],
            df[df["kind"] == kind]["err"], label=lbl, 
               c=c, marker='o', lw=2, linestyle=ls)
    ax[0].fill_between(
        df[df["kind"] == kind]["n_train"],
        df[df["kind"] == kind]["err_low"],
        df[df["kind"] == kind]["err_high"],
        color=c, alpha=0.3,
    )
    # Time
    ax[1].plot(df[df["kind"] == kind]["n_train"],
             df[df["kind"] == kind]["time"], 
             c=c, marker='o', lw=2, linestyle=ls)

ax[0].set_xlabel('Number of samples')
ax[0].set_ylabel("nRMSE")
ax[0].set_xscale('log')
ax[0].set_xlim(100, 10000)

ax[1].set_xlabel('Number of samples')
ax[1].set_ylabel('Time (s)')
ax[1].set_xscale('log')
ax[1].set_yscale('log')
ax[1].set_xlim(100, 10000)
ax[1].set_yticks([1e-2, 1e-1, 1, 1e1, 1e2])

fig.legend(ncols=4, loc='upper center', bbox_to_anchor=(0.5, 1.1))
fig.tight_layout()
# fig.savefig(pcr_n_df_path.with_suffix('.png'), dpi=300)

In [None]:
df = pd.read_csv(pcr_n_df_path)
fig, ax = plt.subplots(ncols=2)
models = [
    ("full_rrr", "FullRRR", IBM_COLORS[3], "-"),
    ("rrr", "NysRRR", IBM_COLORS[1], "--"),
#     ("full_pcr", "FullPCR", IBM_COLORS[0], "-"),
#     ("pcr", "NysPCR", IBM_COLORS[2], "--"),
]

for kind, lbl, c, ls in models:
    # Error
    ax[0].plot(df[df["kind"] == kind]["n_train"],
            df[df["kind"] == kind]["err"], label=lbl, 
               c=c, marker='o', lw=2, linestyle=ls)
    ax[0].fill_between(
        df[df["kind"] == kind]["n_train"],
        df[df["kind"] == kind]["err_low"],
        df[df["kind"] == kind]["err_high"],
        color=c, alpha=0.3,
    )
    # Time
    ax[1].plot(df[df["kind"] == kind]["n_train"],
             df[df["kind"] == kind]["time"], 
             c=c, marker='o', lw=2, linestyle=ls)
fig.legend(ncols=4, loc='upper center')

ax[0].set_xlabel('Number of samples')
ax[0].set_ylabel("nRMSE")
ax[0].set_xscale('log')

ax[1].set_xlabel('Number of samples')
ax[1].set_ylabel('Time (s)')
ax[1].set_xscale('log')
ax[1].set_yscale('log')

fig.tight_layout()
# fig.savefig(pcr_n_df_path.with_suffix('.png'), dpi=300)
