# Compare performance of continuous time models

Please, install `easy_tpp` and put [the original ODE-RNN implementation](https://github.com/YuliaRubanova/latent_ode/tree/master) to the `lib` folder.

In [1]:
import time
import torch
from easy_tpp.model.torch_model.torch_baselayer import DNN
from easy_tpp.utils import rk4_step_method
from easy_tpp.model.torch_model.torch_nhp import ContTimeLSTMCell as EasyContTimeLSTMCell
from easy_tpp.model.torch_model.torch_ode_tpp import NeuralODE as EasyODEGRUCell
from hotpp.data import PaddedBatch
from hotpp.nn.encoder.rnn.ctlstm import ContTimeLSTM as HotContTimeLSTM
from hotpp.nn.encoder.rnn.ode import ODEGRU as HotODEGRU
from lib.ode_func import ODEFunc
from lib.ode_rnn import ODE_RNN as LatODEGRU
from lib.diffeq_solver import DiffeqSolver
from lib.utils import create_net

def measure(model, n_trials=100):
    model.cuda()
    model.eval()
    x = torch.randn(64, 100, 64).cuda()
    dt = torch.rand(64, 100).cuda()
    torch.cuda.synchronize()
    start = time.time()
    with torch.no_grad():
        for _ in range(n_trials):
            out = model(x, dt)
    torch.cuda.synchronize()
    print("Forward", (time.time() - start) / n_trials)
    start = time.time()
    for _ in range(n_trials):
        model(x, dt)[0].mean().backward()
    torch.cuda.synchronize()
    print("FW / BW", (time.time() - start) / n_trials)

In [5]:
class EasyContTimeLSTM(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.cell = EasyContTimeLSTMCell(hidden_size)

    def forward(self, x, dt):
        all_hiddens = []
        all_outputs = []
        all_cells = []
        all_cell_bars = []
        all_decays = []

        h_t, c_t, c_bar_i = torch.zeros(len(x),
                                        3 * self.hidden_size,
                                        device=x.device).chunk(3, dim=1)
        b, l, d = x.shape
        for i in range(l):
            cell_i, c_bar_i, decay_i, output_i = self.cell(x[:, i], h_t, c_t, c_bar_i)
            c_t, h_t = self.cell.decay(cell_i, c_bar_i, decay_i, output_i, dt[:, i:i + 1])
            all_outputs.append(output_i)
            all_decays.append(decay_i)
            all_cells.append(cell_i)
            all_cell_bars.append(c_bar_i)
            all_hiddens.append(h_t)
        cell_stack = torch.stack(all_cells, dim=1)
        cell_bar_stack = torch.stack(all_cell_bars, dim=1)
        decay_stack = torch.stack(all_decays, dim=1)
        output_stack = torch.stack(all_outputs, dim=1)

        # [batch_size, max_seq_length, hidden_dim]
        hiddens_stack = torch.stack(all_hiddens, dim=1)

        # [batch_size, max_seq_length, 4, hidden_dim]
        decay_states_stack = torch.stack((cell_stack,
                                          cell_bar_stack,
                                          decay_stack,
                                          output_stack),
                                         dim=2)

        return hiddens_stack, decay_states_stack

class ContTimeLSTM(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.model = HotContTimeLSTM(hidden_size, hidden_size)

    def forward(self, x, dt):
        state = self.model.init_state.repeat(len(x), 1)
        return self.model._forward_loop(self.model.cell.preprocess(x), dt, state)

print("OUR CT-LSTM")
measure(ContTimeLSTM(64))

print("EesyTPP CT-LSTM")
measure(EasyContTimeLSTM(64))

OUR CT-LSTM
Forward 0.0024916505813598632
FW / BW 0.027592058181762694
EesyTPP CT-LSTM
Forward 0.015844099521636963
FW / BW 0.05194777488708496


In [4]:
class OrigODEGRU(LatODEGRU):
    def __init__(self, hidden_size):
        rec_ode_func = ODEFunc(
			input_dim = hidden_size, 
			latent_dim = hidden_size,
			ode_func_net = create_net(hidden_size, hidden_size, n_layers=0, n_units=hidden_size),
			device = "cuda").to("cuda")

        solver = DiffeqSolver(hidden_size, rec_ode_func, "rk4", hidden_size, 
			odeint_rtol = 1e-3, odeint_atol = 1e-4, device = "cuda")

        super().__init__(hidden_size, hidden_size, n_gru_units=hidden_size,
                         z0_diffeq_solver=solver)
        
    def forward(self, x, dt):
        mask = torch.full_like(x, True, dtype=torch.bool, device=x.device)
        pred_x, info = self.get_reconstruction(dt[0], x, dt[0], mask=mask, mode=None)
        return pred_x


class ODEGRU(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.model = HotODEGRU(hidden_size, hidden_size, diff_hidden_size=hidden_size)

    def forward(self, x, dt):
        state = self.model.init_state.repeat(len(x), 1)
        return self.model._forward_loop(x, dt, state)

print("Our ODE-RNN")
measure(ODEGRU(64))

print("Orig ODE-RNN")
measure(OrigODEGRU(64))

Our ODE-RNN
Forward 0.009947264194488525
FW / BW 0.04754711389541626
Orig ODE-RNN
Forward 0.08002629280090331
FW / BW 0.16193607568740845
