## Idealised Hypothetical Disease Treatment

In [42]:
import joblib
import numpy as np
# import os
import plotly.graph_objects as go
import random
import torch
import wandb
from mclatte.model import train_mclatte
from mclatte.simulation_data import generate_simulation_data, TreatmentRepr
from rnn.model import train_baseline_rnn
from synctwin.model import train_synctwin

In [43]:
random.seed(509)
np.random.seed(509)
torch.manual_seed(509)

<torch._C.Generator at 0x2981fa863f0>

Data Generation

In [44]:
M = 5
H = 5
R = 5
D = 10
K = 3
C = 4

In [45]:
def generate_data(N, p_0, mode):
    N_train = round(N * 0.8)
    N_test = round(N * 0.2)
    X, M_, Y_pre, Y_post, A, T = generate_simulation_data(N, M, H, R, D, K, C, mode, p_0)
    X_train, X_test = X[:N_train], X[N_train:]
    M_train, M_test = M_[:N_train], M_[N_train:]
    Y_pre_train, Y_pre_test = Y_pre[:N_train], Y_pre[N_train:]
    Y_post_train, Y_post_test = Y_post[:N_train], Y_post[N_train:]
    A_train, A_test = A[:N_train], A[N_train:]
    T_train, T_test = T[:N_train], T[N_train:]
    return N, N_train, N_test, X_train, X_test, M_train, M_test, Y_pre_train, Y_pre_test, Y_post_train, Y_post_test, A_train, A_test, T_train, T_test

In [46]:
# joblib.dump((N_train, M, H, R, D, K, C, X_train, M_train, Y_pre_train, Y_post_train, A_train, T_train), 'data/simulation/data_uniform_200.joblib')
# joblib.dump((N_test, M, H, R, D, K, C, X_test, M_test, Y_pre_test, Y_post_test, A_test, T_test), 'data/simulation/test_data_uniform_200.joblib')

### Visualizations

In [47]:
(
    N_visual, 
    N_train_visual, 
    N_test_visual, 
    X_train_visual, 
    X_test_visual, 
    M_train_visual, 
    M_test_visual, 
    Y_pre_train_visual, 
    Y_pre_test_visual, 
    Y_post_train_visual, 
    Y_post_test_visual, 
    A_train_visual, 
    A_test_visual, 
    T_train_visual, 
    T_test_visual,
) = generate_data(250, 0, TreatmentRepr.BINARY)
X_visual = np.concatenate((X_train_visual, X_test_visual), axis=0)
A_visual = np.concatenate((A_train_visual, A_test_visual), axis=0)
Y_pre_visual = np.concatenate((Y_pre_train_visual, Y_pre_test_visual), axis=0)
Y_post_visual = np.concatenate((Y_post_train_visual, Y_post_test_visual), axis=0)
sample_ids = np.random.randint(N_visual, size=10)

Covariates

In [48]:
fig = go.Figure()
for feature_idx in range(D):
    values = np.mean(X_visual[sample_ids, :, feature_idx], axis=0)
    fig.add_trace(go.Scatter(x=list(range(R * M)), y=values, name=f'feature {feature_idx}'))
fig.update_layout(
    title='Average Covariate Values',
    xaxis_title='t',
    yaxis_title='Feature Value'
)
fig.show()

Treatment Causes

In [49]:
fig = go.Figure(data=go.Heatmap(z=A_visual[sample_ids].T))
fig.update_layout(
    title='Treatment Causes',
    xaxis_title='Sample ID',
    yaxis_title='Cause'
)
fig.show()

Treatment Outcomes

In [50]:
Y_sampled = np.concatenate((Y_pre_visual, Y_post_visual), axis=1)[sample_ids, :]

In [51]:
fig = go.Figure()
for sample_idx in range(len(sample_ids)):
    values = Y_sampled[sample_idx, :]
    fig.add_trace(go.Scatter(x=list(range(M + H)), y=values, name=f'Sample {sample_idx}'))
fig.update_layout(
    title='Sampled Treatment Outcomes',
    xaxis_title='t',
    yaxis_title='Outcome Value'
)
fig.show()

## Modelling

In [52]:
wandb.init(project='mclatte-test', entity='jasonyz')

0,1
epoch,▁▂▂▃▄▅▁▂▂▃▁▁▂▃▄▅▆▇█▁▁▂▃▁▁▂▃▃▄▅▆▂▁▂▃▄▅▆▇█
ptl/loss,██▆▅▆▆▁▁▁▃▃▂▂▂▃▂▂▂▂▂▁▁▁▅▄▅▅▄▅▄▁▂▂▁▁▂▂▁▂▁
ptl/valid_loss,▆▆▆▆▆▆██▂▂▂▄▃▃▃▃▃▃▃▇▇▁▁▁▂▆▄▄▄▄▄▆▄▂▂▂▂▂▂▂
trainer/global_step,▁▁▂▃▄▅▁▁▂▃▁▁▂▃▄▅▆▇█▁▁▂▃▁▁▂▃▃▄▅▆▁▁▂▃▄▅▆▇█

0,1
epoch,54.0
ptl/loss,2.90286
ptl/valid_loss,1.52671
trainer/global_step,659.0


In [53]:
# N_train, M, H, R, D, K, C, X_train, M_train, Y_pre_train, Y_post_train, A_train, T_train = joblib.load(
#     os.path.join(os.getcwd(), f'data/simulation/data.joblib')
# )
# N_test, M, H, R, D, K, C, X_test, M_test, Y_pre_test, Y_post_test, A_test, T_test = joblib.load(
#     os.path.join(os.getcwd(), 'data/simulation/test_data.joblib')
# )

In [54]:
def na_catcher(func):
    def wrapper_na_catcher(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            print(e)
            return np.nan
    return wrapper_na_catcher

### McLatte

In [55]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])
mclatte_config = {
    'encoder_class': 'lstm',
    'decoder_class': 'lstm',
    'hidden_dim': 64,
    'batch_size': 64,
    'epochs': 100,
    'lr': 0.0151,
    'gamma': 0.986855,
    'lambda_r': 1.928836,
    'lambda_s': 0.042385,
}

In [56]:
@na_catcher
def test_mclatte(
    X_train, 
    X_test, 
    M_train, 
    M_test, 
    Y_pre_train, 
    Y_post_train, 
    Y_post_test, 
    A_train, 
    A_test, 
    T_train, 
    T_test,
    run_idx=0,
):
    trained_mclatte = train_mclatte(
        mclatte_config,
        X_train,
        M_train,
        Y_pre_train,
        Y_post_train,
        A_train, 
        T_train,
        R,
        M,
        H,
        input_dim=D, 
        treatment_dim=K, 
        test_run=run_idx,
    )

    _, y_tilde = trained_mclatte(
        torch.from_numpy(X_test).float(),
        torch.from_numpy(A_test).float(),
        torch.from_numpy(T_test).float(),
        torch.from_numpy(M_test).float(),
    )
    
    return torch.nn.functional.l1_loss(
        y_tilde, 
        torch.from_numpy(Y_post_test).float()
    ).item()

### Baseline RNN

In [57]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/baseline_rnn_hp.csv')).sort_values(by='valid_loss').iloc[0])
rnn_config = {
    'rnn_class': 'gru',
    'hidden_dim': 4,
    'seq_len': 4,
    'batch_size': 64,
    'epochs': 100,
    'lr': 0.048177,
    'gamma': 0.795612,
}

In [58]:
def rnn_predict(trained_rnn, Y_pre, Y_post):
    """
    Make predictions using results from previous time steps.
    """
    Y = Y_pre
    losses = 0.0
    for i in range(Y_post.shape[1]):
        Y_tilde = trained_rnn(
            torch.from_numpy(Y).float().unsqueeze(2)
        ).squeeze()

        Y = np.concatenate((
            Y[:, 1:], 
            Y_tilde.cpu().detach().numpy()[:, [-1]]
        ), axis=1)
        
        losses += torch.nn.functional.l1_loss(
            Y_tilde[:, -1], 
            torch.from_numpy(Y_post).float()[:, i]
        ).item()
    return losses / Y_post.shape[1]

In [59]:
@na_catcher
def test_rnn(
    Y_pre_train, 
    Y_pre_test, 
    Y_post_train, 
    Y_post_test, 
    run_idx=0,
):
    trained_rnn = train_baseline_rnn(
        rnn_config,
        Y=np.concatenate((Y_pre_train, Y_post_train), axis=1),
        input_dim=1, 
        test_run=run_idx,
    )
    return rnn_predict(trained_rnn, Y_pre_test, Y_post_test)

### SyncTwin

In [60]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/synctwin_hp.csv')).sort_values(by='valid_loss').iloc[0])
synctwin_config = {
    'hidden_dim': 32,
    'reg_B': 0.909119,
    'lam_express': 0.106598,
    'lam_recon': 0.441844,
    'lam_prognostic': 0.207286,
    'tau': 0.311216,
    'batch_size': 32,
    'epochs': 100,
    'lr': 0.000196,
    'gamma': 0.888244,
}

In [61]:
@na_catcher
def test_synctwin(
    N, 
    N_test, 
    X_train, 
    X_test, 
    M_train, 
    M_test, 
    Y_post_train, 
    Y_post_test, 
    A_train, 
    A_test, 
    T_train, 
    T_test,
    run_idx=0,
):
    Y_mask_train = np.all(A_train == 0, axis=1)
    Y_mask_test = np.all(A_test == 0, axis=1)
    Y_control_train = Y_post_train[Y_mask_train]

    trained_synctwin = train_synctwin(
        synctwin_config,
        X=X_train,
        M_=M_train,
        T=T_train,
        Y_batch=Y_post_train,
        Y_control=Y_control_train,
        Y_mask=Y_mask_train, 
        N=N,
        D=D,
        n_treated=N - Y_control_train.shape[0],
        pre_trt_x_len=R * M,
        test_run=run_idx,
    ).cuda()

    _, l1_loss = trained_synctwin(
        torch.from_numpy(X_test).float().cuda(),
        torch.from_numpy(T_test).float().cuda(),
        torch.from_numpy(M_test).float().cuda(),
        torch.arange(0, N_test).cuda(),
        torch.from_numpy(Y_post_test).float().cuda(),
        torch.from_numpy(Y_mask_test).float().cuda(),
    )
    return l1_loss.item()

### Test Models

In [62]:
N_TEST = 5

In [63]:
TEST_CONFIGS = [
    [200, 0.1, TreatmentRepr.BINARY],
    [200, 0.5, TreatmentRepr.BINARY],
    [200, 0.1, TreatmentRepr.BOUNDED],
    [200, 0.5, TreatmentRepr.BOUNDED],
    [200, 0.1, TreatmentRepr.REAL_VALUED],
    [200, 0.5, TreatmentRepr.REAL_VALUED],
    [1000, 0.1, TreatmentRepr.BINARY],
    [1000, 0.5, TreatmentRepr.BINARY],
    [1000, 0.1, TreatmentRepr.BOUNDED],
    [1000, 0.5, TreatmentRepr.BOUNDED],
    [1000, 0.1, TreatmentRepr.REAL_VALUED],
    [1000, 0.5, TreatmentRepr.REAL_VALUED],
]

In [64]:
mclatte_losses_all = []
rnn_losses_all = []
synctwin_losses_all = []
for config_idx in range(1, len(TEST_CONFIGS)):
    config = TEST_CONFIGS[config_idx]
    mclatte_losses = []
    rnn_losses = []
    synctwin_losses = []
    for i in range(N_TEST * config_idx + 1, N_TEST * (1 + config_idx) + 1):
        (
            N, 
            N_train, 
            N_test, 
            X_train, 
            X_test, 
            M_train, 
            M_test, 
            Y_pre_train, 
            Y_pre_test, 
            Y_post_train, 
            Y_post_test, 
            A_train, 
            A_test, 
            T_train, 
            T_test,
        ) = generate_data(config[0], config[1], config[2])

        mclatte_losses.append(test_mclatte(
            X_train, 
            X_test, 
            M_train, 
            M_test, 
            Y_pre_train, 
            Y_post_train, 
            Y_post_test, 
            A_train, 
            A_test, 
            T_train, 
            T_test,
            run_idx=i,
        ))
        rnn_losses.append(test_rnn(Y_pre_train, Y_pre_test, Y_post_train, Y_post_test, i))
        synctwin_losses.append(test_synctwin(
            N, 
            N_test, 
            X_train, 
            X_test, 
            M_train, 
            M_test, 
            Y_post_train, 
            Y_post_test, 
            A_train, 
            A_test, 
            T_train, 
            T_test,
            run_idx=i,
        ))
        joblib.dump((mclatte_losses, rnn_losses, synctwin_losses), f'results/test/config_{config_idx}.joblib')
    mclatte_losses_all.append(mclatte_losses)
    rnn_losses_all.append(rnn_losses)
    synctwin_losses_all.append(synctwin_losses)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 23.9 K
1 | _decoder | LstmDecoder | 33.9 K
-----------------------------------------
58.1 K    Trainable params
0         Non-trainable params
58.1 K    Total params
0.232     Total estimated model params size (MB)


__init__() missing 8 required positional arguments: 'encoder', 'decoder', 'lambda_r', 'lambda_s', 'lr', 'gamma', 'post_trt_seq_len', and 'hidden_dim'



`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5.


The number of training samples (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Detected KeyboardInterrupt, attempting graceful shutdown...

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x0000029820AB0040>
Traceback (most recent call last):
  File "c:\Users\Jason\Projects\McLatte\.venv.mclatte\lib\site-packages\torch\utils\data\dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "c:\Users\Jason\Projects\McLatte\.venv.mclatte\lib\site-packages\torch\utils\data\dataloader.py", line 1295, in _shutdown_workers
    if self._persistent_workers or self._workers_status[worker_id]:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_workers_status'
GPU available: True, used: True
TPU available: False, us

__init__() missing 5 required positional arguments: 'rnn', 'hidden_dim', 'output_dim', 'lr', and 'gamma'



The number of training samples (12) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



In [None]:
np.mean(mclatte_losses), np.std(mclatte_losses)

(1.848403310775757, 0.4963271472037229)

In [None]:
mclatte_losses

[2.5393459796905518,
 1.2903814315795898,
 1.9774020910263062,
 1.2747408151626587,
 2.1601462364196777]

In [None]:
np.mean(rnn_losses), np.std(rnn_losses)

(2.476710448265076, 0.9416777155576778)

In [None]:
rnn_losses

[3.6125144004821776,
 1.5295768976211548,
 2.462286901473999,
 1.333347702026367,
 3.4458263397216795]

In [None]:
np.mean(np.nan_to_num(synctwin_losses, nan=4)), np.std(np.nan_to_num(synctwin_losses, nan=4))

(3.549545073509216, 0.9009098529815673)