## Idealised Hypothetical Disease Treatment

In [1]:
import joblib
import numpy as np
import plotly.graph_objects as go
import random
import torch
import wandb
from plot_utils import plot_config_results
from test_utils import (
    test_skimmed_mclatte,
    test_semi_skimmed_mclatte,
    test_mclatte,
    test_rnn,
    test_synctwin,
    test_losses,
)
from mclatte.test_data.idt import (
    generate_data,
    SimDataGenConfig, 
    TreatmentRepr,
)

In [2]:
random.seed(509)
np.random.seed(509)
torch.manual_seed(509)

<torch._C.Generator at 0x7ff00084c710>

Constants for Data Generation

In [3]:
M = 5
H = 5
R = 5
D = 10
K = 3
C = 4
constants = dict(m=M, h=H, r=R, d=D, k=K, c=C)

### Visualizations

In [4]:
data_gen_config = SimDataGenConfig(
    n=250,
    p_0=0.1,
    mode=TreatmentRepr.BINARY,
    **constants
)
(
    N_visual,
    N_train_visual,
    N_test_visual,
    X_train_visual,
    X_test_visual,
    M_train_visual,
    M_test_visual,
    Y_pre_train_visual,
    Y_pre_test_visual,
    Y_post_train_visual,
    Y_post_test_visual,
    A_train_visual,
    A_test_visual,
    T_train_visual,
    T_test_visual,
) = generate_data(data_gen_config, "visual")
X_visual = np.concatenate((X_train_visual, X_test_visual), axis=0)
A_visual = np.concatenate((A_train_visual, A_test_visual), axis=0)
Y_pre_visual = np.concatenate((Y_pre_train_visual, Y_pre_test_visual), axis=0)
Y_post_visual = np.concatenate((Y_post_train_visual, Y_post_test_visual), axis=0)
sample_ids = np.random.randint(N_visual, size=10)

Covariates

In [5]:
fig = go.Figure()
for feature_idx in range(D):
    values = np.mean(X_visual[sample_ids, :, feature_idx], axis=0)
    fig.add_trace(
        go.Scatter(x=list(range(R * M)), y=values, name=f"feature {feature_idx}")
    )
fig.update_layout(
    title="Average Covariate Values", xaxis_title="t", yaxis_title="Feature Value"
)
fig.show()

Treatment Causes

In [6]:
fig = go.Figure(data=go.Heatmap(z=A_visual[sample_ids].T))
fig.update_layout(
    title="Treatment Causes", xaxis_title="Sample ID", yaxis_title="Cause"
)
fig.show()

Treatment Outcomes

In [7]:
Y_sampled = np.concatenate((Y_pre_visual, Y_post_visual), axis=1)[sample_ids, :]

In [8]:
fig = go.Figure()
for sample_idx in range(len(sample_ids)):
    values = Y_sampled[sample_idx, :]
    fig.add_trace(
        go.Scatter(x=list(range(M + H)), y=values, name=f"Sample {sample_idx}")
    )
fig.update_layout(
    title="Sampled Treatment Outcomes", xaxis_title="t", yaxis_title="Outcome Value"
)
fig.show()

## Modelling

In [9]:
wandb.init(project="mclatte-test", entity="jasonyz")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjasonyz[0m (use `wandb login --relogin` to force relogin)

The `IPython.html` package has been deprecated since IPython 4.0. You should import from `notebook` instead. `IPython.html.widgets` has moved to `ipywidgets`.

[34m[1mwandb[0m: wandb version 0.12.9 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


### McLatte

#### Skimmed

In [10]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results_idt/skimmed_mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])
skimmed_mclatte_config = {
    "encoder_class": "lstm",
    "decoder_class": "lstm",
    "hidden_dim": 64,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.0151,
    "gamma": 0.986855,
    "lambda_r": 1.928836,
    "lambda_p": 0.042385,
}

#### Semi-Skimmed

In [11]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results_idt/semi_skimmed_mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])
semi_skimmed_mclatte_config = {
    "encoder_class": "lstm",
    "decoder_class": "lstm",
    "hidden_dim": 16,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.120062,
    "gamma": 0.731629,
    "lambda_r": 0.016767,
    "lambda_d": 1.83538,
    "lambda_p": 1.509965,
}

#### Vanilla

In [12]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results_idt/mclatte_hp.csv')).sort_values(by='valid_loss').iloc[0])
mclatte_config = {
    "encoder_class": "lstm",
    "decoder_class": "lstm",
    "hidden_dim": 16,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.024468,
    "gamma": 0.740409,
    "lambda_r": 0.040299,
    "lambda_d": 0.034368,
    "lambda_p": 0.021351,
}

### Baseline RNN

In [13]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/baseline_rnn_hp.csv')).sort_values(by='valid_loss').iloc[0])
rnn_config = {
    "rnn_class": "gru",
    "hidden_dim": 4,
    "seq_len": 4,
    "batch_size": 64,
    "epochs": 100,
    "lr": 0.048177,
    "gamma": 0.795612,
}

### SyncTwin

In [14]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/synctwin_hp.csv')).sort_values(by='valid_loss').iloc[0])
synctwin_config = {
    "hidden_dim": 32,
    "reg_B": 0.909119,
    "lam_express": 0.106598,
    "lam_recon": 0.441844,
    "lam_prognostic": 0.207286,
    "tau": 0.311216,
    "batch_size": 32,
    "epochs": 100,
    "lr": 0.000196,
    "gamma": 0.888244,
}

## Test Models

In [15]:
N_TEST = 5

In [16]:
TEST_CONFIGS = [
    [200, 0.1, TreatmentRepr.BINARY],
    [200, 0.5, TreatmentRepr.BINARY],
    [200, 0.1, TreatmentRepr.BOUNDED],
    [200, 0.5, TreatmentRepr.BOUNDED],
    [200, 0.1, TreatmentRepr.REAL_VALUED],
    [200, 0.5, TreatmentRepr.REAL_VALUED],
    # [1000, 0.1, TreatmentRepr.BINARY],
    # [1000, 0.5, TreatmentRepr.BINARY],
    # [1000, 0.1, TreatmentRepr.BOUNDED],
    # [1000, 0.5, TreatmentRepr.BOUNDED],
    # [1000, 0.1, TreatmentRepr.REAL_VALUED],
    # [1000, 0.5, TreatmentRepr.REAL_VALUED],
]

In [17]:
def run_tests():
    for config_idx in range(len(TEST_CONFIGS)):
        config = TEST_CONFIGS[config_idx]
        data_gen_config = SimDataGenConfig(
            n=config[0],
            p_0=config[1],
            mode=config[2],
            **constants
        )
        mclatte_losses = []
        semi_skimmed_mclatte_losses = []
        skimmed_mclatte_losses = []
        rnn_losses = []
        synctwin_losses = []
        for i in range(N_TEST * config_idx + 1, N_TEST * (1 + config_idx) + 1):
            (
                _,
                train_data,
                test_data,
            ) = generate_data(data_gen_config, i, return_raw=False)

            skimmed_mclatte_losses.append(
                test_skimmed_mclatte(
                    skimmed_mclatte_config,
                    constants,
                    train_data,
                    test_data,
                    run_idx=i,
                )
            )
            semi_skimmed_mclatte_losses.append(
                test_semi_skimmed_mclatte(
                    semi_skimmed_mclatte_config,
                    constants,
                    train_data,
                    test_data,
                    run_idx=i,
                )
            )
            mclatte_losses.append(
                test_mclatte(
                    mclatte_config,
                    constants,
                    train_data,
                    test_data,
                    run_idx=i,
                )
            )

            rnn_losses.append(
                test_rnn(
                    rnn_config,
                    train_data,
                    test_data,
                    run_idx=i,
                )
            )

            synctwin_losses.append(
                test_synctwin(
                    synctwin_config,
                    constants,
                    train_data,
                    test_data,
                    run_idx=i,
                )
            )
            joblib.dump(
                (
                    config,
                    mclatte_losses,
                    semi_skimmed_mclatte_losses,
                    skimmed_mclatte_losses,
                    rnn_losses,
                    synctwin_losses,
                ),
                f"results/test/config_{config_idx}_idt.joblib",
            )

In [18]:
run_tests()

### Statistical Testing

In [19]:
LOSS_NAMES = ["McLatte", "Semi-Skimmed McLatte", "Skimmed McLatte", "RNN", "SyncTwin"]

In [20]:
all_losses = [[] for _ in range(len(LOSS_NAMES))]
for config_id in range(len(TEST_CONFIGS)):
    _, *losses = joblib.load(f"results/test/config_{config_id}_idt.joblib")
    for i in range(len(LOSS_NAMES)):
        all_losses[i] += losses[i]
test_losses(all_losses, LOSS_NAMES)

### Plot with trained models

In [23]:
for config_idx in range(len(TEST_CONFIGS)):
    config = TEST_CONFIGS[config_idx]
    data_gen_config = SimDataGenConfig(
        n=config[0],
        p_0=config[1],
        mode=config[2],
        **constants
    )
    plot_config_results("idt", generate_data, config_idx, data_gen_config)