## Idealised Hypothetical Disease Treatment

In [105]:
import joblib
import numpy as np
import os
import pandas as pd
import plotly.graph_objects as go
import random
import torch
import wandb
from mclatte.model import (
    train_mclatte, 
    train_semi_skimmed_mclatte, 
    train_skimmed_mclatte, 
    McLatte,
    SemiSkimmedMcLatte,
    SkimmedMcLatte,
)
from mclatte.simulation_data import generate_simulation_data, TreatmentRepr
from rnn.model import (
    train_baseline_rnn,
    BaselineRnn,
)
from scipy.stats import ttest_ind
from synctwin.model import (
    train_synctwin,
    SyncTwinPl,
)

In [4]:
random.seed(509)
np.random.seed(509)
torch.manual_seed(509)

<torch._C.Generator at 0x1b4ddc834f0>

Data Generation

In [5]:
M = 5
H = 5
R = 5
D = 10
K = 3
C = 4

In [6]:
def generate_data(N, p_0, mode, run_idx):
    data_path = f'data/test/idt_{N}_{p_0}_{mode}_{run_idx}.joblib'
    try:
        return joblib.load(data_path)
    except Exception as e:
        print(e)
        N_train = round(N * 0.8)
        N_test = round(N * 0.2)
        X, M_, Y_pre, Y_post, A, T = generate_simulation_data(N, M, H, R, D, K, C, mode, p_0)
        X_train, X_test = X[:N_train], X[N_train:]
        M_train, M_test = M_[:N_train], M_[N_train:]
        Y_pre_train, Y_pre_test = Y_pre[:N_train], Y_pre[N_train:]
        Y_post_train, Y_post_test = Y_post[:N_train], Y_post[N_train:]
        A_train, A_test = A[:N_train], A[N_train:]
        T_train, T_test = T[:N_train], T[N_train:]
        all_data = (
            N, N_train, N_test, 
            X_train, X_test, 
            M_train, M_test, 
            Y_pre_train, Y_pre_test, 
            Y_post_train, Y_post_test, 
            A_train, A_test, 
            T_train, T_test
        )
        joblib.dump(all_data, data_path)
        return all_data

In [7]:
# joblib.dump((N_train, M, H, R, D, K, C, X_train, M_train, Y_pre_train, Y_post_train, A_train, T_train), 'data/simulation/data_uniform_200.joblib')
# joblib.dump((N_test, M, H, R, D, K, C, X_test, M_test, Y_pre_test, Y_post_test, A_test, T_test), 'data/simulation/test_data_uniform_200.joblib')

### Visualizations

In [8]:
(
    N_visual, 
    N_train_visual, 
    N_test_visual, 
    X_train_visual, 
    X_test_visual, 
    M_train_visual, 
    M_test_visual, 
    Y_pre_train_visual, 
    Y_pre_test_visual, 
    Y_post_train_visual, 
    Y_post_test_visual, 
    A_train_visual, 
    A_test_visual, 
    T_train_visual, 
    T_test_visual,
) = generate_data(250, 0, TreatmentRepr.BINARY, 'visual')
X_visual = np.concatenate((X_train_visual, X_test_visual), axis=0)
A_visual = np.concatenate((A_train_visual, A_test_visual), axis=0)
Y_pre_visual = np.concatenate((Y_pre_train_visual, Y_pre_test_visual), axis=0)
Y_post_visual = np.concatenate((Y_post_train_visual, Y_post_test_visual), axis=0)
sample_ids = np.random.randint(N_visual, size=10)

[Errno 2] No such file or directory: 'data/test/idt_250_0_TreatmentRepr.BINARY_visual.joblib'


Covariates

In [9]:
fig = go.Figure()
for feature_idx in range(D):
    values = np.mean(X_visual[sample_ids, :, feature_idx], axis=0)
    fig.add_trace(go.Scatter(x=list(range(R * M)), y=values, name=f'feature {feature_idx}'))
fig.update_layout(
    title='Average Covariate Values',
    xaxis_title='t',
    yaxis_title='Feature Value'
)
fig.show()

Treatment Causes

In [10]:
fig = go.Figure(data=go.Heatmap(z=A_visual[sample_ids].T))
fig.update_layout(
    title='Treatment Causes',
    xaxis_title='Sample ID',
    yaxis_title='Cause'
)
fig.show()

Treatment Outcomes

In [11]:
Y_sampled = np.concatenate((Y_pre_visual, Y_post_visual), axis=1)[sample_ids, :]

In [12]:
fig = go.Figure()
for sample_idx in range(len(sample_ids)):
    values = Y_sampled[sample_idx, :]
    fig.add_trace(go.Scatter(x=list(range(M + H)), y=values, name=f'Sample {sample_idx}'))
fig.update_layout(
    title='Sampled Treatment Outcomes',
    xaxis_title='t',
    yaxis_title='Outcome Value'
)
fig.show()

## Modelling

In [13]:
wandb.init(project='mclatte-test', entity='jasonyz')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjasonyz[0m (use `wandb login --relogin` to force relogin)

The `IPython.html` package has been deprecated since IPython 4.0. You should import from `notebook` instead. `IPython.html.widgets` has moved to `ipywidgets`.



In [14]:
# N_train, M, H, R, D, K, C, X_train, M_train, Y_pre_train, Y_post_train, A_train, T_train = joblib.load(
#     os.path.join(os.getcwd(), f'data/simulation/data.joblib')
# )
# N_test, M, H, R, D, K, C, X_test, M_test, Y_pre_test, Y_post_test, A_test, T_test = joblib.load(
#     os.path.join(os.getcwd(), 'data/simulation/test_data.joblib')
# )

In [15]:
def na_catcher(func):
    def wrapper_na_catcher(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            print(e)
            return np.nan
    return wrapper_na_catcher

### McLatte

In [16]:
def infer_mcespresso(trained_mcespresso, X_test, A_test, T_test, M_test):
    trained_mcespresso.eval()
    return trained_mcespresso(
        torch.from_numpy(X_test).float(),
        torch.from_numpy(A_test).float(),
        torch.from_numpy(T_test).float(),
        torch.from_numpy(M_test).float(),
    )

#### Skimmed

In [17]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results_idt/skimmed_mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])
skimmed_mclatte_config = {
    'encoder_class': 'lstm',
    'decoder_class': 'lstm',
    'hidden_dim': 64,
    'batch_size': 64,
    'epochs': 100,
    'lr': 0.0151,
    'gamma': 0.986855,
    'lambda_r': 1.928836,
    'lambda_p': 0.042385,
}

In [18]:
def test_skimmed_mclatte(
    X_train, 
    X_test, 
    M_train, 
    M_test, 
    Y_pre_train, 
    Y_post_train, 
    Y_post_test, 
    A_train, 
    A_test, 
    T_train, 
    T_test,
    run_idx=0,
):
    trained_skimmed_mclatte = train_skimmed_mclatte(
        skimmed_mclatte_config,
        X_train,
        M_train,
        Y_pre_train,
        Y_post_train,
        A_train, 
        T_train,
        R,
        M,
        H,
        input_dim=D, 
        treatment_dim=K, 
        test_run=run_idx,
    )
    _, y_tilde = infer_mcespresso(
        trained_skimmed_mclatte, X_test, A_test, T_test, M_test
    )
    
    return torch.nn.functional.l1_loss(
        y_tilde, 
        torch.from_numpy(Y_post_test).float()
    ).item()

#### Semi-Skimmed

In [19]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results_idt/semi_skimmed_mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])
semi_skimmed_mclatte_config = {
    'encoder_class': 'lstm',
    'decoder_class': 'lstm',
    'hidden_dim': 16,
    'batch_size': 64,
    'epochs': 100,
    'lr': 0.120062,
    'gamma': 0.731629,
    'lambda_r': 0.016767,
    'lambda_d': 1.83538,
    'lambda_p': 1.509965,
}

In [20]:
def test_semi_skimmed_mclatte(
    X_train, 
    X_test, 
    M_train, 
    M_test, 
    Y_pre_train, 
    Y_post_train, 
    Y_post_test, 
    A_train, 
    A_test, 
    T_train, 
    T_test,
    run_idx=0,
):
    trained_semi_skimmed_mclatte = train_semi_skimmed_mclatte(
        semi_skimmed_mclatte_config,
        X_train,
        M_train,
        Y_pre_train,
        Y_post_train,
        A_train, 
        T_train,
        R,
        M,
        H,
        input_dim=D, 
        treatment_dim=K, 
        test_run=run_idx,
    )
    _, _, y_tilde = infer_mcespresso(
        trained_semi_skimmed_mclatte, X_test, A_test, T_test, M_test
    )
    
    return torch.nn.functional.l1_loss(
        y_tilde, 
        torch.from_numpy(Y_post_test).float()
    ).item()

#### Vanilla

In [21]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results_idt/mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])
mclatte_config = {
    'encoder_class': 'lstm',
    'decoder_class': 'lstm',
    'hidden_dim': 16,
    'batch_size': 64,
    'epochs': 100,
    'lr': 0.024468,
    'gamma': 0.740409,
    'lambda_r': 0.040299,
    'lambda_d': 0.034368,
    'lambda_p': 0.021351,
}

In [22]:
def test_mclatte(
    X_train, 
    X_test, 
    M_train, 
    M_test, 
    Y_pre_train, 
    Y_post_train, 
    Y_post_test, 
    A_train, 
    A_test, 
    T_train, 
    T_test,
    run_idx=0,
):
    trained_mclatte = train_mclatte(
        mclatte_config,
        X_train,
        M_train,
        Y_pre_train,
        Y_post_train,
        A_train, 
        T_train,
        R,
        M,
        H,
        input_dim=D, 
        treatment_dim=K, 
        test_run=run_idx,
    )
    _, _, y_tilde = infer_mcespresso(
        trained_mclatte, X_test, A_test, T_test, M_test
    )
    
    return torch.nn.functional.l1_loss(
        y_tilde, 
        torch.from_numpy(Y_post_test).float()
    ).item()

### Baseline RNN

In [23]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/baseline_rnn_hp.csv')).sort_values(by='valid_loss').iloc[0])
rnn_config = {
    'rnn_class': 'gru',
    'hidden_dim': 4,
    'seq_len': 4,
    'batch_size': 64,
    'epochs': 100,
    'lr': 0.048177,
    'gamma': 0.795612,
}

In [24]:
def rnn_predict(trained_rnn, Y_pre, Y_post, return_Y_pred=False):
    """
    Make predictions using results from previous time steps.
    """
    Y = Y_pre
    losses = 0.0
    Y_pred = []
    for i in range(Y_post.shape[1]):
        Y_tilde = trained_rnn(
            torch.from_numpy(Y).float().unsqueeze(2)
        ).squeeze()

        Y = np.concatenate((
            Y[:, 1:], 
            Y_tilde.cpu().detach().numpy()[:, [-1]]
        ), axis=1)
        
        losses += torch.nn.functional.l1_loss(
            Y_tilde[:, -1], 
            torch.from_numpy(Y_post).float()[:, i]
        ).item()
        Y_pred.append(Y_tilde[:, -1])
    if return_Y_pred:
        return torch.stack(Y_pred, 1)
    return losses / Y_post.shape[1]

In [25]:
def infer_rnn(trained_rnn, Y_pre_test, Y_post_test, return_Y_pred=False):
    trained_rnn.eval()
    return rnn_predict(trained_rnn, Y_pre_test, Y_post_test, return_Y_pred)

In [26]:
def test_rnn(
    Y_pre_train, 
    Y_pre_test, 
    Y_post_train, 
    Y_post_test, 
    run_idx=0,
):
    trained_rnn = train_baseline_rnn(
        rnn_config,
        Y=np.concatenate((Y_pre_train, Y_post_train), axis=1),
        input_dim=1, 
        test_run=run_idx,
    )
    return infer_rnn(trained_rnn, Y_pre_test, Y_post_test)

### SyncTwin

In [27]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/synctwin_hp.csv')).sort_values(by='valid_loss').iloc[0])
synctwin_config = {
    'hidden_dim': 32,
    'reg_B': 0.909119,
    'lam_express': 0.106598,
    'lam_recon': 0.441844,
    'lam_prognostic': 0.207286,
    'tau': 0.311216,
    'batch_size': 32,
    'epochs': 100,
    'lr': 0.000196,
    'gamma': 0.888244,
}

In [28]:
def infer_synctwin(trained_synctwin, N_test, Y_post_test):
    trained_synctwin.eval()
    return trained_synctwin._sync_twin.get_prognostics(
        torch.arange(0, N_test).cpu(),  
        torch.from_numpy(Y_post_test).float().cpu()
    )

In [29]:
def test_synctwin(
    N, 
    N_test, 
    X_train, 
    X_test, 
    M_train, 
    M_test, 
    Y_post_train, 
    Y_post_test, 
    A_train, 
    A_test, 
    T_train, 
    T_test,
    run_idx=0,
):
    Y_mask_train = np.all(A_train == 0, axis=1)
    Y_mask_test = np.all(A_test == 0, axis=1)
    Y_control_train = Y_post_train[Y_mask_train]

    trained_synctwin = train_synctwin(
        synctwin_config,
        X=X_train,
        M_=M_train,
        T=T_train,
        Y_batch=Y_post_train,
        Y_control=Y_control_train,
        Y_mask=Y_mask_train, 
        N=N,
        D=D,
        n_treated=N - Y_control_train.shape[0],
        pre_trt_x_len=R * M,
        test_run=run_idx,
    ).cpu()

    trained_synctwin.eval()
    _, l1_loss = trained_synctwin(
        torch.from_numpy(X_test).float().cpu(),
        torch.from_numpy(T_test).float().cpu(),
        torch.from_numpy(M_test).float().cpu(),
        torch.arange(0, N_test).cpu(),
        torch.from_numpy(Y_post_test).float().cpu(),
        torch.from_numpy(Y_mask_test).float().cpu(),
    )
    return l1_loss.item()

## Test Models

In [30]:
N_TEST = 5

In [31]:
TEST_CONFIGS = [
    [200, 0.1, TreatmentRepr.BINARY],
    [200, 0.5, TreatmentRepr.BINARY],
    [200, 0.1, TreatmentRepr.BOUNDED],
    [200, 0.5, TreatmentRepr.BOUNDED],
    [200, 0.1, TreatmentRepr.REAL_VALUED],
    [200, 0.5, TreatmentRepr.REAL_VALUED],
    # [1000, 0.1, TreatmentRepr.BINARY],
    # [1000, 0.5, TreatmentRepr.BINARY],
    # [1000, 0.1, TreatmentRepr.BOUNDED],
    # [1000, 0.5, TreatmentRepr.BOUNDED],
    # [1000, 0.1, TreatmentRepr.REAL_VALUED],
    # [1000, 0.5, TreatmentRepr.REAL_VALUED],
]

In [32]:
def run_tests():
    for config_idx in range(0):
        config = TEST_CONFIGS[config_idx]
        mclatte_losses = []
        semi_skimmed_mclatte_losses = []
        skimmed_mclatte_losses = []
        rnn_losses = []
        synctwin_losses = []
        for i in range(N_TEST * config_idx + 1, N_TEST * (1 + config_idx) + 1):
            (
                N, 
                N_train, 
                N_test, 
                X_train, 
                X_test, 
                M_train, 
                M_test, 
                Y_pre_train, 
                Y_pre_test, 
                Y_post_train, 
                Y_post_test, 
                A_train, 
                A_test, 
                T_train, 
                T_test,
            ) = generate_data(config[0], config[1], config[2], i)

            mclatte_losses.append(test_mclatte(
                X_train, 
                X_test, 
                M_train, 
                M_test, 
                Y_pre_train, 
                Y_post_train, 
                Y_post_test, 
                A_train, 
                A_test, 
                T_train, 
                T_test,
                run_idx=i,
            ))
            semi_skimmed_mclatte_losses.append(test_semi_skimmed_mclatte(
                X_train, 
                X_test, 
                M_train, 
                M_test, 
                Y_pre_train, 
                Y_post_train, 
                Y_post_test, 
                A_train, 
                A_test, 
                T_train, 
                T_test,
                run_idx=i,
            ))
            skimmed_mclatte_losses.append(test_skimmed_mclatte(
                X_train, 
                X_test, 
                M_train, 
                M_test, 
                Y_pre_train, 
                Y_post_train, 
                Y_post_test, 
                A_train, 
                A_test, 
                T_train, 
                T_test,
                run_idx=i,
            ))

            rnn_losses.append(test_rnn(
                Y_pre_train, 
                Y_pre_test, 
                Y_post_train, 
                Y_post_test, 
                run_idx=i,
            ))

            synctwin_losses.append(test_synctwin(
                N_train, 
                N_test, 
                X_train, 
                X_test, 
                M_train, 
                M_test, 
                Y_post_train, 
                Y_post_test, 
                A_train, 
                A_test, 
                T_train, 
                T_test,
                run_idx=i,
            ))
            joblib.dump((
                config, 
                mclatte_losses, 
                semi_skimmed_mclatte_losses, 
                skimmed_mclatte_losses, 
                rnn_losses,
                synctwin_losses,
            ), f'results/test/config_{config_idx}_idt.joblib')

#### Check finished runs results

In [33]:
def print_losses():
    (
        config, 
        mclatte_losses, 
        semi_skimmed_mclatte_losses, 
        skimmed_mclatte_losses, 
        rnn_losses,
        synctwin_losses,
    ) = joblib.load(f'results_idt/maes/config_0_idt.joblib')
    for losses in mclatte_losses, synctwin_losses, rnn_losses:
        print(f'{np.mean(losses):.3f} ({np.std(losses):.3f})')

### Statistical Testing

In [108]:
LOSS_NAMES = ['McLatte', 'Semi-Skimmed McLatte', 'Skimmed McLatte', 'RNN', 'SyncTwin']

In [130]:
def test_losses(losses):
    t_test_results = pd.DataFrame(columns=LOSS_NAMES, index=LOSS_NAMES)

    for i in range(len(LOSS_NAMES)):
        for j in range(len(LOSS_NAMES)):
            t = ttest_ind(losses[i], losses[j], alternative='less')
            t_test_results[LOSS_NAMES[i]][LOSS_NAMES[j]] = t.pvalue
    return t_test_results

In [133]:
all_losses = [[] for _ in range(len(LOSS_NAMES))]
for config_id in range(len(TEST_CONFIGS)):
    _, *losses = joblib.load(f'results_idt/maes/config_{config_id}_idt.joblib')
    # print(test_losses(losses))
    for i in range(len(LOSS_NAMES)):
        all_losses[i] += losses[i]
test_losses(all_losses)

Unnamed: 0,McLatte,Semi-Skimmed McLatte,Skimmed McLatte,RNN,SyncTwin
McLatte,0.5,0.998033,0.95358,0.981433,0.958817
Semi-Skimmed McLatte,0.001967,0.5,0.077473,0.219616,0.117342
Skimmed McLatte,0.04642,0.922527,0.5,0.721141,0.56895
RNN,0.018567,0.780384,0.278859,0.5,0.346143
SyncTwin,0.041183,0.882658,0.43105,0.653857,0.5


### Plot with trained models

In [98]:
PLOT_NAME = {
    'skimmed_mclatte': 'S',
    'semi_skimmed_mclatte': 'SS',
    'mclatte': 'V',
    'rnn': 'RNN',
    'synctwin': 'SyncTwin',
}

In [99]:
def line_model_pred(fig, model, name, infer_model, config_id, plot_sub_id, post_t, y_pre_plot, file_suffix, *infer_args):
    trained_model = model.load_from_checkpoint(os.path.join(os.getcwd(), f'results_idt/trained_models/{name}_{config_id + 1}.ckpt'))
    y_tilde = infer_model(trained_model, *infer_args)
    y_pred_plot = [y_pre_plot[-1]] + list(y_tilde.detach().numpy()[plot_sub_id])
    line_pred_model = go.Scatter(x=post_t, y=y_pred_plot, name=f'{PLOT_NAME[name]}{file_suffix}', line={'dash': 'dash'})
    fig.add_trace(line_pred_model)

In [100]:
def plot_subject(fig, config_id, plot_sub_id, N_test, X_test, M_test, post_t, y_pre_plot, A_test, T_test, file_suffix=''):
    line_model_pred(fig, SkimmedMcLatte, 'skimmed_mclatte', lambda *args: infer_mcespresso(*args)[1], 
                    config_id, plot_sub_id, post_t, y_pre_plot, file_suffix, X_test, A_test, T_test, M_test)
    line_model_pred(fig, SemiSkimmedMcLatte, 'semi_skimmed_mclatte', lambda *args: infer_mcespresso(*args)[2], 
                    config_id, plot_sub_id, post_t, y_pre_plot, file_suffix, X_test, A_test, T_test, M_test)
    line_model_pred(fig, McLatte, 'mclatte', lambda *args: infer_mcespresso(*args)[2], 
                    config_id, plot_sub_id, post_t, y_pre_plot, file_suffix, X_test, A_test, T_test, M_test)
    # line_model_pred(fig, BaselineRnn, 'rnn', lambda *args: infer_rnn(*args, return_Y_pred=True), 
    #                 config_id, plot_sub_id, post_t, y_pre_plot, file_suffix, Y_pre_test, Y_post_test)
    # line_model_pred(fig, SyncTwinPl, 'synctwin', lambda *args: infer_synctwin(*args), 
    #                 config_id, plot_sub_id, post_t, y_pre_plot, file_suffix, N_test, Y_post_test)
    

In [101]:
def plot_config_results(config_id, file_suffix=''):
    config = TEST_CONFIGS[config_id]
    (
        _, _, N_test, 
        _, X_test, 
        _, M_test, 
        _, Y_pre_test, 
        _, Y_post_test, 
        _, A_test, 
        _, T_test,
    ) = generate_data(config[0], config[1], config[2], 0)

    for plot_sub_id in range(N_test):
        y_pre_plot = Y_pre_test[plot_sub_id]
        pre_t = list(np.arange(y_pre_plot.shape[0]) - y_pre_plot.shape[0])

        y_post_plot = [y_pre_plot[-1]] + list(Y_post_test[plot_sub_id])
        post_t = np.arange(len(y_post_plot))
        
        trt_str = ', '.join(map(
            lambda x: str(round(x, 2)) if abs(x - round(x)) > 5e-2 else str(int(x)), 
            A_test[plot_sub_id]
        ))

        fig = go.Figure()
        line_pre_trt = go.Scatter(x=pre_t + list(post_t), y=list(y_pre_plot) + y_post_plot, name='ground truth')
        fig.add_trace(line_pre_trt)

        plot_subject(fig, config_id, plot_sub_id, N_test, X_test, M_test, post_t, y_pre_plot, A_test, T_test)
        A_test[plot_sub_id] = np.ones_like(A_test[plot_sub_id]) if not (A_test[plot_sub_id] == 0).all() else np.zeros_like(A_test[plot_sub_id])
        plot_subject(fig, config_id, plot_sub_id, N_test, X_test, M_test, post_t, y_pre_plot, A_test, T_test, file_suffix=' 01')
        
        fig.update_layout(
            title=f'Outcome for Treatment Vector ({trt_str})', 
            yaxis_title='Outcome', 
            xaxis_title='Time',
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
        )
        fig.write_image(f'plots/idt/outcome_pred_{config_id}_{plot_sub_id}{file_suffix}.png')

In [102]:
for config_id in range(len(TEST_CONFIGS)):
    plot_config_results(config_id)

In [66]:
config = TEST_CONFIGS[4]
(
    _, _, N_test, 
    _, X_test, 
    _, M_test, 
    _, Y_pre_test, 
    _, Y_post_test, 
    _, A_test, 
    _, T_test,
) = generate_data(config[0], config[1], config[2], 0)
A_test[34] = np.ones_like(A_test[34])
plot_subject(4, 34, N_test, X_test, M_test, Y_pre_test, Y_post_test, A_test, T_test)