In [None]:
import sys
IN_COLAB: bool = 'google.colab' in sys.modules
COLAB_PREFIX = 'comp9417_proj'
MODEL_OUT_DIR: str = 'models/out/nri'

if not IN_COLAB:
    %cd ..
    
    from models.nri import NRI, NRITrainingParams
    from models.train import CheckpointParameters, train, resume
    from experiments.helper import get_dataset

    print('Notebook running in local env')
else:
    !git clone https://github.com/D31T4/COMP9417-project.git
    !mv COMP9417-project {COLAB_PREFIX}

    MODEL_OUT_DIR = f'{COLAB_PREFIX}/{MODEL_OUT_DIR}'
    !mkdir {MODEL_OUT_DIR}
    print('Notebook running in colab')

    from comp9417_proj.data.preprocess import preprocess, DefaultInputDir, DefaultOutputDir
    from comp9417_proj.models.nri import NRI, NRITrainingParams
    from comp9417_proj.models.train import CheckpointParameters, train, resume
    from comp9417_proj.experiments.helper import get_dataset

    preprocess(f'{COLAB_PREFIX}/data/{DefaultInputDir}', f'{COLAB_PREFIX}/data/{DefaultOutputDir}')


In [2]:
import torch

train_set, val_set, test_set, adj_mat, edge_prior = get_dataset(IN_COLAB)

checkpt = CheckpointParameters(MODEL_OUT_DIR, 1)
train_params = NRITrainingParams(10)

model = NRI(state_dim=6, prior_steps=50, hid_dim=64, adj_mat=adj_mat, do_prob=0.5)
optimizer = torch.optim.Adam(list(model.parameters()), lr=5e-3)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)

In [3]:
train(
    model, 
    n_epoch=2, 
    datasets=(train_set, val_set, test_set), 
    edge_prior=edge_prior, 
    checkpoint_params=checkpt,
    train_params=train_params
)

[train] Epoch 0; NLL: 1.24E+02; KL: 1.90E+00: 100%|██████████| 538/538 [36:55<00:00,  4.12s/it]


[train] Epoch: 0000
          NLL: 3334.0653662274
           KL: 2.1264766348
          MSE: 0.0011340358
      Elapsed: 2215.5761s


[train] Epoch 0; NLL: 1.47E+03; KL: 1.24E+00: 100%|██████████| 176/176 [04:54<00:00,  1.68s/it]


[valid] Epoch: 0000
          NLL: 1631.0477735346
           KL: 1.8722662872
          MSE: 0.0005547781
      Elapsed: 2510.4371s


[train] Epoch 1; NLL: 9.22E+01; KL: 1.24E+00: 100%|██████████| 538/538 [36:17<00:00,  4.05s/it]


[train] Epoch: 0001
          NLL: 117.4287809918
           KL: 1.4961004423
          MSE: 0.0000399418
      Elapsed: 2177.2075s


[train] Epoch 1; NLL: 1.53E+03; KL: 1.12E+00: 100%|██████████| 176/176 [05:09<00:00,  1.76s/it]


[valid] Epoch: 0001
          NLL: 1324.8493204984
           KL: 1.0237590108
          MSE: 0.0004506290
      Elapsed: 2486.6350s
