In [2]:
import gc
import sys
import warnings 
from pathlib import Path
from collections import defaultdict

sys.path.append(str(Path.cwd().parents[1]))

import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm

from source.features import PPNFeature
from source.evaluation import nll, rmse
from source.models.LaplaceTTKM import LaplaceTTKM
from source.data_functions import load_transform_data
from source.general_functions import update_results_dict, create_dir_if_not_exists

sns.set_theme()
warnings.filterwarnings("ignore")

DATA_LIST = ['yacht', 'energy', 'boston', 'concrete', 'kin8nm', 'naval', 'protein']
SAVE_DIR = Path(f'./artifacts/results/bayes_core')
create_dir_if_not_exists(SAVE_DIR);

### Question 1. Which TT-Core should be Bayesian? 

- Firstly, compute Pattern 1, 2, 3;
- Secondly, compute the resulting Table;

In [3]:
def get_res_d_core(data_name, data2params, model_cls, params, x_train, x_test, y_train, y_test, tqdm_disable=True):
    data_params = data2params[data_name]
    d_dim = len(data_params['tt_ranks']) - 1
    res_dict = defaultdict(list)
    for d_core_als in tqdm(range(d_dim), disable=tqdm_disable):
        model = model_cls(hess_type=d_core_als, d_core_als=d_core_als, **data_params, **params)
        model.fit(x_train, y_train)
        ys_train, ys_std_train = model.predict(x_train, True)
        ys_test, ys_std_test = model.predict(x_test, True)
        update_results_dict(
            res_dict, 
            d_core=d_core_als,
            nll_train=nll(ys_train, ys_std_train**2, y_train), 
            nll_test=nll(ys_test, ys_std_test**2, y_test),
            rmse_train=rmse(ys_train, y_train), 
            rmse_test=rmse(ys_test, y_test),
        )
        # To free cache memory
        jax.clear_caches()
        gc.collect()
    return res_dict

def compute_res_df(data_name: str, data2params: dict, ext_str: str) -> None:
    # Prepare the data:
    x_train, x_test, y_train, y_test = load_transform_data(
        f'../../data/{data_name}.csv', 0.1, 13, True, False, 'std')
    x_train, x_test, y_train, y_test = map(
        jnp.array, [x_train, x_test, y_train, y_test])
    # Get results:
    res_df = pd.DataFrame(
        get_res_d_core(
            data_name, data2params, LaplaceTTKM, PARAMS, x_train, x_test, y_train, y_test, False)
    )
    res_df.to_csv(SAVE_DIR / f'{data_name}_{ext_str}.csv')

# Hyperparams:
PARAMS = dict(
    fmap=PPNFeature(), beta_e=1e-2, gamma_w=1e-5, pd_mode='la', 
    hess_th=None, seed=13, n_epoch_vi=1, pd_samples=30, beta_e_samples=10,
)

#### Pattern 1: [R, R, R] R>1

In [None]:
data2params = dict(
    yacht=dict(m_order=8, tt_ranks=(1, *(2,)*5, 1), n_epoch=20), 
    energy=dict(m_order=16, tt_ranks=(1, *(2,)*7, 1), n_epoch=20),
    boston=dict(m_order=8, tt_ranks=(1, *(2,)*12, 1), n_epoch=20),
    concrete=dict(m_order=12, tt_ranks=(1, *(3,)*7, 1), n_epoch=20),
    kin8nm=dict(m_order=16, tt_ranks=(1, *(5,)*7, 1), n_epoch=10),
    naval=dict(m_order=16, tt_ranks=(1, *(6,)*15, 1), n_epoch=5),
    protein=dict(m_order=16, tt_ranks=(1, *(7,)*8, 1), n_epoch=5),
)
for data_name in DATA_LIST:
    compute_res_df(data_name, data2params, ext_str='eq_ranks')

#### Pattern 2: [R, P, R] P>R

In [None]:
data2params = dict(
    yacht=dict(m_order=8, tt_ranks=(1, 2, 2, 5, 2, 2, 1), n_epoch=20), 
    energy=dict(m_order=16, tt_ranks=(1, 2, 2, 2, 5, 2, 2, 2, 1), n_epoch=20),
    boston=dict(m_order=4, tt_ranks=(1, 2, 2, 2, 2, 2, 5, 5, 2, 2, 2, 2, 2, 1), n_epoch=20),
    concrete=dict(m_order=8, tt_ranks=(1, 3, 3, 3, 5, 3, 3, 3, 1), n_epoch=20),
    kin8nm=dict(m_order=16, tt_ranks=(1, 4, 4, 4, 8, 4, 4, 4, 1), n_epoch=10),
    naval=dict(m_order=8, tt_ranks=(1, 4, 4, 4, 4, 4, 4, 8, 8, 8, 4, 4, 4, 4, 4, 4, 1), n_epoch=5),
    protein=dict(m_order=16, tt_ranks=(1, 4, 4, 4, 8, 8, 4, 4, 4, 1), n_epoch=5),
)
for data_name in DATA_LIST:
    compute_res_df(data_name, data2params, ext_str='diff_ranks')

#### Pattern 3: [R, P, R] P<R

In [None]:
data2params = dict(
    yacht=dict(m_order=8, tt_ranks=(1, 2, 2, 1, 2, 2, 1), n_epoch=20), 
    energy=dict(m_order=16, tt_ranks=(1, 2, 2, 2, 1, 2, 2, 2, 1), n_epoch=20),
    boston=dict(m_order=4, tt_ranks=(1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1), n_epoch=20),
    concrete=dict(m_order=8, tt_ranks=(1, 3, 3, 3, 1, 3, 3, 3, 1), n_epoch=20),
    kin8nm=dict(m_order=16, tt_ranks=(1, 4, 4, 4, 1, 4, 4, 4, 1), n_epoch=10),
    naval=dict(m_order=8, tt_ranks=(1, 4, 4, 4, 4, 4, 4, 4, 1, 4, 4, 4, 4, 4, 4, 4, 1), n_epoch=5),
    protein=dict(m_order=16, tt_ranks=(1, 4, 4, 4, 4, 1, 4, 4, 4, 1), n_epoch=5),
)
for data_name in DATA_LIST:
    compute_res_df(data_name, data2params, ext_str='diff_ranks_small')

#### Resulting Table:

In [None]:
pattern_names = ['Pattern 1', 'Pattern 2', 'Pattern 3']
dd = pd.DataFrame(np.zeros((3, 7)), columns=DATA_LIST, index=pattern_names)

for data_name in DATA_LIST:
    pattern_paths = [f'{data_name}_eq_ranks.csv', f'{data_name}_diff_ranks.csv', f'{data_name}_diff_ranks_small.csv']
    for pn, p_path in zip(pattern_names, pattern_paths):
        res_df = pd.read_csv(SAVE_DIR / p_path, index_col=0) 
        d_core = int(res_df[['d_core', 'nll_test']].sort_values(by='nll_test', ascending=True)['d_core'].iloc[0]) + 1
        dd[data_name].loc[pn] = f'{d_core}/{len(res_df)}'

dd.columns = [c.capitalize() for c in dd.columns]
dd = dd.T.sort_index()
dd.to_csv(SAVE_DIR / f'bayes_tt_core.csv')

caption = "What TT-core should be Bayesian?"
print(
    dd.to_latex(
        column_format='||l||c|c|c||',
        caption=caption,
        label="table:bayes_tt_core",
        multirow=False,
        index=True,
    )
)

### Question 2. Which feature - $x_d$ should be Bayesian?

- Firstly, train the models;
- Secondly, compute the resulting Table;

In [7]:
data2params = dict(
    yacht=dict(m_order=8, tt_ranks=(1, *(2,)*5, 1), n_epoch=20), 
    energy=dict(m_order=16, tt_ranks=(1, *(2,)*7, 1), n_epoch=20),
    boston=dict(m_order=8, tt_ranks=(1, *(2,)*12, 1), n_epoch=20),
    concrete=dict(m_order=12, tt_ranks=(1, *(3,)*7, 1), n_epoch=20),
    kin8nm=dict(m_order=16, tt_ranks=(1, *(5,)*7, 1), n_epoch=10),
    naval=dict(m_order=16, tt_ranks=(1, *(6,)*15, 1), n_epoch=5),
    protein=dict(m_order=12, tt_ranks=(1, *(6,)*8, 1), n_epoch=5),
)
# Hyperparams:
params = dict(
    fmap=PPNFeature(), beta_e=1e-2, gamma_w=1e-5, pd_mode='la', 
    hess_th=1e-3, seed=13, n_epoch_vi=1, pd_samples=30, beta_e_samples=10,
)
N_SHIFTS = 6 
TQDM_DISABLE = False
F_IDX = list(range(N_SHIFTS))

#### Models Training:

In [None]:
for data_name in DATA_LIST: 
    print(data_name)
    data_params = data2params[data_name]
    d_dim = len(data_params['tt_ranks']) - 1
    perm_res = list()
    for shift_f in range(N_SHIFTS):
        # Prepare the data:
        x_train, x_test, y_train, y_test = load_transform_data(
            f'../../data/{data_name}.csv', 0.2, 13, True, False, 'std')
        x_train, x_test, y_train, y_test = map(
            jnp.array, [x_train, x_test, y_train, y_test])
        x_train = x_train[:, np.roll(F_IDX, shift_f)].copy()
        x_test = x_test[:, np.roll(F_IDX, shift_f)].copy()

        res_dict = defaultdict(list)
        for d_core_als in tqdm(range(d_dim), disable=TQDM_DISABLE):
            model = LaplaceTTKM(hess_type=d_core_als, d_core_als=d_core_als, **data_params, **params)
            model.fit(x_train, y_train)
            ys_train, ys_std_train = model.predict(x_train, True)
            ys_test, ys_std_test = model.predict(x_test, True)

            update_results_dict(
                res_dict, 
                d_core=d_core_als,
                nll_train=nll(ys_train, ys_std_train**2, y_train), 
                nll_test=nll(ys_test, ys_std_test**2, y_test),
                rmse_train=rmse(ys_train, y_train), 
                rmse_test=rmse(ys_test, y_test),
            )
        jax.clear_caches() # To free cache memory
        gc.collect()
        perm_res.append(res_dict)

    res_df = pd.concat([pd.DataFrame(res) for res in perm_res], keys=F_IDX)
    res_df.to_csv(SAVE_DIR / f'{data_name}_f_permutations.csv')

#### Resulting Table:

In [None]:
shift_names = ['No Shift',] + [f'Shift {i + 1}' for i in range(N_SHIFTS - 1)]
dd = pd.DataFrame(
    np.zeros((6, len(DATA_LIST))), columns=DATA_LIST, index=shift_names)

for data_name in DATA_LIST:
    res_df = pd.read_csv(SAVE_DIR / f'{data_name}_f_permutations.csv', index_col=(0, 1))
    for shift_f in range(N_SHIFTS):
        shift_df = res_df.loc[shift_f]
        d_core = int(shift_df[['d_core', 'nll_test']].sort_values(by='nll_test', ascending=True)['d_core'].iloc[0]) + 1
        dd[data_name][shift_f] = f'{d_core}/{len(shift_df)}'

dd.columns = [c.capitalize() for c in dd.columns]
dd = dd.T.sort_index()[['No Shift', 'Shift 1', 'Shift 2', 'Shift 3']]
dd.to_csv(SAVE_DIR / f'bayes_tt_features.csv')

caption = "Which feature $x_d$ should be Bayesian?"
print(
    dd.to_latex(
        column_format='||l||c|c|c|c||',
        caption=caption,
        label="table:bayes_tt_features",
        multirow=False,
        index=True,
    )
)