In [1]:
import datetime
import joblib
import numpy as np
import os
import pandas as pd
import plotly.graph_objects as go
import random
import torch
import wandb
from mclatte.model import (
    train_mclatte, 
    train_semi_skimmed_mclatte, 
    train_skimmed_mclatte, 
    McLatte,
    SemiSkimmedMcLatte,
    SkimmedMcLatte,
)
from rnn.model import (
    train_baseline_rnn,
    BaselineRnn,
)
from scipy.stats import ttest_ind
from sklearn.preprocessing import scale
from synctwin.model import (
    train_synctwin,
    SyncTwinPl,
)

In [2]:
random.seed(509)
np.random.seed(509)
torch.manual_seed(509)

<torch._C.Generator at 0x7fa0b861d8b0>

## Data Preparation

In [3]:
N_SUBJECTS = 70
M = 5
H = 5
R = 5

### Data Codes

In [4]:
DATA_CODES = {
    33: 'reg_insulin',          # treatment
    34: 'nph_insulin',          # treatment
    35: 'ult_insulin',          # treatment
    48: 'unspecified_bg',       # outcome
    57: 'unspecified_bg',       # outcome
    58: 'pre_breakfast_bg',     # outcome
    59: 'post_breakfast_bg',    # outcome
    60: 'pre_lunch_bg',         # outcome
    61: 'post_lunch_bg',        # outcome
    62: 'pre_supper_bg',        # outcome
    63: 'post_supper_bg',       # outcome
    64: 'pre_snack_bg',         # outcome
    65: 'hypo_symptoms',        # covariate
    66: 'typical_meal',         # covariate
    67: 'more_meal',            # covariate
    68: 'less_meal',            # covariate
    69: 'typical_exercise',     # covariate
    70: 'more_exercise',        # covariate
    71: 'less_exercise',        # covariate
    72: 'unspecified_event',    # covariate
}
TREATMENT_COLS = ['reg_insulin', 'nph_insulin', 'ult_insulin']
OUTCOME_COLS = ['unspecified_bg', 'pre_breakfast_bg', 'post_breakfast_bg', 'pre_lunch_bg', 'post_lunch_bg', 'pre_supper_bg', 'post_supper_bg', 'pre_snack_bg']

### Load Data

In [5]:
def float_or_na(v):
    try:
        return float(v)
    except Exception as e:
        if v == '0Hi':
            return 1
        if v == '0Lo':
            return -1
        print(e)
        return np.nan

In [6]:
def combine(values):
    valid_values = values[pd.notna(values)]
    if valid_values.shape[0] == 0:
        return np.nan
    return np.median(valid_values)

In [7]:
def try_to_date(v):
    try:
        return datetime.datetime.strptime(v, '%m-%d-%Y')
    except Exception as e:
        print(f'{e}: {v}')
    try:
        v = v[:4] + '0' + v[5:]  # handle date mis-input (e.g. 6-31)
        return datetime.datetime.strptime(v, '%m-%d-%Y')
    except Exception as e:
        print(f'{e}: {v}')
        return np.nan

In [8]:
def try_to_time(v):
    try:
        return datetime.datetime.strptime(v, '%H:%M').time()
    except Exception as e:
        print(f'{e}: {v}')
        return np.nan

In [9]:
def try_to_combine(date, time):
    try:
        return datetime.datetime.combine(date, time)
    except Exception as e:
        print(f'{e}: {date} {time}')
    if isinstance(date, datetime.datetime):
        return date
    return np.nan

In [10]:
def load_subject_i(subject_idx):
    raw_df = pd.read_csv(os.path.join(os.getcwd(), f'data/diabetes/data-{subject_idx:02d}'), sep='\t', names=['date', 'time', 'code', 'value'])
    raw_df['date'] = raw_df['date'].apply(try_to_date)
    raw_df['time'] = raw_df['time'].apply(try_to_time)
    raw_df['datetime'] = raw_df.apply(lambda row: try_to_combine(row['date'], row['time']), axis=1)
    raw_df.drop(columns=['date', 'time'], inplace=True)
    raw_df.sort_values(by=['datetime'], inplace=True)
    
    all_datetimes = raw_df.datetime.values
    converted_df = pd.DataFrame(index=range(len(set(all_datetimes))), columns=list(DATA_CODES.values()))

    begin_idx = 0
    converted_idx = 0
    while begin_idx < raw_df.shape[0]:
        while begin_idx < raw_df.shape[0] and np.isnan(all_datetimes[begin_idx]):
            begin_idx += 1
        
        end_idx = begin_idx
        while end_idx < raw_df.shape[0] and all_datetimes[end_idx] == all_datetimes[begin_idx]:
            if raw_df.iloc[end_idx]['code'] in DATA_CODES:
                col_name = DATA_CODES[raw_df.iloc[end_idx]['code']]
                converted_df.iloc[converted_idx][col_name] = float_or_na(raw_df.iloc[end_idx]['value'])
            end_idx += 1
        begin_idx = end_idx
        converted_idx += 1

    outcomes = converted_df.apply(lambda row: combine(row[OUTCOME_COLS]), axis=1)
    treatment = converted_df[TREATMENT_COLS].apply(lambda col: combine(col), axis=0)
    converted_df = converted_df[TREATMENT_COLS + OUTCOME_COLS]

    mask_df = ~converted_df.isna()
    converted_df[pd.isna(converted_df)] = 0
    treatment[pd.isna(treatment)] = 0
    return (
        converted_df.to_numpy(), 
        mask_df.to_numpy(), 
        outcomes[pd.notna(outcomes)].to_numpy(), 
        treatment.to_numpy(),
    )

In [11]:
def load_data():
    # Initialisation
    X = []
    M_ = []
    Y_pre = []
    Y_post = []
    A = []

    # Reading
    for subject_idx in range(1, N_SUBJECTS + 1):
        X_i, M_i, Y_i, A_i = load_subject_i(subject_idx)
        
        if M * R > M_i.shape[0]:
            M_.append(np.concatenate((np.zeros((M * R - M_i.shape[0], M_i.shape[1])), M_i)))
        else:
            M_.append(M_i[-M * R:])
        
        if H + M > Y_i.shape[0]:
            Y_pre.append(np.concatenate((np.zeros(H + M - Y_i.shape[0]), Y_i[:-H])))
        else:
            Y_pre.append(Y_i[-(H + M):-H])
        
        Y_post.append(Y_i[-H:])
        A.append(A_i)
        X_i = X_i[:-H]
        if M * R > X_i.shape[0]:
            X.append(np.concatenate((np.zeros((M * R - X_i.shape[0], X_i.shape[1])), X_i)))
        else:
            X.append(X_i[-M * R:])

    # Aggregation
    X = np.stack(X)
    M_ = np.stack(M_)
    Y_pre = np.array(Y_pre)
    Y_post = np.array(Y_post)
    A = np.array(A)
    T = np.transpose(np.tile(np.arange(-M * R, 0), (N_SUBJECTS, X.shape[2], 1)), (0, 2, 1))
    
    # Scaling
    X_to_scale = X.reshape((-1, X.shape[2]))  # (N, T, D) -> (N * T, D)
    X_scaled = scale(X_to_scale, axis=0)
    X = X_scaled.reshape(X.shape)
    
    Y = np.concatenate((Y_pre, Y_post), axis=1)
    Y_to_scale = Y.reshape((-1, 1))  # (N, M) + (N, H) -> (N * T, 1)
    Y_scaled = scale(Y_to_scale, axis=0)
    Y = Y_scaled.reshape(Y.shape)
    Y_pre, Y_post = Y[:, :-H], Y[:, -H:]
    
    A = scale(A, axis=0)  # [N, K]
    
    return X, M_, Y_pre, Y_post, A, T

In [12]:
def generate_and_write_data():
    X, M_, Y_pre, Y_post, A, T = load_data()
    joblib.dump((X, M_, Y_pre, Y_post, A, T), os.path.join(os.getcwd(), 'data/diabetes/processed.joblib'))
    
    N = N_SUBJECTS
    D = X.shape[2] 
    K = A.shape[1] 
    C = 4 
    joblib.dump(
        (N, M, H, R, D, K, C, X, M_, Y_pre, Y_post, A, T),
        os.path.join(os.getcwd(), f"data/diabetes/hp_search.joblib")
    )

### Data Generation

In [13]:
def generate_data():
    N_train = round(N_SUBJECTS * 0.8)
    N_test = round(N_SUBJECTS * 0.2)
    X, M_, Y_pre, Y_post, A, T = joblib.load(os.path.join(os.getcwd(), 'data/diabetes/processed.joblib'))
    X_train, X_test = X[:N_train], X[N_train:]
    M_train, M_test = M_[:N_train], M_[N_train:]
    Y_pre_train, Y_pre_test = Y_pre[:N_train], Y_pre[N_train:]
    Y_post_train, Y_post_test = Y_post[:N_train], Y_post[N_train:]
    A_train, A_test = A[:N_train], A[N_train:]
    T_train, T_test = T[:N_train], T[N_train:]
    all_data = (
        N_SUBJECTS, N_train, N_test, 
        X_train, X_test, 
        M_train, M_test, 
        Y_pre_train, Y_pre_test, 
        Y_post_train, Y_post_test, 
        A_train, A_test, 
        T_train, T_test
    )
    return all_data

In [14]:
N, M, H, R, D, K, C, X, M_, Y_pre, Y_post, A, T = joblib.load(
    os.path.join(os.getcwd(), f"data/diabetes/hp_search.joblib")
)

## Modelling

In [15]:
wandb.init(project='mclatte-test', entity='jasonyz')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjasonyz[0m (use `wandb login --relogin` to force relogin)
  warn("The `IPython.html` package has been deprecated since IPython 4.0. "
[34m[1mwandb[0m: wandb version 0.12.9 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [16]:
def na_catcher(func):
    def wrapper_na_catcher(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            print(e)
            return np.nan
    return wrapper_na_catcher

### McLatte

In [17]:
def infer_mcespresso(trained_mcespresso, X_test, A_test, T_test, M_test):
    trained_mcespresso.eval()
    return trained_mcespresso(
        torch.from_numpy(X_test).float(),
        torch.from_numpy(A_test).float(),
        torch.from_numpy(T_test).float(),
        torch.from_numpy(M_test).float(),
    )

#### Vanilla

In [18]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])

In [19]:
mclatte_config = {
    'encoder_class': 'lstm',
    'decoder_class': 'lstm',
    'hidden_dim': 8,
    'batch_size': 64,
    'epochs': 100,
    'lr': 0.021089,
    'gamma': 0.541449,
    'lambda_r': 0.814086,
    'lambda_d': 0.185784,
    'lambda_p': 0.081336,
}

In [20]:
def test_mclatte(
    X_train, 
    X_test, 
    M_train, 
    M_test, 
    Y_pre_train, 
    Y_post_train, 
    Y_post_test, 
    A_train, 
    A_test, 
    T_train, 
    T_test,
    run_idx=0,
):
    trained_mclatte = train_mclatte(
        mclatte_config,
        X_train,
        M_train,
        Y_pre_train,
        Y_post_train,
        A_train, 
        T_train,
        R,
        M,
        H,
        input_dim=X_train.shape[2], 
        treatment_dim=A_train.shape[1], 
        test_run=run_idx,
    )
    _, _, y_tilde = infer_mcespresso(
        trained_mclatte, X_test, A_test, T_test, M_test
    )
    
    return torch.nn.functional.l1_loss(
        y_tilde, 
        torch.from_numpy(Y_post_test).float()
    ).item()

#### Semi-Skimmed

In [21]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/semi_skimmed_mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])

In [22]:
semi_skimmed_mclatte_config = {
    'encoder_class': 'lstm',
    'decoder_class': 'lstm',
    'hidden_dim': 4,
    'batch_size': 64,
    'epochs': 100,
    'lr': 0.006606,
    'gamma': 0.860694,
    'lambda_r': 79.016676,
    'lambda_d': 1.2907,
    'lambda_p': 11.112241,
}

In [23]:
def test_semi_skimmed_mclatte(
    X_train, 
    X_test, 
    M_train, 
    M_test, 
    Y_pre_train, 
    Y_post_train, 
    Y_post_test, 
    A_train, 
    A_test, 
    T_train, 
    T_test,
    run_idx=0,
):
    trained_semi_skimmed_mclatte = train_semi_skimmed_mclatte(
        semi_skimmed_mclatte_config,
        X_train,
        M_train,
        Y_pre_train,
        Y_post_train,
        A_train, 
        T_train,
        R,
        M,
        H,
        input_dim=X_train.shape[2], 
        treatment_dim=A_train.shape[1], 
        test_run=run_idx,
    )
    _, _, y_tilde = infer_mcespresso(
        trained_semi_skimmed_mclatte, X_test, A_test, T_test, M_test
    )
    
    return torch.nn.functional.l1_loss(
        y_tilde, 
        torch.from_numpy(Y_post_test).float()
    ).item()

#### Skimmed

In [24]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/skimmed_mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])

In [25]:
skimmed_mclatte_config = {
    'encoder_class': 'lstm',
    'decoder_class': 'lstm',
    'hidden_dim': 16,
    'batch_size': 64,
    'epochs': 100,
    'lr': 0.000928,
    'gamma': 0.728492,
    'lambda_r': 1.100493,
    'lambda_p': 2.108935,
}

In [26]:
def test_skimmed_mclatte(
    X_train, 
    X_test, 
    M_train, 
    M_test, 
    Y_pre_train, 
    Y_post_train, 
    Y_post_test, 
    A_train, 
    A_test, 
    T_train, 
    T_test,
    run_idx=0,
):
    trained_skimmed_mclatte = train_skimmed_mclatte(
        skimmed_mclatte_config,
        X_train,
        M_train,
        Y_pre_train,
        Y_post_train,
        A_train, 
        T_train,
        R,
        M,
        H,
        input_dim=X_train.shape[2], 
        treatment_dim=A_train.shape[1], 
        test_run=run_idx,
    )
    _, y_tilde = infer_mcespresso(
        trained_skimmed_mclatte, X_test, A_test, T_test, M_test
    )
    
    return torch.nn.functional.l1_loss(
        y_tilde, 
        torch.from_numpy(Y_post_test).float()
    ).item()

### Baseline RNN

In [27]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/baseline_rnn_hp.csv')).sort_values(by='valid_loss').iloc[0])

In [28]:
rnn_config = {
    'rnn_class': 'gru',
    'hidden_dim': 64,
    'seq_len': 2,
    'batch_size': 64,
    'epochs': 100,
    'lr': 0.006321,
    'gamma': 0.543008,
}

In [29]:
def rnn_predict(trained_rnn, Y_pre, Y_post, return_Y_pred=False):
    """
    Make predictions using results from previous time steps.
    """
    Y = Y_pre
    losses = 0.0
    Y_pred = []
    for i in range(Y_post.shape[1]):
        Y_tilde = trained_rnn(
            torch.from_numpy(Y).float().unsqueeze(2)
        ).squeeze()

        Y = np.concatenate((
            Y[:, 1:], 
            Y_tilde.cpu().detach().numpy()[:, [-1]]
        ), axis=1)
        
        losses += torch.nn.functional.l1_loss(
            Y_tilde[:, -1], 
            torch.from_numpy(Y_post).float()[:, i]
        ).item()
        Y_pred.append(Y_tilde[:, -1])
    if return_Y_pred:
        return torch.stack(Y_pred, 1)
    return losses / Y_post.shape[1]

In [30]:
def infer_rnn(trained_rnn, Y_pre_test, Y_post_test, return_Y_pred=False):
    trained_rnn.eval()
    return rnn_predict(trained_rnn, Y_pre_test, Y_post_test, return_Y_pred)

In [31]:
def test_rnn(
    Y_pre_train, 
    Y_pre_test, 
    Y_post_train, 
    Y_post_test, 
    run_idx=0,
):
    trained_rnn = train_baseline_rnn(
        rnn_config,
        Y=np.concatenate((Y_pre_train, Y_post_train), axis=1),
        input_dim=1, 
        test_run=run_idx,
    )

    trained_rnn.eval()
    return rnn_predict(trained_rnn, Y_pre_test, Y_post_test)

### SyncTwin

In [32]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/synctwin_hp.csv')).sort_values(by='valid_loss').iloc[0])
synctwin_config = {
    'hidden_dim': 128,
    'reg_B': 0.522652,
    'lam_express': 0.163847,
    'lam_recon': 0.39882,
    'lam_prognostic': 0.837303,
    'tau': 0.813696,
    'batch_size': 32,
    'epochs': 100,
    'lr': 0.001476,
    'gamma': 0.912894,
}

In [33]:
def infer_synctwin(trained_synctwin, N_test, Y_post_test):
    trained_synctwin.eval()
    return trained_synctwin._sync_twin.get_prognostics(
        torch.arange(0, N_test).cpu(),  
        torch.from_numpy(Y_post_test).float().cpu()
    )

In [34]:
def test_synctwin(
    N_train, 
    N_test, 
    X_train, 
    X_test, 
    M_train, 
    M_test, 
    Y_post_train, 
    Y_post_test, 
    A_train, 
    A_test, 
    T_train, 
    T_test,
    run_idx=0,
):
    Y_mask_train = np.all(A_train == 0, axis=1)
    Y_mask_test = np.all(A_test == 0, axis=1)
    Y_control_train = Y_post_train[Y_mask_train]

    trained_synctwin = train_synctwin(
        synctwin_config,
        X=X_train,
        M_=M_train,
        T=T_train,
        Y_batch=Y_post_train,
        Y_control=Y_control_train,
        Y_mask=Y_mask_train, 
        N=N_train,
        D=X_train.shape[2],
        n_treated=N_train - Y_control_train.shape[0],
        pre_trt_x_len=R * M,
        test_run=run_idx,
    )

    trained_synctwin.eval()
    _, l1_loss = trained_synctwin(
        torch.from_numpy(X_test).float(),
        torch.from_numpy(T_test).float(),
        torch.from_numpy(M_test).float(),
        torch.arange(0, N_test),
        torch.from_numpy(Y_post_test).float(),
        torch.from_numpy(Y_mask_test).float(),
    )
    return l1_loss.item()

## Test Models

In [35]:
N_TEST = 5

In [36]:
def run_tests():
    mclatte_losses = []
    semi_skimmed_mclatte_losses = []
    skimmed_mclatte_losses = []
    rnn_losses = []
    synctwin_losses = []
    for i in range(1, N_TEST + 1):
        (
            _, 
            N_train, 
            N_test, 
            X_train, 
            X_test, 
            M_train, 
            M_test, 
            Y_pre_train, 
            Y_pre_test, 
            Y_post_train, 
            Y_post_test, 
            A_train, 
            A_test, 
            T_train, 
            T_test,
        ) = generate_data()

        mclatte_losses.append(test_mclatte(
            X_train, 
            X_test, 
            M_train, 
            M_test, 
            Y_pre_train, 
            Y_post_train, 
            Y_post_test, 
            A_train, 
            A_test, 
            T_train, 
            T_test,
            run_idx=i,
        ))
        semi_skimmed_mclatte_losses.append(test_semi_skimmed_mclatte(
            X_train, 
            X_test, 
            M_train, 
            M_test, 
            Y_pre_train, 
            Y_post_train, 
            Y_post_test, 
            A_train, 
            A_test, 
            T_train, 
            T_test,
            run_idx=i,
        ))
        skimmed_mclatte_losses.append(test_skimmed_mclatte(
            X_train, 
            X_test, 
            M_train, 
            M_test, 
            Y_pre_train, 
            Y_post_train, 
            Y_post_test, 
            A_train, 
            A_test, 
            T_train, 
            T_test,
            run_idx=i,
        ))

        rnn_losses.append(test_rnn(
            Y_pre_train, 
            Y_pre_test, 
            Y_post_train, 
            Y_post_test, 
            run_idx=i,
        ))

        """
        synctwin_losses.append(test_synctwin(
            N_train, 
            N_test, 
            X_train, 
            X_test, 
            M_train, 
            M_test, 
            Y_post_train, 
            Y_post_test, 
            A_train, 
            A_test, 
            T_train, 
            T_test,
            run_idx=i,
        ))
        """
        
        joblib.dump((
            mclatte_losses, 
            semi_skimmed_mclatte_losses, 
            skimmed_mclatte_losses, 
            rnn_losses,
            synctwin_losses,
        ), f'results/test/diabetes.joblib')

In [37]:
run_tests()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 808   
1 | _decoder | LstmDecoder | 684   
-----------------------------------------
1.6 K     Trainable params
0         Non-trainable params
1.6 K     Total params
0.006     Total estimated model params size (MB)


[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/mclatte_1.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 808   
1 | _decoder | LstmDecoder | 684   
-----------------------------------------
1.6 K     Trainable params
0         Non-trainable params
1.6 K     Total params
0.006     Total estimated model params size (MB)


[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/mclatte_2.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 808   
1 | _decoder | LstmDecoder | 684   
-----------------------------------------
1.6 K     Trainable params
0         Non-trainable params
1.6 K     Total params
0.006     Total estimated model params size (MB)


[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/mclatte_3.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 808   
1 | _decoder | LstmDecoder | 684   
-----------------------------------------
1.6 K     Trainable params
0         Non-trainable params
1.6 K     Total params
0.006     Total estimated model params size (MB)


[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/mclatte_4.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 808   
1 | _decoder | LstmDecoder | 684   
-----------------------------------------
1.6 K     Trainable params
0         Non-trainable params
1.6 K     Total params
0.006     Total estimated model params size (MB)


[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/mclatte_5.ckpt'


#### Check finished runs results

In [40]:
def print_losses():
    all_losses = joblib.load(f'results_diabetes/maes/diabetes.joblib')
    for losses in all_losses:
        print(f'{np.mean(losses):.3f} ({np.std(losses):.3f})')

In [41]:
print_losses()

0.920 (0.015)
0.903 (0.018)
0.906 (0.001)
0.882 (0.014)
nan (nan)


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean, casting='unsafe',
  ret = ret.dtype.type(ret / rcount)


### Statistical Testing

In [42]:
LOSS_NAMES = ['McLatte', 'Semi-Skimmed McLatte', 'Skimmed McLatte', 'RNN', 'SyncTwin']

In [43]:
def test_losses(losses):
    t_test_results = pd.DataFrame(columns=LOSS_NAMES, index=LOSS_NAMES)

    for i in range(len(losses)):
        for j in range(len(losses)):
            t = ttest_ind(losses[i], losses[j], alternative='less')
            t_test_results[LOSS_NAMES[i]][LOSS_NAMES[j]] = t.pvalue
    return t_test_results

In [44]:
losses = joblib.load(f'results_diabetes/maes/diabetes.joblib')[:4]
test_losses(losses)

Unnamed: 0,McLatte,Semi-Skimmed McLatte,Skimmed McLatte,RNN,SyncTwin
McLatte,0.5,0.086596,0.049252,0.002647,
Semi-Skimmed McLatte,0.913404,0.5,0.637164,0.046576,
Skimmed McLatte,0.950748,0.362836,0.5,0.003574,
RNN,0.997353,0.953424,0.996426,0.5,
SyncTwin,,,,,


### Plot with trained models

In [None]:
PLOT_NAME = {
    'skimmed_mclatte': 'S',
    'semi_skimmed_mclatte': 'SS',
    'mclatte': 'V',
    'rnn': 'RNN',
    'synctwin': 'SyncTwin',
}

In [None]:
def line_model_pred(fig, model, name, infer_model, plot_sub_id, post_t, y_pre_plot, file_suffix, *infer_args):
    trained_model = model.load_from_checkpoint(os.path.join(os.getcwd(), f'results_idt/trained_models/{name}.ckpt'))
    y_tilde = infer_model(trained_model, *infer_args)
    y_pred_plot = [y_pre_plot[-1]] + list(y_tilde.detach().numpy()[plot_sub_id])
    line_pred_model = go.Scatter(x=post_t, y=y_pred_plot, name=f'{PLOT_NAME[name]}{file_suffix}', line={'dash': 'dash'})
    fig.add_trace(line_pred_model)

In [None]:
def plot_subject(fig, plot_sub_id, N_test, X_test, M_test, post_t, y_pre_plot, A_test, T_test, Y_pre_test, Y_post_test, file_suffix=''):
    line_model_pred(fig, SkimmedMcLatte, 'skimmed_mclatte', lambda *args: infer_mcespresso(*args)[1], 
                    plot_sub_id, post_t, y_pre_plot, file_suffix, X_test, A_test, T_test, M_test)
    line_model_pred(fig, SemiSkimmedMcLatte, 'semi_skimmed_mclatte', lambda *args: infer_mcespresso(*args)[2], 
                    plot_sub_id, post_t, y_pre_plot, file_suffix, X_test, A_test, T_test, M_test)
    line_model_pred(fig, McLatte, 'mclatte', lambda *args: infer_mcespresso(*args)[2], 
                    plot_sub_id, post_t, y_pre_plot, file_suffix, X_test, A_test, T_test, M_test)
    line_model_pred(fig, BaselineRnn, 'rnn', lambda *args: infer_rnn(*args, return_Y_pred=True), 
                    plot_sub_id, post_t, y_pre_plot, file_suffix, Y_pre_test, Y_post_test)
    line_model_pred(fig, SyncTwinPl, 'synctwin', lambda *args: infer_synctwin(*args), 
                    plot_sub_id, post_t, y_pre_plot, file_suffix, N_test, Y_post_test)
    

In [None]:
def plot_config_results(file_suffix=''):
    (
        _, _, N_test, 
        _, X_test, 
        _, M_test, 
        _, Y_pre_test, 
        _, Y_post_test, 
        _, A_test, 
        _, T_test,
    ) = generate_data()

    for plot_sub_id in range(N_test):
        y_pre_plot = Y_pre_test[plot_sub_id]
        pre_t = list(np.arange(y_pre_plot.shape[0]) - y_pre_plot.shape[0])

        y_post_plot = [y_pre_plot[-1]] + list(Y_post_test[plot_sub_id])
        post_t = np.arange(len(y_post_plot))
        
        trt_str = ', '.join(map(
            lambda x: str(round(x, 2)) if abs(x - round(x)) > 5e-2 else str(int(x)), 
            A_test[plot_sub_id]
        ))

        fig = go.Figure()
        line_pre_trt = go.Scatter(x=pre_t + list(post_t), y=list(y_pre_plot) + y_post_plot, name='ground truth')
        fig.add_trace(line_pre_trt)

        plot_subject(fig, plot_sub_id, N_test, X_test, M_test, post_t, y_pre_plot, A_test, T_test, Y_pre_test, Y_post_test)
        A_test[plot_sub_id] = np.ones_like(A_test[plot_sub_id]) if not (A_test[plot_sub_id] == 0).all() else np.zeros_like(A_test[plot_sub_id])
        plot_subject(fig, plot_sub_id, N_test, X_test, M_test, post_t, y_pre_plot, A_test, T_test, Y_pre_test, Y_post_test, file_suffix=' 01')
        
        fig.update_layout(
            title=f'Outcome for Treatment Vector ({trt_str})', 
            yaxis_title='Outcome', 
            xaxis_title='Time',
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
        )
        fig.write_image(f'plots/diabetes/outcome_pred_{plot_sub_id}{file_suffix}.png')