## Analysis of the neural ODE on the CCT and CED benchmarks

In [1]:
import enum
import os
import argparse
import logging
import time
import numpy as np
import numpy.random as npr
import matplotlib

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm.auto import tqdm

from torchdiffeq import odeint #includes exogenous inputs

class LatentODEfunc(nn.Module):

    def __init__(self, latent_dim=4, nhidden=20, udim=None):
        super(LatentODEfunc, self).__init__()
        self.elu = nn.ELU(inplace=True)
        udim_val = 0 if udim is None else udim
        self.fc1 = nn.Linear(latent_dim + udim_val, nhidden)
        self.fc2 = nn.Linear(nhidden, nhidden)
        self.fc3 = nn.Linear(nhidden, latent_dim)
        self.nfe = 0

    def forward(self, t, x, u=None):
        self.nfe += 1
        if u is not None: #append the input to the state
            x = torch.cat([x,u[:,None] if u.ndim==1 else u],dim=1)
        out = self.fc1(x)
        out = self.elu(out)
        out = self.fc2(out)
        out = self.elu(out)
        out = self.fc3(out)
        return out


class RecognitionRNN(nn.Module):

    def __init__(self, latent_dim=4, obs_dim=2, nhidden=25, nbatch=1):
        super(RecognitionRNN, self).__init__()
        self.nhidden = nhidden
        self.nbatch = nbatch
        self.i2h = nn.Linear(obs_dim + nhidden, nhidden)
        self.h2o = nn.Linear(nhidden, latent_dim * 2)

    def forward(self, x, h):
        combined = torch.cat((x, h), dim=1)
        h = torch.tanh(self.i2h(combined))
        out = self.h2o(h)
        return out, h

    def initHidden(self):
        return torch.zeros(self.nbatch, self.nhidden)


class Decoder(nn.Module):

    def __init__(self, latent_dim=4, obs_dim=2, nhidden=20):
        super(Decoder, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(latent_dim, nhidden)
        self.fc2 = nn.Linear(nhidden, obs_dim)

    def forward(self, z):
        out = self.fc1(z)
        out = self.relu(out)
        out = self.fc2(out)
        return out


class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val


def log_normal_pdf(x, mean, logvar):
    const = torch.from_numpy(np.array([2. * np.pi])).float().to(x.device)
    const = torch.log(const)
    return -.5 * (const + logvar + (x - mean) ** 2. / torch.exp(logvar))


def normal_kl(mu1, lv1, mu2, lv2):
    v1 = torch.exp(lv1)
    v2 = torch.exp(lv2)
    lstd1 = lv1 / 2.
    lstd2 = lv2 / 2.

    kl = lstd2 - lstd1 + ((v1 + (mu1 - mu2) ** 2.) / (2. * v2)) - .5
    return kl

def get_train_val_test(dataset):
    import deepSI
    from deepSI import System_data_list
    if dataset=='CED':
        data_full = deepSI.datasets.CED()
        train = System_data_list([data_i[:300] for data_i in data_full])
        test = System_data_list([data_i[300:] for data_i in data_full])
        val = System_data_list([t[:100] for t in test])
    elif dataset=='CCT':
        train, test = deepSI.datasets.Cascaded_Tanks()
        val, test = test[:len(test)//2], test
    return train, val, test

def get_torch_data(dataset, device=torch.device('cpu')):
    train, val, test = get_train_val_test(dataset)
    from deepSI.system_data import System_data_norm
    norm = System_data_norm() #normalization
    norm.fit(train)
    train, val, test = [norm.transform(t) for t in [train, val, test]]

    if dataset=='CCT':
        samp_trajs = train.y[None,:,None]
        u_samp_trajs = train.u[None,:,None]
        orig_trajs = test.y[None,:,None]
        u_orig_trajs = test.u[None,:,None]
        val_trajs = val.y[None,:,None]
        u_val_trajs = val.u[None,:,None]

        sample_time = 4. #seconds
        samp_ts = np.arange(len(train.y))*sample_time

    elif dataset=='CED':
        samp_trajs = np.array([ti.y[:,None] for ti in train])
        u_samp_trajs =  np.array([ti.u for ti in train])
        orig_trajs = np.array([ti.y[:,None] for ti in test])
        u_orig_trajs = np.array([ti.u for ti in test])
        val_trajs = np.array([ti.y[:,None] for ti in val])
        u_val_trajs = np.array([ti.u[:,None] for ti in val])

        sample_time = 1/50 #seconds
        samp_ts = np.arange(len(train[0]))*sample_time

    orig_trajs = torch.from_numpy(orig_trajs).float().to(device) #samples, time
    u_orig_trajs = torch.from_numpy(u_orig_trajs).float().to(device)
    samp_trajs = torch.from_numpy(samp_trajs).float().to(device)
    u_samp_trajs = torch.from_numpy(u_samp_trajs).float().to(device)
    samp_ts = torch.from_numpy(samp_ts).float().to(device)
    u_val_trajs = torch.from_numpy(u_val_trajs).float().to(device)
    val_trajs = torch.from_numpy(val_trajs).float().to(device)
    
    return orig_trajs, u_orig_trajs, samp_trajs, u_samp_trajs, samp_ts, \
            u_val_trajs, val_trajs, sample_time, norm

In [4]:
# dataset='CED'
# # tau_1 = True
for dataset, tau_1 in [('CCT',False),('CED',False),('CED',True)]:
    method = 'rk4' #exogenous inputs only implemented for Euler, RK4 and midpoint.
    device = torch.device('cpu')
    train_dir = './models-neural-ode/'
    niters = 20000

    #given by latent_ODE:
    nhidden = 20
    rnn_nhidden = 25
    obs_dim = 1
    noise_std = 0.1



    orig_trajs, u_orig_trajs, samp_trajs, u_samp_trajs, samp_ts, u_val_trajs, val_trajs, sample_time, norm = \
            get_torch_data(dataset, device=device)

    if dataset=='CCT':
        dttau = 0.032
        latent_dim = 2

    elif dataset=='CED':
        dttau = 0.12
        latent_dim = 3

    tau = 1 if tau_1 else sample_time/dttau

    samp_ts /= tau #dt /dt * dttau = dttau



    rmslist = []
    for I in range(1,400):
        try:
            ckpt_path = os.path.join(train_dir, f'ckpt-best-{dataset}-{I}.pth' if tau_1==False else f'ckpt-best-{dataset}-{I}-tau_1.pth')
            out = torch.load(ckpt_path)
        except FileNotFoundError:
            continue
        print(f'ckpt-best-{dataset}-{I}.pth')
        func = out['func']
        rec = out['rec']
        dec = out['dec']

        h = rec.initHidden().to(device)
        for t in reversed(range(orig_trajs.size(1))):
            obs = orig_trajs[:, t, :]
            out, h = rec.forward(obs, h)
        qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
        epsilon = torch.randn(qz0_mean.size()).to(device)
        z0 = epsilon * torch.exp(.5 * qz0_logvar)*0 + qz0_mean #only mean

        # forward in time and solve ode for reconstructions
        pred_z = odeint(func, z0, samp_ts[:orig_trajs.shape[1]], u=u_orig_trajs, method=method).permute(1, 0, 2)
        pred_x = dec(pred_z)
        orig_trajs_p = pred_x.detach()

        # matplotlib.use()
        if dataset=='CCT':
    #         plt.figure(figsize=(12,3))
    #         plt.plot(orig_trajs[0,:,0].numpy())
    #         plt.plot(orig_trajs_p[0,:,0].numpy())
    #         plt.show()
            rms = torch.mean((orig_trajs - orig_trajs_p)**2).item()**0.5*norm.ystd
            rmslist.append(rms)
            print('RMS=',rms)
        if dataset=='CED':
    #         plt.figure(figsize=(12,3))
    #         plt.plot(orig_trajs[0,:,0].numpy())
    #         plt.plot(orig_trajs_p[0,:,0].numpy())
    #         plt.show()
    #         plt.figure(figsize=(12,3))
    #         plt.plot(orig_trajs[1,:,0].numpy())
    #         plt.plot(orig_trajs_p[1,:,0].numpy())
    #         plt.show()
            rms1 = torch.mean((orig_trajs[0] - orig_trajs_p[0])**2).item()**0.5*norm.ystd
            rms2 = torch.mean((orig_trajs[1] - orig_trajs_p[1])**2).item()**0.5*norm.ystd
            print('RMS set1=',rms1)
            print('RMS set2=',rms2)
            rmslist.append((rms1,rms2))

    rmslist = np.array(rmslist)

    print()
    print('###########################')
    print('###########################')
    if dataset=='CED':
        print("RMS results CED tau_1" if tau_1 else "RMS results CED")
        print('set 1 min=',np.min(rmslist[:,0],axis=0), 'mean=',np.mean(rmslist[:,0],axis=0), 'len=',len(rmslist))
        print('set 2 min=',np.min(rmslist[:,1],axis=0), 'mean=',np.mean(rmslist[:,1],axis=0), 'len=',len(rmslist))
    elif dataset=='CCT':
        print("RMS results CCT")
        print('min=',np.min(rmslist,axis=0), 'mean=',np.mean(rmslist,axis=0), 'len=',len(rmslist))
    print()
    print('###########################')
    print('###########################')

ckpt-best-CCT-1.pth
RMS= 0.28936187906794736
ckpt-best-CCT-2.pth
RMS= 0.31364721722852457
ckpt-best-CCT-3.pth
RMS= 0.2067198689884109
ckpt-best-CCT-4.pth
RMS= 0.6541493821647515
ckpt-best-CCT-5.pth
RMS= 0.3882140760448425
ckpt-best-CCT-6.pth
RMS= 0.23907775720886332
ckpt-best-CCT-7.pth
RMS= 0.2273063909022871
ckpt-best-CCT-8.pth
RMS= 0.2549915661111742
ckpt-best-CCT-9.pth
RMS= 0.3344552356823532
ckpt-best-CCT-10.pth
RMS= 0.30742240312827346
ckpt-best-CCT-11.pth
RMS= 0.2454621261155722
ckpt-best-CCT-12.pth
RMS= 0.31688610899904324
ckpt-best-CCT-13.pth
RMS= 0.3166067180697689
ckpt-best-CCT-14.pth
RMS= 0.4938528896709381
ckpt-best-CCT-15.pth
RMS= 0.22484235246757336
ckpt-best-CCT-101.pth
RMS= 0.23857751407904534
ckpt-best-CCT-102.pth
RMS= 0.9365915276210175
ckpt-best-CCT-103.pth
RMS= 0.30130051087175175
ckpt-best-CCT-104.pth
RMS= 0.3585485406642151
ckpt-best-CCT-106.pth
RMS= 0.17933939666348714
ckpt-best-CCT-107.pth
RMS= 0.2810473376241484
ckpt-best-CCT-108.pth
RMS= 0.2886307006668135
ckp