In [None]:
import joblib
import numpy as np
import os
import random
import torch
import wandb
from mclatte.test_data.diabetes import generate_data
from test_utils import (
    test_skimmed_mclatte,
    test_semi_skimmed_mclatte,
    test_mclatte,
    test_rnn,
    test_losses,
)

In [None]:
random.seed(509)
np.random.seed(509)
torch.manual_seed(509)

## Data Preparation

In [None]:
N, M, H, R, D, K, C, X, M_, Y_pre, Y_post, A, T = joblib.load(
    os.path.join(os.getcwd(), f"data/diabetes/hp_search.joblib")
)
constants = dict(m=M, h=H, r=R, d=D, k=K, c=C)

## Modelling

In [None]:
wandb.init(project="mclatte-test", entity="jasonyz")

### McLatte

#### Vanilla

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])

In [None]:
mclatte_config = {
    "encoder_class": "lstm",
    "decoder_class": "lstm",
    "hidden_dim": 8,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.021089,
    "gamma": 0.541449,
    "lambda_r": 0.814086,
    "lambda_d": 0.185784,
    "lambda_p": 0.081336,
}

#### Semi-Skimmed

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/semi_skimmed_mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])

In [None]:
semi_skimmed_mclatte_config = {
    "encoder_class": "lstm",
    "decoder_class": "lstm",
    "hidden_dim": 4,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.006606,
    "gamma": 0.860694,
    "lambda_r": 79.016676,
    "lambda_d": 1.2907,
    "lambda_p": 11.112241,
}

#### Skimmed

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/skimmed_mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])

In [None]:
skimmed_mclatte_config = {
    "encoder_class": "lstm",
    "decoder_class": "lstm",
    "hidden_dim": 16,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.000928,
    "gamma": 0.728492,
    "lambda_r": 1.100493,
    "lambda_p": 2.108935,
}

### Baseline RNN

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/baseline_rnn_hp.csv')).sort_values(by='valid_loss').iloc[0])

In [None]:
rnn_config = {
    "rnn_class": "gru",
    "hidden_dim": 64,
    "seq_len": 2,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.006321,
    "gamma": 0.543008,
}

### SyncTwin

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/synctwin_hp.csv')).sort_values(by='valid_loss').iloc[0])
synctwin_config = {
    "hidden_dim": 128,
    "reg_B": 0.522652,
    "lam_express": 0.163847,
    "lam_recon": 0.39882,
    "lam_prognostic": 0.837303,
    "tau": 0.813696,
    "batch_size": 32,
    "epochs": 100,
    "lr": 0.001476,
    "gamma": 0.912894,
}

## Test Models

In [None]:
N_TEST = 5

In [None]:
def run_tests():
    mclatte_losses = []
    semi_skimmed_mclatte_losses = []
    skimmed_mclatte_losses = []
    rnn_losses = []
    for i in range(1, N_TEST + 1):
        (
            _,
            train_data,
            test_data,
        ) = generate_data(return_raw=False)

        skimmed_mclatte_losses.append(
            test_skimmed_mclatte(
                skimmed_mclatte_config,
                constants,
                train_data,
                test_data,
                run_idx=i,
            )
        )
        semi_skimmed_mclatte_losses.append(
            test_semi_skimmed_mclatte(
                semi_skimmed_mclatte_config,
                constants,
                train_data,
                test_data,
                run_idx=i,
            )
        )
        mclatte_losses.append(
            test_mclatte(
                mclatte_config,
                constants,
                train_data,
                test_data,
                run_idx=i,
            )
        )

        rnn_losses.append(
            test_rnn(
                rnn_config,
                train_data,
                test_data,
                run_idx=i,
            )
        )
        
        joblib.dump(
            (
                mclatte_losses,
                semi_skimmed_mclatte_losses,
                skimmed_mclatte_losses,
                rnn_losses,
            ),
            f"results/test/diabetes.joblib",
        )

In [None]:
run_tests()

#### Check finished runs results

In [None]:
def print_losses():
    all_losses = joblib.load(f"results/test/diabetes.joblib")
    for losses in all_losses:
        print(f"{np.mean(losses):.3f} ({np.std(losses):.3f})")

In [None]:
print_losses()

### Statistical Testing

In [None]:
LOSS_NAMES = ["McLatte", "Semi-Skimmed McLatte", "Skimmed McLatte", "RNN", "SyncTwin"]

In [None]:
losses = joblib.load(f"results/test/diabetes.joblib")
test_losses(losses, LOSS_NAMES)