## Exploratory Data Analysis

In [1]:
import joblib
import numpy as np
import os
import pandas as pd
import plotly.graph_objects as go
import ray
import torch
import wandb
from mclatte.model import train_mclatte
from mclatte.simulation_data import generate_simulation_data, TreatmentRepr
from ray import tune
from rnn.model import train_baseline_rnn
from synctwin.model import train_synctwin

In [2]:
np.random.seed(509)

In [2]:
wandb.init(project='mclatte-test', entity='jasonyz')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjasonyz[0m (use `wandb login --relogin` to force relogin)
  warn("The `IPython.html` package has been deprecated since IPython 4.0. "


In [3]:
ray.init(address=None)

{'node_ip_address': '127.0.0.1',
 'raylet_ip_address': '127.0.0.1',
 'redis_address': '127.0.0.1:6379',
 'object_store_address': 'tcp://127.0.0.1:64028',
 'raylet_socket_name': 'tcp://127.0.0.1:59951',
 'webui_url': None,
 'session_dir': 'C:\\Users\\Jason\\AppData\\Local\\Temp\\ray\\session_2021-12-10_14-32-27_890431_9528',
 'metrics_export_port': 61737,
 'node_id': 'ab1cee49e6887541ad2776ecd8517aeda619e137bfa0a8cc81b61a59'}

Experiment Constants

In [4]:
N_train, M, H, R, D, K, C, X_train, M_train, Y_pre_train, Y_post_train, A_train, T_train = joblib.load(
    os.path.join(os.getcwd(), f'data/simulation/data.joblib')
)

In [5]:
N_test, M, H, R, D, K, C, X_test, M_test, Y_pre_test, Y_post_test, A_test, T_test = joblib.load(
    os.path.join(os.getcwd(), 'data/simulation/test_data.joblib')
)

In [10]:
N = 250
N_train = 200
N_test = 50
M = 5
H = 5
R = 5
D = 10
K = 3
C = 4

In [14]:
X, M_, Y_pre, Y_post, A, T = generate_simulation_data(N, M, H, R, D, K, C, TreatmentRepr.BOUNDED)

In [15]:
X_train, X_test = X[:N_train], X[N_train:]
M_train, M_test = M_[:N_train], M_[N_train:]
Y_pre_train, Y_pre_test = Y_pre[:N_train], Y_pre[N_train:]
Y_post_train, Y_post_test = Y_post[:N_train], Y_post[N_train:]
A_train, A_test = A[:N_train], A[N_train:]
T_train, T_test = T[:N_train], T[N_train:]

In [16]:
joblib.dump((N_train, M, H, R, D, K, C, X_train, M_train, Y_pre_train, Y_post_train, A_train, T_train), 'data/simulation/data_uniform_200.joblib')
joblib.dump((N_test, M, H, R, D, K, C, X_test, M_test, Y_pre_test, Y_post_test, A_test, T_test), 'data/simulation/test_data_uniform_200.joblib')

['data/simulation/test_data_uniform_200.joblib']

### Visualizations

In [None]:
sample_ids = np.random.randint(N, size=10)

Covariates

In [None]:
fig = go.Figure()
for feature_idx in range(D):
    values = np.mean(X[sample_ids, :, feature_idx], axis=0)
    fig.add_trace(go.Scatter(x=list(range(R * M)), y=values, name=f'feature {feature_idx}'))
fig.update_layout(
    title='Average Covariate Values',
    xaxis_title='t',
    yaxis_title='Feature Value'
)
fig.show()

Treatment Causes

In [None]:
fig = go.Figure(data=go.Heatmap(z=A[sample_ids].T))
fig.update_layout(
    title='Treatment Causes',
    xaxis_title='Sample ID',
    yaxis_title='Cause'
)
fig.show()

Treatment Outcomes

In [None]:
Y_sampled = np.concatenate((Y_pre, Y_post), axis=1)[sample_ids, :]

In [None]:
fig = go.Figure()
for sample_idx in range(len(sample_ids)):
    values = Y_sampled[sample_idx, :]
    fig.add_trace(go.Scatter(x=list(range(M + H)), y=values, name=f'Sample {sample_idx}'))
fig.update_layout(
    title='Sampled Treatment Outcomes',
    xaxis_title='t',
    yaxis_title='Outcome Value'
)
fig.show()

## Modelling

Train McLatte and benchmark models to test their performance

### McLatte

In [6]:
pd.read_csv(os.path.join(os.getcwd(), 'results/mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0]

trial_id                                                          2897e_00014
loss                                                               119.389359
valid_loss                                                           1.869723
time_this_iter_s                                                    84.108797
done                                                                     True
timesteps_total                                                           NaN
episodes_total                                                            NaN
training_iteration                                                         13
experiment_id                                a84191b222894e398ba479af36b8b602
date                                                      2021-12-10_09-12-45
timestamp                                                          1639127565
time_total_s                                                       1118.42367
pid                                                             

In [7]:
mclatte_config = {
    'encoder_class': 'lstm',
    'decoder_class': 'lstm',
    'hidden_dim': 64,
    'batch_size': 64,
    'epochs': 100,
    'lr': 0.0151,
    'gamma': 0.986855,
    'lambda_r': 1.928836,
    'lambda_s': 0.042385,
}

In [8]:
mclatte_losses = []
for i in range(5):
    trained_mclatte = train_mclatte(
        mclatte_config,
        X_train,
        M_train,
        Y_pre_train,
        Y_post_train,
        A_train, 
        T_train,
        R,
        M,
        H,
        input_dim=D, 
        treatment_dim=K, 
        test_run=i,
    )
    x_tilde, y_tilde = trained_mclatte(
        torch.from_numpy(X_test).float(),
        torch.from_numpy(A_test).float(),
        torch.from_numpy(T_test).float(),
        torch.from_numpy(M_test).float(),
    )
    mclatte_losses.append(
        torch.nn.functional.l1_loss(
            y_tilde, 
            torch.from_numpy(Y_post_test).float()
        ).item()
    )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 23.9 K
1 | _decoder | LstmDecoder | 33.9 K
-----------------------------------------
58.1 K    Trainable params
0         Non-trainable params
58.1 K    Total params
0.232     Total estimated model params size (MB)
  rank_zero_deprecation(
  rank_zero_warn(
    runpy.run_module(module, run_name="__main__", alter_sys=False)
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.9_3.9.2544.0_x64__qbz5n2kfra8p0\lib\runpy.py", line 213, in run_module
    return _run_code(code, {}, init_globals, run_name, mod_spec)
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.9_3.9.2544.0_x64__qbz5n2kfra8p0\lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "c:\Users\Jason\Projects\McLatte\.

In [9]:
np.mean(mclatte_losses), np.std(mclatte_losses)

(1.4011790037155152, 0.06934993200345907)

[2m[36m(pid=None)[0m [2021-12-10 15:03:04,059 C 9900 36424] redis_client.cc:87:  Check failed: under_retry_limit Expected 1 Redis shard addresses, found 8104303535147543913
[2m[36m(pid=None)[0m *** StackTrace Information ***
[2m[36m(pid=None)[0m     configthreadlocale
[2m[36m(pid=None)[0m     BaseThreadInitThunk
[2m[36m(pid=None)[0m     RtlUserThreadStart
[2m[36m(pid=None)[0m 
[2m[36m(pid=None)[0m [2021-12-10 15:03:07,879 E 34864 5284] local_object_manager.cc:32: Plasma object ffffffffffffffffffffffffffffffffffffffff0100000001000000 was evicted before the raylet could pin it.
[2m[36m(pid=None)[0m [2021-12-10 15:03:07,880 E 34864 5284] local_object_manager.cc:32: Plasma object ffffffffffffffffffffffffffffffffffffffff0100000002000000 was evicted before the raylet could pin it.
 pid=15252)[0m GPU available: True, used: True
 pid=15252)[0m TPU available: False, using: 0 TPU cores
 pid=15252)[0m IPU available: False, using: 0 IPUs
 pid=15252)[0m 2021-12-10 15:03:

### Baseline RNN

In [None]:
hp_config = {
    'rnn_class': tune.choice(['rnn']),
    'hidden_dim': tune.choice([C]),
    'seq_len': tune.choice([32]),
    'batch_size': tune.choice([32]),
    'epochs': tune.choice([100]),
    'lr': tune.loguniform(1e-4, 1e-1),
    'gamma': tune.uniform(0.5, 0.99),
}
sync_config = tune.SyncConfig()

In [None]:
rnn_trainable = tune.with_parameters(
    train_baseline_rnn,
    Y=np.concatenate((Y_pre, Y_post), axis=1),
    input_dim=1, 
)

In [None]:
analysis = tune.run(
    rnn_trainable,
    name='tune_pl_baseline_rnn',
    local_dir=os.path.join(os.getcwd(), 'data'),
    sync_config=sync_config,
    resources_per_trial={
        "cpu": 8,
        "gpu": 1,
    },
    metric='valid_loss',
    mode='min',
    checkpoint_score_attr='valid_loss',
    keep_checkpoints_num=5,
    config=hp_config,
    num_samples=10,
    verbose=1,
    resume='AUTO',
)

### SyncTwin

In [15]:
hp_config = {
    'hidden_dim': tune.choice([C]),
    'reg_B': tune.choice([1]),
    'lam_express': tune.choice([1]),
    'lam_recon': tune.choice([1]),
    'lam_prognostic': tune.choice([1]),
    'tau': tune.choice([1]),
    'batch_size': tune.choice([32]),
    'epochs': tune.choice([100]),
    'lr': tune.loguniform(1e-4, 1e-1),
    'gamma': tune.uniform(0.5, 0.99),
}
sync_config = tune.SyncConfig()

In [16]:
Y_mask = np.all(A == 0, axis=1)
Y_control = Y_post[Y_mask]

In [17]:
Y_control.shape

(28, 5)

In [18]:
st_trainable = tune.with_parameters(
    train_synctwin,
    X=X,
    M_=M_,
    T=T,
    Y_batch=Y_post,
    Y_control=Y_control,
    Y_mask=Y_mask, 
    N=N,
    D=D,
    n_treated=N - Y_control.shape[0],
    pre_trt_x_len=R * M,
)

In [19]:
analysis = tune.run(
    st_trainable,
    name='tune_pl_sync_twin',
    local_dir=os.path.join(os.getcwd(), 'data'),
    sync_config=sync_config,
    resources_per_trial={
        "cpu": 8,
        "gpu": 1,
    },
    metric='valid_loss',
    mode='min',
    checkpoint_score_attr='valid_loss',
    keep_checkpoints_num=5,
    config=hp_config,
    num_samples=10,
    verbose=1,
    resume='AUTO',
)

2021-12-10 01:13:29,446	ERROR tune.py:622 -- Trials did not complete: [train_synctwin_917c5_00000, train_synctwin_917c5_00001, train_synctwin_917c5_00002, train_synctwin_917c5_00003, train_synctwin_917c5_00004, train_synctwin_917c5_00005, train_synctwin_917c5_00006, train_synctwin_917c5_00007, train_synctwin_917c5_00008, train_synctwin_917c5_00009]
2021-12-10 01:13:29,447	INFO tune.py:626 -- Total run time: 350.50 seconds (350.22 seconds for the tuning loop).
 pid=33736)[0m Traceback (most recent call last):
 pid=33736)[0m   File "<string>", line 1, in <module>
 pid=33736)[0m   File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.9_3.9.2544.0_x64__qbz5n2kfra8p0\lib\multiprocessing\spawn.py", line 116, in spawn_main
 pid=33736)[0m     exitcode = _main(fd, parent_sentinel)
 pid=33736)[0m   File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.9_3.9.2544.0_x64__qbz5n2kfra8p0\lib\multiprocessing\spawn.py", line 126, in _main
 pid=33736)[0m     self = re