In [None]:
import joblib
from plot_utils import plot_config_results
from test_utils import (
    test_skimmed_mclatte,
    test_semi_skimmed_mclatte,
    test_mclatte,
    test_rnn,
    test_synctwin,
    test_losses,
)
from mclatte.test_data.pkpd import generate_data, PkpdDataGenConfig

## Data Generation

Constants used for generation

In [None]:
sim_id = "0.25_200"
seed = 509
model_id = ""
M = 5
H = 5
R = 5
D = 3
K = 1
C = 3
constants = dict(m=M, h=H, r=R, d=D, k=K, c=C)

## Modelling

### McLatte

#### Vanilla

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results_pkpd/mclatte_hp_pkpd.csv')).sort_values(by='valid_loss').iloc[0])
mclatte_config = {
    "encoder_class": "lstm",
    "decoder_class": "lstm",
    "hidden_dim": 64,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.001944,
    "gamma": 0.957115,
    "lambda_r": 0.311437,
    "lambda_d": 0.118073,
    "lambda_p": 0.49999,
}

#### Semi-Skimmed

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results_pkpd/semi_skimmed_mclatte_hp_pkpd.csv')).sort_values(by='valid_loss').iloc[0])
semi_skimmed_mclatte_config = {
    "encoder_class": "lstm",
    "decoder_class": "lstm",
    "hidden_dim": 64,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.001944,
    "gamma": 0.957115,
    "lambda_r": 0.311437,
    "lambda_d": 0.118073,
    "lambda_p": 0.49999,
}

#### Skimmed

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results_pkpd/skimmed_mclatte_hp_pkpd.csv')).sort_values(by='valid_loss').iloc[0])
skimmed_mclatte_config = {
    "encoder_class": "lstm",
    "decoder_class": "lstm",
    "hidden_dim": 64,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.021114,
    "gamma": 0.980614,
    "lambda_r": 0.093878,
    "lambda_p": 0.485204,
}

### Baseline RNN

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/baseline_rnn_hp_pkpd.csv')).sort_values(by='valid_loss').iloc[0])
rnn_config = {
    "rnn_class": "gru",
    "hidden_dim": 64,
    "seq_len": 2,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.025182,
    "gamma": 0.543008,
}

### SyncTwin

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/synctwin_hp_pkpd.csv')).sort_values(by='valid_loss').iloc[0])
synctwin_config = {
    "hidden_dim": 128,
    "reg_B": 0.778155,
    "lam_express": 0.658256,
    "lam_recon": 0.086627,
    "lam_prognostic": 0.631468,
    "tau": 0.911613,
    "batch_size": 32,
    "epochs": 100,
    "lr": 0.003222,
    "gamma": 0.572529,
}

## Test Models

In [42]:
N_TEST = 1

In [43]:
TEST_CONFIGS = [
    ["0.1", 200],
    # ["0.25", 200],
    # ["0.5", 200],
    # ["0.1", 1000],
    # ["0.25", 1000],
    # ["0.5", 1000],
]

In [44]:
def run_tests():
    for config_idx in range(len(TEST_CONFIGS)):
        config = TEST_CONFIGS[config_idx]
        data_gen_config = PkpdDataGenConfig(
            n=config[1],
            p_0=config[0],
            seed=seed,
            **constants
        )
        mclatte_losses = []
        semi_skimmed_mclatte_losses = []
        skimmed_mclatte_losses = []
        rnn_losses = []
        synctwin_losses = []
        for i in range(N_TEST * config_idx + 1, N_TEST * (1 + config_idx) + 1):
            _, train_data, test_data = generate_data(data_gen_config, return_raw=False)

            skimmed_mclatte_losses.append(
                test_skimmed_mclatte(
                    skimmed_mclatte_config,
                    constants,
                    train_data,
                    test_data,
                    run_idx=i,
                )
            )
            semi_skimmed_mclatte_losses.append(
                test_semi_skimmed_mclatte(
                    semi_skimmed_mclatte_config,
                    constants,
                    train_data,
                    test_data,
                    run_idx=i,
                )
            )
            mclatte_losses.append(
                test_mclatte(
                    mclatte_config,
                    constants,
                    train_data,
                    test_data,
                    run_idx=i,
                )
            )

            rnn_losses.append(
                test_rnn(
                    rnn_config,
                    train_data,
                    test_data,
                    run_idx=i,
                )
            )

            synctwin_losses.append(
                test_synctwin(
                    synctwin_config,
                    constants,
                    train_data,
                    test_data,
                    run_idx=i,
                )
            )
            joblib.dump(
                (
                    config,
                    mclatte_losses,
                    semi_skimmed_mclatte_losses,
                    skimmed_mclatte_losses,
                    rnn_losses,
                    synctwin_losses,
                ),
                f"results/test/config_{config_idx}_pkpd.joblib",
            )

In [45]:
run_tests()

### Statistical Testing

In [46]:
LOSS_NAMES = ["McLatte", "Semi-Skimmed McLatte", "Skimmed McLatte", "RNN", "SyncTwin"]

In [None]:
all_losses = [[] for _ in range(len(LOSS_NAMES))]
for config_id in range(len(TEST_CONFIGS)):
    _, *losses = joblib.load(f"results/test/config_{config_id}_pkpd.joblib")
    for i in range(len(LOSS_NAMES)):
        all_losses[i] += losses[i]
test_losses(all_losses, LOSS_NAMES)

### Plot with trained models

In [None]:
for config_idx in range(len(TEST_CONFIGS)):
    config = TEST_CONFIGS[config_idx]
    data_gen_config = PkpdDataGenConfig(
        n=config[1],
        p_0=config[0],
        seed=seed,
        **constants
    )
    plot_config_results("pkpd", generate_data, config_idx, data_gen_config)