In [None]:
import datetime
import joblib
import numpy as np
import os
import pandas as pd
import random
import torch
import wandb
from .test_utils import (
    test_skimmed_mclatte,
    test_semi_skimmed_mclatte,
    test_mclatte,
    test_rnn,
    test_losses,
)
from sklearn.preprocessing import scale

In [None]:
random.seed(509)
np.random.seed(509)
torch.manual_seed(509)

## Data Preparation

In [None]:
N_SUBJECTS = 70
M = 5
H = 5
R = 5

### Data Codes

In [None]:
DATA_CODES = {
    33: "reg_insulin",  # treatment
    34: "nph_insulin",  # treatment
    35: "ult_insulin",  # treatment
    48: "unspecified_bg",  # outcome
    57: "unspecified_bg",  # outcome
    58: "pre_breakfast_bg",  # outcome
    59: "post_breakfast_bg",  # outcome
    60: "pre_lunch_bg",  # outcome
    61: "post_lunch_bg",  # outcome
    62: "pre_supper_bg",  # outcome
    63: "post_supper_bg",  # outcome
    64: "pre_snack_bg",  # outcome
    65: "hypo_symptoms",  # covariate
    66: "typical_meal",  # covariate
    67: "more_meal",  # covariate
    68: "less_meal",  # covariate
    69: "typical_exercise",  # covariate
    70: "more_exercise",  # covariate
    71: "less_exercise",  # covariate
    72: "unspecified_event",  # covariate
}
TREATMENT_COLS = ["reg_insulin", "nph_insulin", "ult_insulin"]
OUTCOME_COLS = [
    "unspecified_bg",
    "pre_breakfast_bg",
    "post_breakfast_bg",
    "pre_lunch_bg",
    "post_lunch_bg",
    "pre_supper_bg",
    "post_supper_bg",
    "pre_snack_bg",
]

### Load Data

In [None]:
def float_or_na(v):
    try:
        return float(v)
    except Exception as e:
        if v == "0Hi":
            return 1
        if v == "0Lo":
            return -1
        print(e)
        return np.nan

In [None]:
def combine(values):
    valid_values = values[pd.notna(values)]
    if valid_values.shape[0] == 0:
        return np.nan
    return np.median(valid_values)

In [None]:
def try_to_date(v):
    try:
        return datetime.datetime.strptime(v, "%m-%d-%Y")
    except Exception as e:
        print(f"{e}: {v}")
    try:
        v = v[:4] + "0" + v[5:]  # handle date mis-input (e.g. 6-31)
        return datetime.datetime.strptime(v, "%m-%d-%Y")
    except Exception as e:
        print(f"{e}: {v}")
        return np.nan

In [None]:
def try_to_time(v):
    try:
        return datetime.datetime.strptime(v, "%H:%M").time()
    except Exception as e:
        print(f"{e}: {v}")
        return np.nan

In [None]:
def try_to_combine(date, time):
    try:
        return datetime.datetime.combine(date, time)
    except Exception as e:
        print(f"{e}: {date} {time}")
    if isinstance(date, datetime.datetime):
        return date
    return np.nan

In [None]:
def load_subject_i(subject_idx):
    raw_df = pd.read_csv(
        os.path.join(os.getcwd(), f"data/diabetes/data-{subject_idx:02d}"),
        sep="\t",
        names=["date", "time", "code", "value"],
    )
    raw_df["date"] = raw_df["date"].apply(try_to_date)
    raw_df["time"] = raw_df["time"].apply(try_to_time)
    raw_df["datetime"] = raw_df.apply(
        lambda row: try_to_combine(row["date"], row["time"]), axis=1
    )
    raw_df.drop(columns=["date", "time"], inplace=True)
    raw_df.sort_values(by=["datetime"], inplace=True)

    all_datetimes = raw_df.datetime.values
    converted_df = pd.DataFrame(
        index=range(len(set(all_datetimes))), columns=list(DATA_CODES.values())
    )

    begin_idx = 0
    converted_idx = 0
    while begin_idx < raw_df.shape[0]:
        while begin_idx < raw_df.shape[0] and np.isnan(all_datetimes[begin_idx]):
            begin_idx += 1

        end_idx = begin_idx
        while (
            end_idx < raw_df.shape[0]
            and all_datetimes[end_idx] == all_datetimes[begin_idx]
        ):
            if raw_df.iloc[end_idx]["code"] in DATA_CODES:
                col_name = DATA_CODES[raw_df.iloc[end_idx]["code"]]
                converted_df.iloc[converted_idx][col_name] = float_or_na(
                    raw_df.iloc[end_idx]["value"]
                )
            end_idx += 1
        begin_idx = end_idx
        converted_idx += 1

    outcomes = converted_df.apply(lambda row: combine(row[OUTCOME_COLS]), axis=1)
    treatment = converted_df[TREATMENT_COLS].apply(lambda col: combine(col), axis=0)
    converted_df = converted_df[TREATMENT_COLS + OUTCOME_COLS]

    mask_df = ~converted_df.isna()
    converted_df[pd.isna(converted_df)] = 0
    treatment[pd.isna(treatment)] = 0
    return (
        converted_df.to_numpy(),
        mask_df.to_numpy(),
        outcomes[pd.notna(outcomes)].to_numpy(),
        treatment.to_numpy(),
    )

In [None]:
def load_data():
    # Initialisation
    X = []
    M_ = []
    Y_pre = []
    Y_post = []
    A = []

    # Reading
    for subject_idx in range(1, N_SUBJECTS + 1):
        X_i, M_i, Y_i, A_i = load_subject_i(subject_idx)

        if M * R > M_i.shape[0]:
            M_.append(
                np.concatenate((np.zeros((M * R - M_i.shape[0], M_i.shape[1])), M_i))
            )
        else:
            M_.append(M_i[-M * R :])

        if H + M > Y_i.shape[0]:
            Y_pre.append(np.concatenate((np.zeros(H + M - Y_i.shape[0]), Y_i[:-H])))
        else:
            Y_pre.append(Y_i[-(H + M) : -H])

        Y_post.append(Y_i[-H:])
        A.append(A_i)
        X_i = X_i[:-H]
        if M * R > X_i.shape[0]:
            X.append(
                np.concatenate((np.zeros((M * R - X_i.shape[0], X_i.shape[1])), X_i))
            )
        else:
            X.append(X_i[-M * R :])

    # Aggregation
    X = np.stack(X)
    M_ = np.stack(M_)
    Y_pre = np.array(Y_pre)
    Y_post = np.array(Y_post)
    A = np.array(A)
    T = np.transpose(
        np.tile(np.arange(-M * R, 0), (N_SUBJECTS, X.shape[2], 1)), (0, 2, 1)
    )

    # Scaling
    X_to_scale = X.reshape((-1, X.shape[2]))  # (N, T, D) -> (N * T, D)
    X_scaled = scale(X_to_scale, axis=0)
    X = X_scaled.reshape(X.shape)

    Y = np.concatenate((Y_pre, Y_post), axis=1)
    Y_to_scale = Y.reshape((-1, 1))  # (N, M) + (N, H) -> (N * T, 1)
    Y_scaled = scale(Y_to_scale, axis=0)
    Y = Y_scaled.reshape(Y.shape)
    Y_pre, Y_post = Y[:, :-H], Y[:, -H:]

    A = scale(A, axis=0)  # [N, K]

    return X, M_, Y_pre, Y_post, A, T

In [None]:
def generate_and_write_data():
    X, M_, Y_pre, Y_post, A, T = load_data()
    joblib.dump(
        (X, M_, Y_pre, Y_post, A, T),
        os.path.join(os.getcwd(), "data/diabetes/processed.joblib"),
    )

    N = N_SUBJECTS
    D = X.shape[2]
    K = A.shape[1]
    C = 4
    joblib.dump(
        (N, M, H, R, D, K, C, X, M_, Y_pre, Y_post, A, T),
        os.path.join(os.getcwd(), f"data/diabetes/hp_search.joblib"),
    )

### Data Generation

In [None]:
def generate_data(return_raw=True):
    N_train = round(N_SUBJECTS * 0.8)
    N_test = round(N_SUBJECTS * 0.2)
    X, M_, Y_pre, Y_post, A, T = joblib.load(
        os.path.join(os.getcwd(), "data/diabetes/processed.joblib")
    )
    X_train, X_test = X[:N_train], X[N_train:]
    M_train, M_test = M_[:N_train], M_[N_train:]
    Y_pre_train, Y_pre_test = Y_pre[:N_train], Y_pre[N_train:]
    Y_post_train, Y_post_test = Y_post[:N_train], Y_post[N_train:]
    A_train, A_test = A[:N_train], A[N_train:]
    T_train, T_test = T[:N_train], T[N_train:]
    all_data = (
        N_SUBJECTS,
        N_train,
        N_test,
        X_train,
        X_test,
        M_train,
        M_test,
        Y_pre_train,
        Y_pre_test,
        Y_post_train,
        Y_post_test,
        A_train,
        A_test,
        T_train,
        T_test,
    )
    if return_raw:
        return all_data

    train_data = dict(
        n=N_train,
        x=X_train,
        m=M_train,
        y_pre=Y_pre_train,
        y_post=Y_post_train,
        a=A_train,
        t=T_train,
    )
    test_data = dict(
        n=N_test,
        x=X_test,
        m=M_test,
        y_pre=Y_pre_test,
        y_post=Y_post_test,
        a=A_test,
        t=T_test,
    )
    return N_SUBJECTS, train_data, test_data

In [None]:
N, M, H, R, D, K, C, X, M_, Y_pre, Y_post, A, T = joblib.load(
    os.path.join(os.getcwd(), f"data/diabetes/hp_search.joblib")
)
constants = dict(m=M, h=H, r=R, d=D, k=K, c=C)

## Modelling

In [None]:
wandb.init(project="mclatte-test", entity="jasonyz")

### McLatte

#### Vanilla

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])

In [None]:
mclatte_config = {
    "encoder_class": "lstm",
    "decoder_class": "lstm",
    "hidden_dim": 8,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.021089,
    "gamma": 0.541449,
    "lambda_r": 0.814086,
    "lambda_d": 0.185784,
    "lambda_p": 0.081336,
}

#### Semi-Skimmed

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/semi_skimmed_mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])

In [None]:
semi_skimmed_mclatte_config = {
    "encoder_class": "lstm",
    "decoder_class": "lstm",
    "hidden_dim": 4,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.006606,
    "gamma": 0.860694,
    "lambda_r": 79.016676,
    "lambda_d": 1.2907,
    "lambda_p": 11.112241,
}

#### Skimmed

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/skimmed_mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])

In [None]:
skimmed_mclatte_config = {
    "encoder_class": "lstm",
    "decoder_class": "lstm",
    "hidden_dim": 16,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.000928,
    "gamma": 0.728492,
    "lambda_r": 1.100493,
    "lambda_p": 2.108935,
}

### Baseline RNN

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/baseline_rnn_hp.csv')).sort_values(by='valid_loss').iloc[0])

In [None]:
rnn_config = {
    "rnn_class": "gru",
    "hidden_dim": 64,
    "seq_len": 2,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.006321,
    "gamma": 0.543008,
}

### SyncTwin

In [None]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/synctwin_hp.csv')).sort_values(by='valid_loss').iloc[0])
synctwin_config = {
    "hidden_dim": 128,
    "reg_B": 0.522652,
    "lam_express": 0.163847,
    "lam_recon": 0.39882,
    "lam_prognostic": 0.837303,
    "tau": 0.813696,
    "batch_size": 32,
    "epochs": 100,
    "lr": 0.001476,
    "gamma": 0.912894,
}

## Test Models

In [None]:
N_TEST = 5

In [None]:
def run_tests():
    mclatte_losses = []
    semi_skimmed_mclatte_losses = []
    skimmed_mclatte_losses = []
    rnn_losses = []
    for i in range(1, N_TEST + 1):
        (
            _,
            train_data,
            test_data,
        ) = generate_data(return_raw=False)

        skimmed_mclatte_losses.append(
            test_skimmed_mclatte(
                skimmed_mclatte_config,
                constants,
                train_data,
                test_data,
                run_idx=i,
            )
        )
        semi_skimmed_mclatte_losses.append(
            test_semi_skimmed_mclatte(
                semi_skimmed_mclatte_config,
                constants,
                train_data,
                test_data,
                run_idx=i,
            )
        )
        mclatte_losses.append(
            test_mclatte(
                mclatte_config,
                constants,
                train_data,
                test_data,
                run_idx=i,
            )
        )

        rnn_losses.append(
            test_rnn(
                rnn_config,
                train_data,
                test_data,
                run_idx=i,
            )
        )

        joblib.dump(
            (
                mclatte_losses,
                semi_skimmed_mclatte_losses,
                skimmed_mclatte_losses,
                rnn_losses,
            ),
            f"results/test/diabetes.joblib",
        )

#### Check finished runs results

In [None]:
def print_losses():
    all_losses = joblib.load(f"results/test/diabetes.joblib")
    for losses in all_losses:
        print(f"{np.mean(losses):.3f} ({np.std(losses):.3f})")

### Statistical Testing

In [None]:
LOSS_NAMES = ["McLatte", "Semi-Skimmed McLatte", "Skimmed McLatte", "RNN", "SyncTwin"]

In [None]:
losses = joblib.load(f"results/test/diabetes.joblib")
test_losses(losses, LOSS_NAMES)