In [None]:
import sys
IN_COLAB: bool = 'google.colab' in sys.modules
COLAB_PREFIX = 'comp9417_proj'
MODEL_OUT_DIR: str = 'models/out/nri'
CUDA: bool = False

if not IN_COLAB:
    %cd ..
    
    from models.nri import NRI, NRITrainingParams
    from models.train import CheckpointParameters, train, resume
    from experiments.helper import get_dataset

    print('Notebook running in local env')
else:
    !git clone https://github.com/D31T4/COMP9417-project.git
    !mv COMP9417-project {COLAB_PREFIX}

    MODEL_OUT_DIR = f'{COLAB_PREFIX}/{MODEL_OUT_DIR}'
    !mkdir {COLAB_PREFIX}/models/out
    !mkdir {MODEL_OUT_DIR}
    print('Notebook running in colab')

    sys.path.insert(0, COLAB_PREFIX)
    from comp9417_proj.data.preprocess import preprocess, DefaultInputDir, DefaultOutputDir
    from comp9417_proj.models.nri import NRI, NRITrainingParams
    from comp9417_proj.models.train import CheckpointParameters, train, resume
    from comp9417_proj.experiments.helper import get_dataset

    preprocess(f'{COLAB_PREFIX}/data/{DefaultInputDir}', f'{COLAB_PREFIX}/data/{DefaultOutputDir}')


In [None]:
import torch

train_set, val_set, test_set, adj_mat, edge_prior = get_dataset(IN_COLAB)

checkpt = CheckpointParameters(MODEL_OUT_DIR, 1)

if IN_COLAB:
    from google.colab import files

    def onCheckpoint(prefix: str):
        files.download(f'{prefix}.loss.npy')
        files.download(f'{prefix}.lr.pt')
        files.download(f'{prefix}.optim.pt')
        files.download(f'{prefix}.model.pt')

    checkpt.onCheckpoint = onCheckpoint # comment this line to disable backup

train_params = NRITrainingParams(10)

model = NRI(state_dim=6, prior_steps=50, hid_dim=128, adj_mat=adj_mat, do_prob=0.5)
optimizer = torch.optim.Adam(list(model.parameters()), lr=5e-3)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)

if CUDA:
    model.cuda()
    edge_prior = edge_prior.cuda()

In [None]:
train(
    model, 
    n_epoch=30, 
    datasets=(train_set, val_set, test_set), 
    edge_prior=edge_prior, 
    checkpoint_params=checkpt,
    train_params=train_params
)

In [None]:
MODEL_OUT_DIR = 'models/out/grand_colab'

import numpy as np
stats = np.load(f'{MODEL_OUT_DIR}/checkpt_15.loss.npy', allow_pickle=True).item()
print(stats['valid_nll'])
print(stats['train_nll'])

In [None]:
MODEL_OUT_DIR = 'models/out/nri_colab'

import numpy as np
stats = np.load(f'{MODEL_OUT_DIR}/checkpt_20.loss.npy', allow_pickle=True).item()
print(stats['valid_nll'])
print(stats['train_nll'])
print(stats['valid_mse'])

print(stats['valid_kl'])
print(stats['train_kl'])