
# Neural ODE Training

This notebook performs training of latent NODE models on sine wave data.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys, os
sys.path.append(os.path.abspath('../src'))
sys.path.append(os.path.abspath('..'))

import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR

device = torch.device('cuda:0')

import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm

# Custom lib
from model import LatentNeuralODEBuilder
from utils import gpu, asnp
from train import TrainingLoop
from train_piw import PIWTrainingLoop
from visualize_sine import plot_training_sine

## Data Generation: Sine Family

The point of this data is to provide a "library" of possible sine waves to train our classifier with.

In [None]:
data_path = "../data/sine_data_2021-04-09 00:13:41.249505"
generator = torch.load(data_path)['generator']

In [None]:
train_time, train_data = generator.get_train_set()
val_time, val_data = generator.get_val_set()

train_data = train_data.reshape(len(train_data), -1, 1)
val_data = val_data.reshape(len(val_data), -1, 1)

train_data_tt = gpu(train_data)
train_time_tt = gpu(train_time)

val_data_tt = gpu(val_data)
val_time_tt = gpu(val_time)

In [None]:
class SineSet(Dataset):
    def __init__(self, data, time):
        self.data = data
        self.time = time
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx], self.time
    
train_dataset = SineSet(train_data_tt, train_time_tt)
val_dataset = SineSet(val_data_tt, val_time_tt)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(train_dataset, batch_size=len(val_dataset))

# Training loop

In [None]:
model_args = {
    'obs_dim': 1,
    'rec_latent_dim': 8,
    'node_latent_dim': 4,
    
    'rec_gru_unit': 100,
    'rec_node_hidden': 100,
    'rec_node_layer': 2,
    'rec_node_act': 'Tanh',
    
    'latent_node_hidden': 100,
    'latent_node_layer': 2,
    'latent_node_act': 'Tanh',
    
    'dec_type': 'NN',
    'dec_hidden': 100,
    'dec_layer': 1,
    'dec_act': 'ReLU',
}

elbo_type = 'miw'

base_model = LatentNeuralODEBuilder(**model_args)
model = base_model.build_latent_node(elbo_type).to(device)

In [None]:
main = TrainingLoop(model, train_loader, val_loader, plot_training_sine)

In [None]:
lr = 1e-3

parameters = (model.parameters())
optimizer = optim.Adamax(parameters, lr=lr)
scheduler = ExponentialLR(optimizer, 0.995)

In [None]:
train_args = {
    'max_epoch': 500,
    'l_std': 0.1,
    'clip_norm': 5,
    'model_atol': 1e-4,
    'model_rtol': 1e-3,
    'plt_args': {'n_plot': 2},
    'M': 3,
    'K': 4,
    #'beta': 0.5,
    'method': 'dopri5',
}

In [None]:
main.train(optimizer, train_args, scheduler, plt_traj=True, plt_loss=True)

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'data_path': data_path,
    'model_args': model_args,
    'train_args': train_args,
    'train_obj': main,
}, '../models/sine_iwlatode_{}'.format(datetime.now()))

---
## Train PIWAE

In [None]:
model_args = {
    'obs_dim': 1,
    'rec_latent_dim': 8,
    'node_latent_dim': 4,
    
    'rec_gru_unit': 100,
    'rec_node_hidden': 100,
    'rec_node_layer': 2,
    'rec_node_act': 'Tanh',
    
    'latent_node_hidden': 100,
    'latent_node_layer': 2,
    'latent_node_act': 'Tanh',
    
    'dec_type': 'NN',
    'dec_hidden': 100,
    'dec_layer': 1,
    'dec_act': 'ReLU',
}

elbo_type = 'piw'

base_model = LatentNeuralODEBuilder(**model_args)
model = base_model.build_latent_node(elbo_type).to(device)

In [None]:
main = PIWTrainingLoop(model, train_loader, val_loader, plot_training_sine)

In [None]:
lr = 1e-3

inf_params = (model.enc.parameters())
inf_optim = optim.Adamax(inf_params, lr=lr)
inf_sched = ExponentialLR(inf_optim, 0.99)

node_params = list(model.nodef.parameters())
dec_params = list(model.dec.parameters())
gen_params = node_params + dec_params

gen_optim = optim.Adamax(gen_params, lr=lr)
gen_sched = ExponentialLR(gen_optim, 0.99)

In [None]:
train_args = {
    'max_epoch': 500,
    'l_std': 0.1,
    'clip_norm': 5,
    'model_atol': 1e-4,
    'model_rtol': 1e-3,
    'plt_args': {'n_plot': 2},
    'M': 3,
    'K': 2,
    'method': 'dopri5',
}

In [None]:
main.train(inf_optim, gen_optim, train_args, inf_sched, gen_sched, True, True, True)