In [1]:
import joblib
import numpy as np
from .plot_utils import plot_config_results
from .test_utils import (
    test_skimmed_mclatte,
    test_semi_skimmed_mclatte,
    test_mclatte,
    test_rnn,
    test_synctwin,
)
from mclatte.synctwin import io_utils

## Data Generation

Constants used for generation

In [None]:
sim_id = "0.25_200"
seed = 509
model_id = ""
M = 5
H = 5
R = 5
D = 3
K = 1
C = 3
constants = dict(m=M, h=H, r=R, d=D, k=K, c=C)

In [None]:
def generate_data(p_0, N, return_raw=True):
    base_path_data = f"data/pkpd/{p_0}_{N}-seed-{seed}"
    data_path = base_path_data + "/{}-{}.{}"

    # loading config and data
    io_utils.load_config(data_path, "train")
    x_full, t_full, mask_full, _, y_full, _, _, _, _, _ = io_utils.load_tensor(
        data_path, "train", device="cpu"
    )
    (
        x_full_val,
        t_full_val,
        mask_full_val,
        _,
        y_full_val,
        _,
        _,
        _,
        _,
        _,
    ) = io_utils.load_tensor(data_path, "val", device="cpu")

    x = np.concatenate((x_full.cpu().numpy(), x_full_val.cpu().numpy()), axis=1)
    t = np.concatenate((t_full.cpu().numpy(), t_full_val.cpu().numpy()), axis=1)
    mask = np.concatenate(
        (mask_full.cpu().numpy(), mask_full_val.cpu().numpy()), axis=1
    )
    y = np.concatenate(
        (y_full.cpu().numpy(), y_full_val.cpu().numpy()), axis=1
    ).squeeze()

    X = x.transpose((1, 0, 2))
    N = X.shape[0]
    rand_index = np.random.permutation(N)
    X = X[rand_index]
    M_ = mask.transpose((1, 0, 2))[rand_index]
    Y_pre = y.T[rand_index]
    Y_post = y.T[rand_index]
    A = np.concatenate(
        (
            np.zeros((N // 4, 1)),
            np.ones((N // 4, 1)),
            np.zeros((N // 4, 1)),
            np.ones((N // 4, 1)),
        ),
        axis=0,
    )[rand_index]
    T = t.transpose((1, 0, 2))[rand_index]

    N_train = round(N * 0.8)
    N_test = round(N * 0.2)
    X_train, X_test = X[:N_train], X[N_train:]
    M_train, M_test = M_[:N_train], M_[N_train:]
    Y_pre_train, Y_pre_test = Y_pre[:N_train], Y_pre[N_train:]
    Y_post_train, Y_post_test = Y_post[:N_train], Y_post[N_train:]
    A_train, A_test = A[:N_train], A[N_train:]
    T_train, T_test = T[:N_train], T[N_train:]

    all_data = (
        (
            N_train,
            M,
            H,
            R,
            D,
            K,
            C,
            X_train,
            M_train,
            Y_pre_train,
            Y_post_train,
            A_train,
            T_train,
        ),
        (
            N_test,
            M,
            H,
            R,
            D,
            K,
            C,
            X_test,
            M_test,
            Y_pre_test,
            Y_post_test,
            A_test,
            T_test,
        ),
    )

    if return_raw:
        return all_data

    train_data = dict(
        n=N_train,
        x=X_train,
        m=M_train,
        y_pre=Y_pre_train,
        y_post=Y_post_train,
        a=A_train,
        t=T_train,
    )
    test_data = dict(
        n=N_test,
        x=X_test,
        m=M_test,
        y_pre=Y_pre_test,
        y_post=Y_post_test,
        a=A_test,
        t=T_test,
    )
    return N, train_data, test_data

## Modelling

### McLatte

#### Vanilla

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results_pkpd/mclatte_hp_pkpd.csv')).sort_values(by='valid_loss').iloc[0])
mclatte_config = {
    "encoder_class": "lstm",
    "decoder_class": "lstm",
    "hidden_dim": 64,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.001944,
    "gamma": 0.957115,
    "lambda_r": 0.311437,
    "lambda_d": 0.118073,
    "lambda_p": 0.49999,
}

#### Semi-Skimmed

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results_pkpd/semi_skimmed_mclatte_hp_pkpd.csv')).sort_values(by='valid_loss').iloc[0])
semi_skimmed_mclatte_config = {
    "encoder_class": "lstm",
    "decoder_class": "lstm",
    "hidden_dim": 64,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.001944,
    "gamma": 0.957115,
    "lambda_r": 0.311437,
    "lambda_d": 0.118073,
    "lambda_p": 0.49999,
}

#### Skimmed

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results_pkpd/skimmed_mclatte_hp_pkpd.csv')).sort_values(by='valid_loss').iloc[0])
skimmed_mclatte_config = {
    "encoder_class": "lstm",
    "decoder_class": "lstm",
    "hidden_dim": 64,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.021114,
    "gamma": 0.980614,
    "lambda_r": 0.093878,
    "lambda_p": 0.485204,
}

### Baseline RNN

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/baseline_rnn_hp_pkpd.csv')).sort_values(by='valid_loss').iloc[0])
rnn_config = {
    "rnn_class": "gru",
    "hidden_dim": 64,
    "seq_len": 2,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.025182,
    "gamma": 0.543008,
}

### SyncTwin

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/synctwin_hp_pkpd.csv')).sort_values(by='valid_loss').iloc[0])
synctwin_config = {
    "hidden_dim": 128,
    "reg_B": 0.778155,
    "lam_express": 0.658256,
    "lam_recon": 0.086627,
    "lam_prognostic": 0.631468,
    "tau": 0.911613,
    "batch_size": 32,
    "epochs": 100,
    "lr": 0.003222,
    "gamma": 0.572529,
}

## Test Models

In [None]:
N_TEST = 5

In [5]:
TEST_CONFIGS = [
    ["0.1", 200],
    ["0.25", 200],
    ["0.5", 200],
    ["0.1", 1000],
    ["0.25", 1000],
    ["0.5", 1000],
]

In [None]:
for config_idx in range(0):
    config = TEST_CONFIGS[config_idx]
    mclatte_losses = []
    semi_skimmed_mclatte_losses = []
    skimmed_mclatte_losses = []
    rnn_losses = []
    synctwin_losses = []
    for i in range(N_TEST * config_idx + 1, N_TEST * (1 + config_idx) + 1):
        _, train_data, test_data = generate_data(config[0], config[1], return_raw=False)

        skimmed_mclatte_losses.append(
            test_skimmed_mclatte(
                skimmed_mclatte_config,
                constants,
                train_data,
                test_data,
                run_idx=i,
            )
        )
        semi_skimmed_mclatte_losses.append(
            test_semi_skimmed_mclatte(
                semi_skimmed_mclatte_config,
                constants,
                train_data,
                test_data,
                run_idx=i,
            )
        )
        mclatte_losses.append(
            test_mclatte(
                mclatte_config,
                constants,
                train_data,
                test_data,
                run_idx=i,
            )
        )

        rnn_losses.append(
            test_rnn(
                rnn_config,
                train_data,
                test_data,
                run_idx=i,
            )
        )

        synctwin_losses.append(
            test_synctwin(
                synctwin_config,
                constants,
                train_data,
                test_data,
                run_idx=i,
            )
        )
        joblib.dump(
            (
                config,
                mclatte_losses,
                semi_skimmed_mclatte_losses,
                skimmed_mclatte_losses,
                rnn_losses,
                synctwin_losses,
            ),
            f"results/test/config_{config_idx}_pkpd.joblib",
        )

### Statistical Testing

In [2]:
LOSS_NAMES = ["McLatte", "Semi-Skimmed McLatte", "Skimmed McLatte", "RNN", "SyncTwin"]

In [12]:
all_losses = [[] for _ in range(len(LOSS_NAMES))]
for config_id in range(len(TEST_CONFIGS)):
    _, *losses = joblib.load(f"results_pkpd/maes/config_{config_id}_pkpd.joblib")
    # print(test_losses(losses))
    for i in range(len(LOSS_NAMES)):
        all_losses[i] += losses[i]

McLatte & .5000 & .3377 & .0003 & .0001 & .0000 \\
Semi-Skimmed McLatte & .6623 & .5000 & .0003 & .0001 & .0000 \\
Skimmed McLatte & .9997 & .9997 & .5000 & .0019 & .0000 \\
RNN & .9999 & .9999 & .9981 & .5000 & .1806 \\
SyncTwin & .0000 & .0000 & .0000 & .8194 & .5000 \\


### Plot with trained models

In [None]:
for config_id in range(len(TEST_CONFIGS)):
    plot_config_results("pkpd", TEST_CONFIGS, generate_data, config_id)