In [1]:
import joblib
import numpy as np
import os
import pandas as pd
import torch
from mclatte.model import train_mclatte, train_semi_skimmed_mclatte, train_skimmed_mclatte
from rnn.model import train_baseline_rnn
from synctwin import io_utils
from synctwin.model import train_synctwin

## Data Generation

Constants used for generation

In [2]:
sim_id = '0.25_200'
seed = 509
model_id = ""
M = 5
H = 5
R = 5 
D = 3 
K = 1 
C = 3

In [3]:
def generate_data(p_0, N):
    base_path_data = f"data/pkpd/{p_0}_{N}-seed-{seed}"
    data_path = base_path_data + "/{}-{}.{}"

    # loading config and data
    io_utils.load_config(
        data_path, "train"
    )
    x_full, t_full, mask_full, _, y_full, _, _, _, _, _ = io_utils.load_tensor(data_path, "train", device='cpu')
    x_full_val, t_full_val, mask_full_val, _, y_full_val, _, _, _, _, _ = io_utils.load_tensor(data_path, "val", device='cpu')

    x = np.concatenate((x_full.cpu().numpy(), x_full_val.cpu().numpy()), axis=1)
    t = np.concatenate((t_full.cpu().numpy(), t_full_val.cpu().numpy()), axis=1)
    mask = np.concatenate((mask_full.cpu().numpy(), mask_full_val.cpu().numpy()), axis=1)
    y = np.concatenate((y_full.cpu().numpy(), y_full_val.cpu().numpy()), axis=1).squeeze()

    X = x.transpose((1, 0, 2))
    N = X.shape[0]
    print(N)
    rand_index = np.random.permutation(N)
    X = X[rand_index]
    M_ = mask.transpose((1, 0, 2))[rand_index]
    Y_pre = y.T[rand_index]
    Y_post = y.T[rand_index]
    A = np.concatenate((np.zeros((N // 4, 1)), np.ones((N // 4, 1)), np.zeros((N // 4, 1)), np.ones((N // 4, 1))), axis=0)[rand_index]
    T = t.transpose((1, 0, 2))[rand_index]

    N_train = round(N * 0.8)
    N_test = round(N * 0.2)
    X_train, X_test = X[:N_train], X[N_train:]
    M_train, M_test = M_[:N_train], M_[N_train:]
    Y_pre_train, Y_pre_test = Y_pre[:N_train], Y_pre[N_train:]
    Y_post_train, Y_post_test = Y_post[:N_train], Y_post[N_train:]
    A_train, A_test = A[:N_train], A[N_train:]
    T_train, T_test = T[:N_train], T[N_train:]

    return (
        (N_train, M, H, R, D, K, C, X_train, M_train, Y_pre_train, Y_post_train, A_train, T_train),
        (N_test, M, H, R, D, K, C, X_test, M_test, Y_pre_test, Y_post_test, A_test, T_test),
    )

In [4]:
# joblib.dump((N_train, M, H, R, D, K, C, X_train, M_train, Y_pre_train, Y_post_train, A_train, T_train), f'data/pkpd/data_{sim_id}.joblib')
# joblib.dump((N_test, M, H, R, D, K, C, X_test, M_test, Y_pre_test, Y_post_test, A_test, T_test), f'data/pkpd/test_data_{sim_id}.joblib')

In [5]:
def na_catcher(func):
    def wrapper_na_catcher(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            print(e)
            return np.nan
    return wrapper_na_catcher

## Modelling

### McLatte

#### Vanilla

In [6]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results_pkpd/mclatte_hp_pkpd.csv')).sort_values(by='valid_loss').iloc[0])
mclatte_config = {
    'encoder_class': 'lstm',
    'decoder_class': 'lstm',
    'hidden_dim': 64,
    'batch_size': 64,
    'epochs': 100,
    'lr': 0.001944,
    'gamma': 0.957115,
    'lambda_r': 0.311437,
    'lambda_d': 0.118073,
    'lambda_p': 0.49999,
}

In [7]:
def test_mclatte(
    X_train, 
    X_test, 
    M_train, 
    M_test, 
    Y_pre_train, 
    Y_post_train, 
    Y_post_test, 
    A_train, 
    A_test, 
    T_train, 
    T_test,
    run_idx=0,
):
    trained_mclatte = train_mclatte(
        mclatte_config,
        X_train,
        M_train,
        Y_pre_train,
        Y_post_train,
        A_train, 
        T_train,
        R,
        M,
        H,
        input_dim=D, 
        treatment_dim=K, 
        test_run=run_idx,
    )

    trained_mclatte.eval()
    _, _, y_tilde = trained_mclatte(
        torch.from_numpy(X_test).float(),
        torch.from_numpy(A_test).float(),
        torch.from_numpy(T_test).float(),
        torch.from_numpy(M_test).float(),
    )
    
    return torch.nn.functional.l1_loss(
        y_tilde, 
        torch.from_numpy(Y_post_test).float()
    ).item()

#### Semi-Skimmed

In [8]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results_pkpd/semi_skimmed_mclatte_hp_pkpd.csv')).sort_values(by='valid_loss').iloc[0])
semi_skimmed_mclatte_config = {
    'encoder_class': 'lstm',
    'decoder_class': 'lstm',
    'hidden_dim': 64,
    'batch_size': 64,
    'epochs': 100,
    'lr': 0.001944,
    'gamma': 0.957115,
    'lambda_r': 0.311437,
    'lambda_d': 0.118073,
    'lambda_p': 0.49999,
}

In [9]:
def test_semi_skimmed_mclatte(
    X_train, 
    X_test, 
    M_train, 
    M_test, 
    Y_pre_train, 
    Y_post_train, 
    Y_post_test, 
    A_train, 
    A_test, 
    T_train, 
    T_test,
    run_idx=0,
):
    trained_semi_skimmed_mclatte = train_semi_skimmed_mclatte(
        semi_skimmed_mclatte_config,
        X_train,
        M_train,
        Y_pre_train,
        Y_post_train,
        A_train, 
        T_train,
        R,
        M,
        H,
        input_dim=D, 
        treatment_dim=K, 
        test_run=run_idx,
    )

    trained_semi_skimmed_mclatte.eval()
    _, _, y_tilde = trained_semi_skimmed_mclatte(
        torch.from_numpy(X_test).float(),
        torch.from_numpy(A_test).float(),
        torch.from_numpy(T_test).float(),
        torch.from_numpy(M_test).float(),
    )
    
    return torch.nn.functional.l1_loss(
        y_tilde, 
        torch.from_numpy(Y_post_test).float()
    ).item()

#### Skimmed

In [10]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results_pkpd/skimmed_mclatte_hp_pkpd.csv')).sort_values(by='valid_loss').iloc[0])
skimmed_mclatte_config = {
    'encoder_class': 'lstm',
    'decoder_class': 'lstm',
    'hidden_dim': 64,
    'batch_size': 64,
    'epochs': 100,
    'lr': 0.021114,
    'gamma': 0.980614,
    'lambda_r': 0.093878,
    'lambda_s': 0.485204,
}

In [11]:
def test_skimmed_mclatte(
    X_train, 
    X_test, 
    M_train, 
    M_test, 
    Y_pre_train, 
    Y_post_train, 
    Y_post_test, 
    A_train, 
    A_test, 
    T_train, 
    T_test,
    run_idx=0,
):
    trained_skimmed_mclatte = train_skimmed_mclatte(
        skimmed_mclatte_config,
        X_train,
        M_train,
        Y_pre_train,
        Y_post_train,
        A_train, 
        T_train,
        R,
        M,
        H,
        input_dim=D, 
        treatment_dim=K, 
        test_run=run_idx,
    )

    trained_skimmed_mclatte.eval()
    _, y_tilde = trained_skimmed_mclatte(
        torch.from_numpy(X_test).float(),
        torch.from_numpy(A_test).float(),
        torch.from_numpy(T_test).float(),
        torch.from_numpy(M_test).float(),
    )
    
    return torch.nn.functional.l1_loss(
        y_tilde, 
        torch.from_numpy(Y_post_test).float()
    ).item()

### Baseline RNN

In [12]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/baseline_rnn_hp_pkpd.csv')).sort_values(by='valid_loss').iloc[0])
rnn_config = {
    'rnn_class': 'gru',
    'hidden_dim': 64,
    'seq_len': 2,
    'batch_size': 64,
    'epochs': 100,
    'lr': 0.025182,
    'gamma': 0.543008,
}

In [13]:
def rnn_predict(trained_rnn, Y_pre, Y_post):
    """
    Make predictions using results from previous time steps.
    """
    Y = Y_pre
    losses = 0.0
    for i in range(Y_post.shape[1]):
        Y_tilde = trained_rnn(
            torch.from_numpy(Y).float().unsqueeze(2)
        ).squeeze()

        Y = np.concatenate((
            Y[:, 1:], 
            Y_tilde.cpu().detach().numpy()[:, [-1]]
        ), axis=1)
        
        losses += torch.nn.functional.l1_loss(
            Y_tilde[:, -1], 
            torch.from_numpy(Y_post).float()[:, i]
        ).item()
    return losses / Y_post.shape[1]

In [14]:
def test_rnn(
    Y_pre_train, 
    Y_pre_test, 
    Y_post_train, 
    Y_post_test, 
    run_idx=0,
):
    trained_rnn = train_baseline_rnn(
        rnn_config,
        Y=np.concatenate((Y_pre_train, Y_post_train), axis=1),
        input_dim=1, 
        test_run=run_idx,
    )

    trained_rnn.eval()
    return rnn_predict(trained_rnn, Y_pre_test, Y_post_test)

### SyncTwin

In [15]:
# print(pd.read_csv(os.path.join(os.getcwd(), 'results/synctwin_hp_pkpd.csv')).sort_values(by='valid_loss').iloc[0])
synctwin_config = {
    'hidden_dim': 128,
    'reg_B': 0.778155,
    'lam_express': 0.658256,
    'lam_recon': 0.086627,
    'lam_prognostic': 0.631468,
    'tau': 0.911613,
    'batch_size': 32,
    'epochs': 100,
    'lr': 0.003222,
    'gamma': 0.572529,
}

In [16]:
def test_synctwin(
    N_train, 
    N_test, 
    X_train, 
    X_test, 
    M_train, 
    M_test, 
    Y_post_train, 
    Y_post_test, 
    A_train, 
    A_test, 
    T_train, 
    T_test,
    run_idx=0,
):
    Y_mask_train = np.all(A_train == 0, axis=1)
    Y_mask_test = np.all(A_test == 0, axis=1)
    Y_control_train = Y_post_train[Y_mask_train]

    trained_synctwin = train_synctwin(
        synctwin_config,
        X=X_train,
        M_=M_train,
        T=T_train,
        Y_batch=Y_post_train,
        Y_control=Y_control_train,
        Y_mask=Y_mask_train, 
        N=N_train,
        D=D,
        n_treated=N_train - Y_control_train.shape[0],
        pre_trt_x_len=R * M,
        test_run=run_idx,
    ).cuda()

    trained_synctwin.eval()
    _, l1_loss = trained_synctwin(
        torch.from_numpy(X_test).float().cuda(),
        torch.from_numpy(T_test).float().cuda(),
        torch.from_numpy(M_test).float().cuda(),
        torch.arange(0, N_test).cuda(),
        torch.from_numpy(Y_post_test).float().cuda(),
        torch.from_numpy(Y_mask_test).float().cuda(),
    )
    return l1_loss.item()

## Test Models

In [17]:
N_TEST = 5

In [18]:
TEST_CONFIGS = [
    [200, '0.1'],
    [200, '0.25'],
    [200, '0.5'],
    [1000, '0.1'],
    [1000, '0.25'],
    [1000, '0.5'],
]

In [20]:
for config_idx in range(len(TEST_CONFIGS)):
    config = TEST_CONFIGS[config_idx]
    mclatte_losses = []
    semi_skimmed_mclatte_losses = []
    skimmed_mclatte_losses = []
    rnn_losses = []
    synctwin_losses = []
    for i in range(N_TEST * config_idx + 1, N_TEST * (1 + config_idx) + 1):
        (
            (N_train, M, H, R, D, K, C, X_train, M_train, Y_pre_train, Y_post_train, A_train, T_train),
            (N_test, M, H, R, D, K, C, X_test, M_test, Y_pre_test, Y_post_test, A_test, T_test),
        ) = generate_data(config[1], config[0])

        mclatte_losses.append(test_mclatte(
            X_train, 
            X_test, 
            M_train, 
            M_test, 
            Y_pre_train, 
            Y_post_train, 
            Y_post_test, 
            A_train, 
            A_test, 
            T_train, 
            T_test,
            run_idx=i,
        ))
        semi_skimmed_mclatte_losses.append(test_semi_skimmed_mclatte(
            X_train, 
            X_test, 
            M_train, 
            M_test, 
            Y_pre_train, 
            Y_post_train, 
            Y_post_test, 
            A_train, 
            A_test, 
            T_train, 
            T_test,
            run_idx=i,
        ))
        skimmed_mclatte_losses.append(test_skimmed_mclatte(
            X_train, 
            X_test, 
            M_train, 
            M_test, 
            Y_pre_train, 
            Y_post_train, 
            Y_post_test, 
            A_train, 
            A_test, 
            T_train, 
            T_test,
            run_idx=i,
        ))

        rnn_losses.append(test_rnn(
            Y_pre_train,
            Y_pre_test,
            Y_post_train,
            Y_post_test,
            run_idx=i,
        ))

        synctwin_losses.append(test_synctwin(
            N_train, 
            N_test, 
            X_train, 
            X_test, 
            M_train, 
            M_test, 
            Y_post_train, 
            Y_post_test, 
            A_train, 
            A_test, 
            T_train, 
            T_test,
            run_idx=i,
        ))
        joblib.dump((
            config, 
            mclatte_losses, 
            semi_skimmed_mclatte_losses, 
            skimmed_mclatte_losses, 
            rnn_losses, 
            synctwin_losses
        ), f'results/test/config_{config_idx}_pkpd.joblib')

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 22.0 K
1 | _decoder | LstmDecoder | 33.5 K
-----------------------------------------
56.1 K    Trainable params
0         Non-trainable params
56.1 K    Total params
0.224     Total estimated model params size (MB)


800
[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/mclatte_1_pkpd.ckpt'


Session not detected. You should not be calling `report` outside `tune.run` or while using the class API. 
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/jasonyz/Documents/McLatte/.venv.mclatte/lib/python3.9/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/Users/jasonyz/Documents/McLatte/.venv.mclatte/lib/python3.9/site-packages/traitlets/config/application.py", line 846, in launch_instance
    app.start()
  File "/Users/jasonyz/Documents/McLatte/.venv.mclatte/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 677, in start
    self.io_loop.start()
  File "/Users/jasonyz/Documents/McLatte/.venv.mclatte/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 199, in

[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/semi_skimmed_mclatte_1_pkpd.ckpt'


Session not detected. You should not be calling `report` outside `tune.run` or while using the class API. 
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/jasonyz/Documents/McLatte/.venv.mclatte/lib/python3.9/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/Users/jasonyz/Documents/McLatte/.venv.mclatte/lib/python3.9/site-packages/traitlets/config/application.py", line 846, in launch_instance
    app.start()
  File "/Users/jasonyz/Documents/McLatte/.venv.mclatte/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 677, in start
    self.io_loop.start()
  File "/Users/jasonyz/Documents/McLatte/.venv.mclatte/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 199, in

800
[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/mclatte_2_pkpd.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 22.0 K
1 | _decoder | LstmDecoder | 33.5 K
-----------------------------------------
56.1 K    Trainable params
0         Non-trainable params
56.1 K    Total params
0.224     Total estimated model params size (MB)


[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/semi_skimmed_mclatte_2_pkpd.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 22.0 K
1 | _decoder | LstmDecoder | 33.5 K
-----------------------------------------
56.1 K    Trainable params
0         Non-trainable params
56.1 K    Total params
0.224     Total estimated model params size (MB)


800
[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/mclatte_3_pkpd.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 22.0 K
1 | _decoder | LstmDecoder | 33.5 K
-----------------------------------------
56.1 K    Trainable params
0         Non-trainable params
56.1 K    Total params
0.224     Total estimated model params size (MB)


[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/semi_skimmed_mclatte_3_pkpd.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 22.0 K
1 | _decoder | LstmDecoder | 33.5 K
-----------------------------------------
56.1 K    Trainable params
0         Non-trainable params
56.1 K    Total params
0.224     Total estimated model params size (MB)


800
[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/mclatte_4_pkpd.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 22.0 K
1 | _decoder | LstmDecoder | 33.5 K
-----------------------------------------
56.1 K    Trainable params
0         Non-trainable params
56.1 K    Total params
0.224     Total estimated model params size (MB)


[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/semi_skimmed_mclatte_4_pkpd.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 22.0 K
1 | _decoder | LstmDecoder | 33.5 K
-----------------------------------------
56.1 K    Trainable params
0         Non-trainable params
56.1 K    Total params
0.224     Total estimated model params size (MB)


800
[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/mclatte_5_pkpd.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 22.0 K
1 | _decoder | LstmDecoder | 33.5 K
-----------------------------------------
56.1 K    Trainable params
0         Non-trainable params
56.1 K    Total params
0.224     Total estimated model params size (MB)


[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/semi_skimmed_mclatte_5_pkpd.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 22.0 K
1 | _decoder | LstmDecoder | 33.5 K
-----------------------------------------
56.1 K    Trainable params
0         Non-trainable params
56.1 K    Total params
0.224     Total estimated model params size (MB)


800
[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/mclatte_6_pkpd.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 22.0 K
1 | _decoder | LstmDecoder | 33.5 K
-----------------------------------------
56.1 K    Trainable params
0         Non-trainable params
56.1 K    Total params
0.224     Total estimated model params size (MB)


[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/semi_skimmed_mclatte_6_pkpd.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 22.0 K
1 | _decoder | LstmDecoder | 33.5 K
-----------------------------------------
56.1 K    Trainable params
0         Non-trainable params
56.1 K    Total params
0.224     Total estimated model params size (MB)


800
[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/mclatte_7_pkpd.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 22.0 K
1 | _decoder | LstmDecoder | 33.5 K
-----------------------------------------
56.1 K    Trainable params
0         Non-trainable params
56.1 K    Total params
0.224     Total estimated model params size (MB)


[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/semi_skimmed_mclatte_7_pkpd.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 22.0 K
1 | _decoder | LstmDecoder | 33.5 K
-----------------------------------------
56.1 K    Trainable params
0         Non-trainable params
56.1 K    Total params
0.224     Total estimated model params size (MB)


800
[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/mclatte_8_pkpd.ckpt'


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 22.0 K
1 | _decoder | LstmDecoder | 33.5 K
-----------------------------------------
56.1 K    Trainable params
0         Non-trainable params
56.1 K    Total params
0.224     Total estimated model params size (MB)


[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/semi_skimmed_mclatte_8_pkpd.ckpt'


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name     | Type        | Params
-----------------------------------------
0 | _encoder | LstmEncoder | 22.0 K
1 | _decoder | LstmDecoder | 33.5 K
-----------------------------------------
56.1 K    Trainable params
0         Non-trainable params
56.1 K    Total params
0.224     Total estimated model params size (MB)


800
[Errno 2] No such file or directory: '/Users/jasonyz/Documents/McLatte/results/mclatte_9_pkpd.ckpt'
