In [48]:
import os
import argparse
import logging
import time
import numpy as np
import numpy.random as npr
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchdiffeq import odeint_adjoint as odeint
from torch.utils.data import DataLoader
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [49]:
samp_trajs_TE = torch.load('samp_trajs_TE_tau4k6_25.pt')
samp_trajs_val_TE = torch.load('samp_trajs_val_TE_tau4k6_25.pt')
tau = 4
k = 6
mesured_dim = 12

trial_num = 16

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch = 1000 #for lstm256

ts_num = 0.33
tot_num = 25

samp_ts = np.linspace(0, ts_num, num=tot_num)
samp_ts = torch.from_numpy(samp_ts).float().to(device)

orig_trajs_TE = np.load('orig_trajs_TE_Stereo_Stim_tau4k6.npy')
samp_trajs_TE_test = orig_trajs_TE[:, :tot_num, :]
samp_trajs_TE_test = torch.from_numpy(samp_trajs_TE_test).float().to(device).reshape(trial_num, tot_num, mesured_dim*(k+1))

#Load to Dataloader
train_loader = DataLoader(dataset = samp_trajs_TE, batch_size = batch, shuffle = True, drop_last = True)
val_loader = DataLoader(dataset = samp_trajs_val_TE, batch_size = batch, shuffle = True, drop_last = True)

In [50]:
if not os.path.exists('model'):
           os.makedirs('model')
        
class LatentODEfunc(nn.Module):

    def __init__(self, latent_dim=8, nhidden=50):
        super(LatentODEfunc, self).__init__()
        #self.tanh = nn.ELU(inplace= True)
        self.tanh = nn.Tanh()
        self.fc1 = nn.Linear(latent_dim, nhidden)
        self.fc2 = nn.Linear(nhidden, nhidden)
        self.fc3 = nn.Linear(nhidden, latent_dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.fc1(x)
        out = self.tanh(out)
        out = self.fc2(out)
        out = self.tanh(out)
        out = self.fc3(out)
        return out

class RecognitionRNN(nn.Module):

    def __init__(self, latent_dim=8, obs_dim=46, nhidden=50, nbatch=1):
        super(RecognitionRNN, self).__init__()
        self.nhidden = nhidden
        self.nbatch = nbatch
        #self.h1o = nn.Linear(obs_dim, 8)
        self.h1o = nn.Linear(obs_dim, 36)
        self.h3o = nn.Linear(36, latent_dim*2)
        self.lstm = nn.LSTMCell(latent_dim*2, nhidden)
        self.tanh = nn.Tanh()
        self.h2o = nn.Linear(nhidden, latent_dim*2)

    def forward(self, x, h, c):
        xo = self.h1o(x)
        xo = self.tanh(xo)
        xxo = self.h3o(xo)
        hn, cn = self.lstm(xxo, (h,c))
        hn = self.tanh(hn)
        out = self.h2o(hn)
        return out, hn, cn
    

    def initHidden(self):
        return torch.zeros(1, self.nbatch, self.nhidden)


class Decoder(nn.Module):

    def __init__(self, latent_dim=8, obs_dim=46, nhidden=50):
        super(Decoder, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.tanh = nn.Tanh()
        self.fc1 = nn.Linear(latent_dim, nhidden)
        self.fc2 = nn.Linear(nhidden, nhidden*2)
        self.fc3 = nn.Linear(nhidden*2, obs_dim)

    def forward(self, z):
        out = self.fc1(z)
        out = self.tanh(out)
        out = self.fc2(out)
        out = self.tanh(out)
        out = self.fc3(out)
        return out


class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val


def log_normal_pdf(x, mean, logvar):
    const = torch.from_numpy(np.array([2. * np.pi])).float().to(x.device)
    const = torch.log(const)
    return -.5 * (const + logvar + (x - mean) ** 2. / torch.exp(logvar))

def mseloss(x, mean):
    loss = nn.MSELoss()
    return loss(x, mean)

def normal_kl(mu1, lv1, mu2, lv2):
    v1 = torch.exp(lv1)
    v2 = torch.exp(lv2)
    lstd1 = lv1 / 2.
    lstd2 = lv2 / 2.

    kl = lstd2 - lstd1 + ((v1 + (mu1 - mu2) ** 2.) / (2. * v2)) - .5
    return kl

def MSELoss(yhat, y):
    assert type(yhat) == torch.Tensor
    assert type(y) == torch.Tensor
    return torch.mean((yhat - y) ** 2)


def get_args():
    return {'latent_dim': latent_dim,
            'obs_dim': obs_dim,
            'nhidden': nhidden,
            'dec_nhidden' : dec_nhidden,
            'rnn_nhidden': rnn_nhidden,
            'device': device,
            'learning_rate': learning_rate,
            'tau': tau,
            'k': k}

def get_state_dicts():
    return {'odefunc_state_dict': func.state_dict(),
            'encoder_state_dict': rec.state_dict(),
            'decoder_state_dict': dec.state_dict()}

def data_get_dict():
    return {
        'samp_trajs_TE': samp_trajs_TE,
        'samp_trajs_val_TE': samp_trajs_val_TE,
        'samp_ts': samp_ts,
    }

def get_losses():
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_losses_k1': val_losses_k1,
        'val_losses_k2': val_losses_k2,
        'val_losses_k3': val_losses_k3,
        'val_losses_k4': val_losses_k4,
        'val_losses_k5': val_losses_k5,
        'val_losses_k6': val_losses_k6,
        'val_losses_k7': val_losses_k7,
        'val_losses_k8': val_losses_k8,
        'val_losses_k9': val_losses_k9,
    }

def save_model(Training_Trial, rnn_nhidden, tau, k, lr, latent_dim, itr):
    if not os.path.exists('/model'):
        os.makedirs('/model')
    save_dict = {
        'model_args': get_args(),
        'optimizer_state_dict': optimizer.state_dict(),
        #'data': data_get_dict(),
        'train_loss': get_losses()
    }
    
    save_dict.update(get_state_dicts())
    
    torch.save(save_dict, 'model/ODE_Xcoord_Trial{}_TakenEmbedding_rnn2_lstm{}_tau{}k{}_LSTM_lr{}_latent{}_LSTMautoencoder_Dataloader_epoch{}.pth'.format(Training_Trial, rnn_nhidden, tau, k, lr, latent_dim, itr))

    
def data_for_plot_graph(gen_index):
    with torch.no_grad():
        # sample from trajectorys' approx. posterior

        ts_pos = np.linspace(0, ts_num*gen_index, num=tot_num*gen_index)
        ts_pos = torch.from_numpy(ts_pos).float().to(device)
    
        h = torch.zeros(1, samp_trajs_TE_test.shape[0], rnn_nhidden).to(device)
        c = torch.zeros(1, samp_trajs_TE_test.shape[0], rnn_nhidden).to(device)
    
        hn = h[0, :, :]
        cn = c[0, :, :]
    
        for t in reversed(range(samp_trajs_TE_test.size(1))):
            obs = samp_trajs_TE_test[:, t, :]
            out, hn, cn = rec.forward(obs, hn, cn)
        qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
        epsilon = torch.randn(qz0_mean.size()).to(device)
        z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

        # forward in time and solve ode for reconstructions
        pred_z = odeint(func, z0, ts_pos).permute(1, 0, 2) #change time and batch with permute
        pred_x = dec(pred_z)
        
        return pred_x, pred_z
    
def plot_graph(gen_index, times_index, dataset_value, deriv_index, pred_x_forgraph, orig_trajs, itr, path):
    with torch.no_grad():
        orig_trajs_forgraph = orig_trajs
        ts_pos_combined = np.linspace(0, ts_num*gen_index, num=tot_num*gen_index) 
        
        fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(15, 9)) #####MAKE SURE ROW COL MATCHES THE NUM OF FEATURES
        axes = axes.flatten()
        
        for i, ax in enumerate(axes):
            ax.scatter(ts_pos_combined[times_index:times_index+tot_num*gen_index], orig_trajs_forgraph[dataset_value,times_index:tot_num*gen_index, i*(k+1)+deriv_index], label='sampled data', s = 5)
            ax.plot(ts_pos_combined[times_index:times_index+tot_num*gen_index], pred_x_forgraph[dataset_value, times_index:times_index+tot_num*gen_index, i*(k+1)+deriv_index], 'r',
                 label='learned trajectory (t>0)')
            ax.set_ylim(-2.5, 2.5)

        plt.legend()
        plot_name = 'lstm_datasetnum{}_latent{}_gen{}_deriv{}_epoch{}.png'.format(dataset_value, latent_dim, gen_index, deriv_index, itr)
        save_path = os.path.join(path, plot_name)
        plt.savefig(save_path, dpi=500)
        plt.close()
    

def plot_z_graph(gen_index, times_index, dataset_value, deriv_index, pred_z_forgraph, orig_trajs, itr, path):
    with torch.no_grad():
        orig_trajs_forgraph = orig_trajs
        out, hn, cn = rec.forward(orig_trajs)
        qz_mean, qz_logvar = out[:, :latent_dim], out[:, latent_dim:]
        epsilon = torch.randn(qz_mean.size()).to(device)
        z = epsilon * torch.exp(.5 * qz_logvar) + qz_mean
        
        z_forgraph = z.detach().cpu().numpy()
        ts_pos_combined = np.linspace(0, ts_num*gen_index, num=tot_num*gen_index) 
        
        fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(15, 9))
        axes = axes.flatten()
        
        for i, ax in enumerate(axes):
            ax.scatter(ts_pos_combined[times_index:50*gen_index], z_forgraph[dataset_value, 0:50*gen_index, i], label='sampled data', s = 5)
            ax.plot(ts_pos_combined[times_index:+50*gen_index], pred_z_forgraph[dataset_value, times_index:+50*gen_index, i], 'r',
                 label='learned trajectory (t>0)')
            ax.set_ylim(-2.5, 2.5)

        plt.legend()
        plot_name = 'Zgraph_lstm_datasetnum{}_latent{}_gen{}_deriv{}_epoch{}.png'.format(dataset_value, latent_dim, gen_index, deriv_index, itr)
        save_path = os.path.join(path, plot_name)
        plt.savefig(save_path, dpi=500)
        plt.close()

In [51]:
Training_Trial = 1
latent_dim = 8
nhidden = 64 ##Trial1 = 64, Trial2 = 128, Trial3 = 128, Trial4 = 64, Trial5 = 64
dec_nhidden = 32
obs_dim = 12*(k+1)
rnn_nhidden = 256
nitrs = 1000
noise_std = 0.2
learning_rate = 0.008

func = LatentODEfunc(latent_dim, nhidden).to(device)
rec = RecognitionRNN(latent_dim, obs_dim, rnn_nhidden, batch).to(device)
dec = Decoder(latent_dim, obs_dim, dec_nhidden).to(device)
params = (list(func.parameters()) + list(dec.parameters()) + list(rec.parameters()))
optimizer = optim.Adam(params, lr=learning_rate)
loss_meter = RunningAverageMeter()

train_losses = []
val_losses = []
val_losses_k1 = []
val_losses_k2 = []
val_losses_k3 = []
val_losses_k4 = []
val_losses_k5 = []
val_losses_k6 = []
val_losses_k7 = []
val_losses_k8 = []
val_losses_k9 = []
torch.cuda.empty_cache()

In [52]:
for itr in range(1, nitrs + 1):
    for data in train_loader:
        optimizer.zero_grad()
        h = rec.initHidden().to(device)
        c = rec.initHidden().to(device)
        hn = h[0, :, :]
        cn = c[0, :, :]
        for t in reversed(range(data.size(1))):
            obs = data[:, t, :]
            out, hn, cn = rec.forward(obs, hn, cn)
        qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
        epsilon = torch.randn(qz0_mean.size()).to(device)
        z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean   
        
        # forward in time and solve ode for reconstructions
        pred_z = odeint(func, z0, samp_ts).permute(1, 0, 2)
        pred_x = dec(pred_z)

        # compute loss
        loss = MSELoss(pred_x, data)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        
    with torch.no_grad():
        for data in val_loader:
            h = torch.zeros(1, batch, rnn_nhidden).to(device)
            c = torch.zeros(1, batch, rnn_nhidden).to(device)
            hn = h[0, :, :]
            cn = c[0, :, :]
            
            for t in reversed(range(data.size(1))):
                obs = data[:, t, :]
                out, hn, cn = rec.forward(obs, hn, cn)
            qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
            epsilon = torch.randn(qz0_mean.size()).to(device)
            z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

            #forward in time and solve ode for reconstructions
            pred_z = odeint(func, z0, samp_ts).permute(1, 0, 2)
            pred_x = dec(pred_z)
        
            #val_loss = MSELoss(pred_x, samp_trajs_val_TE) + torch.mean(-0.5 * torch.sum(1 + qz0_logvar - qz0_mean**2 - torch.exp(qz0_logvar), dim = -1)/(31*(k+1)+2))
            val_loss = MSELoss(pred_x[:,:,::(k+1)], data[:,:,::(k+1)])
            val_loss_k1 = MSELoss(pred_x[:,:,1::(k+1)], data[:,:,1::(k+1)])
            val_loss_k2 = MSELoss(pred_x[:,:,2::(k+1)], data[:,:,2::(k+1)])
            val_loss_k3 = MSELoss(pred_x[:,:,3::(k+1)], data[:,:,3::(k+1)])
            val_loss_k4 = MSELoss(pred_x[:,:,4::(k+1)], data[:,:,4::(k+1)])
            val_loss_k5 = MSELoss(pred_x[:,:,5::(k+1)], data[:,:,5::(k+1)])
            val_loss_k6 = MSELoss(pred_x[:,:,6::(k+1)], data[:,:,6::(k+1)])

            val_losses.append(val_loss)
            val_losses_k1.append(val_loss_k1)
            val_losses_k2.append(val_loss_k2)
            val_losses_k3.append(val_loss_k3)
            val_losses_k4.append(val_loss_k4)
            val_losses_k5.append(val_loss_k5)
            val_losses_k6.append(val_loss_k6)

            V = [val_loss, val_loss_k1, val_loss_k2, val_loss_k3, val_loss_k4, val_loss_k5, val_loss_k6]
            lowest_val_loss = torch.asarray(V).min(0)[0]
            deriv_index = torch.asarray(V).min(0)[1]

    if ((itr > 100) and (itr % 10 == 0)):
        save_model(Training_Trial, rnn_nhidden, tau, k, learning_rate, latent_dim, itr)
        tot_index = 40
        times_index = 0
        deriv_index = deriv_index.numpy()
        
        orig_trajs = orig_trajs_TE[:, 0:0+tot_num*tot_index, :]

        pred_x, pred_z = data_for_plot_graph(tot_index)
        pred_x = pred_x.reshape(trial_num, tot_num*tot_index, mesured_dim*(k+1))
        pred_z = pred_z.reshape(trial_num, tot_num*tot_index, latent_dim)
        pred_x_forgraph = pred_x.detach().cpu().numpy()
        pred_z_forgraph = pred_z.detach().cpu().numpy()

        path = "Results_pic/tau{}k{}/latent{}/data_loader_rnn2layer_lstm{}_lr{}_Trial{}/epoch{}".format(tau, k, latent_dim, rnn_nhidden, learning_rate, Training_Trial, itr)

        if not os.path.exists(path):
           os.makedirs(path)

        plotgraph_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

        gen_index = 40
        
        for i in range(len(plotgraph_index)):
            plot_graph(gen_index, times_index, plotgraph_index[i], deriv_index, pred_x_forgraph, orig_trajs, itr, path)
        
    print('Iter: {}, running avg mse: {:.4f} lowest val mse: {:.4f} at k {}'.format(itr, loss, lowest_val_loss, deriv_index))

Iter: 1, running avg mse: 1.0043 lowest val mse: 0.9891 at k 5
Iter: 2, running avg mse: 0.9792 lowest val mse: 0.9586 at k 5
Iter: 3, running avg mse: 0.9648 lowest val mse: 0.9479 at k 4
Iter: 4, running avg mse: 0.9475 lowest val mse: 0.9408 at k 4
Iter: 5, running avg mse: 0.9398 lowest val mse: 0.9364 at k 4
Iter: 6, running avg mse: 0.9351 lowest val mse: 0.9368 at k 4
Iter: 7, running avg mse: 0.9410 lowest val mse: 0.9347 at k 4
Iter: 8, running avg mse: 0.9300 lowest val mse: 0.9320 at k 4
Iter: 9, running avg mse: 0.9285 lowest val mse: 0.9242 at k 4
Iter: 10, running avg mse: 0.9073 lowest val mse: 0.9071 at k 4
Iter: 11, running avg mse: 0.9107 lowest val mse: 0.9014 at k 4
Iter: 12, running avg mse: 0.9036 lowest val mse: 0.8997 at k 4
Iter: 13, running avg mse: 0.9131 lowest val mse: 0.8960 at k 4
Iter: 14, running avg mse: 0.8930 lowest val mse: 0.8900 at k 4
Iter: 15, running avg mse: 0.8900 lowest val mse: 0.8846 at k 4
Iter: 16, running avg mse: 0.8909 lowest val mse:

In [53]:
Trial_tot_num = 10

In [None]:
for p in range(3, Trial_tot_num):
    Training_Trial = p
    
    latent_dim = 8
    nhidden = 64 ##Trial1 = 64, Trial2 = 128, Trial3 = 128, Trial4 = 64, Trial5 = 64
    dec_nhidden = 32
    obs_dim = 12*(k+1)
    rnn_nhidden = 256
    nitrs = 600
    noise_std = 0.2
    learning_rate = 0.008

    func = LatentODEfunc(latent_dim, nhidden).to(device)
    rec = RecognitionRNN(latent_dim, obs_dim, rnn_nhidden, batch).to(device)
    dec = Decoder(latent_dim, obs_dim, dec_nhidden).to(device)
    params = (list(func.parameters()) + list(dec.parameters()) + list(rec.parameters()))
    optimizer = optim.Adam(params, lr=learning_rate)
    loss_meter = RunningAverageMeter()

    train_losses = []
    val_losses = []
    val_losses_k1 = []
    val_losses_k2 = []
    val_losses_k3 = []
    val_losses_k4 = []
    val_losses_k5 = []
    val_losses_k6 = []
    val_losses_k7 = []
    val_losses_k8 = []
    val_losses_k9 = []
    torch.cuda.empty_cache()
    for itr in range(1, nitrs + 1):
        for data in train_loader:
            optimizer.zero_grad()
            h = rec.initHidden().to(device)
            c = rec.initHidden().to(device)
            hn = h[0, :, :]
            cn = c[0, :, :]
            for t in reversed(range(data.size(1))):
                obs = data[:, t, :]
                out, hn, cn = rec.forward(obs, hn, cn)
            qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
            epsilon = torch.randn(qz0_mean.size()).to(device)
            z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean   

            # forward in time and solve ode for reconstructions
            pred_z = odeint(func, z0, samp_ts).permute(1, 0, 2)
            pred_x = dec(pred_z)

            # compute loss
            loss = MSELoss(pred_x, data)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        with torch.no_grad():
            for data in val_loader:
                h = torch.zeros(1, batch, rnn_nhidden).to(device)
                c = torch.zeros(1, batch, rnn_nhidden).to(device)
                hn = h[0, :, :]
                cn = c[0, :, :]

                for t in reversed(range(data.size(1))):
                    obs = data[:, t, :]
                    out, hn, cn = rec.forward(obs, hn, cn)
                qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
                epsilon = torch.randn(qz0_mean.size()).to(device)
                z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

                #forward in time and solve ode for reconstructions
                pred_z = odeint(func, z0, samp_ts).permute(1, 0, 2)
                pred_x = dec(pred_z)

                #val_loss = MSELoss(pred_x, samp_trajs_val_TE) + torch.mean(-0.5 * torch.sum(1 + qz0_logvar - qz0_mean**2 - torch.exp(qz0_logvar), dim = -1)/(31*(k+1)+2))
                val_loss = MSELoss(pred_x[:,:,::(k+1)], data[:,:,::(k+1)])
                val_loss_k1 = MSELoss(pred_x[:,:,1::(k+1)], data[:,:,1::(k+1)])
                val_loss_k2 = MSELoss(pred_x[:,:,2::(k+1)], data[:,:,2::(k+1)])
                val_loss_k3 = MSELoss(pred_x[:,:,3::(k+1)], data[:,:,3::(k+1)])
                val_loss_k4 = MSELoss(pred_x[:,:,4::(k+1)], data[:,:,4::(k+1)])
                val_loss_k5 = MSELoss(pred_x[:,:,5::(k+1)], data[:,:,5::(k+1)])
                val_loss_k6 = MSELoss(pred_x[:,:,6::(k+1)], data[:,:,6::(k+1)])

                val_losses.append(val_loss)
                val_losses_k1.append(val_loss_k1)
                val_losses_k2.append(val_loss_k2)
                val_losses_k3.append(val_loss_k3)
                val_losses_k4.append(val_loss_k4)
                val_losses_k5.append(val_loss_k5)
                val_losses_k6.append(val_loss_k6)

                V = [val_loss, val_loss_k1, val_loss_k2, val_loss_k3, val_loss_k4, val_loss_k5, val_loss_k6]
                lowest_val_loss = torch.asarray(V).min(0)[0]
                deriv_index = torch.asarray(V).min(0)[1]

        if ((itr > 100) and (itr % 10 == 0)):
            save_model(Training_Trial, rnn_nhidden, tau, k, learning_rate, latent_dim, itr)
            tot_index = 40
            times_index = 0
            deriv_index = deriv_index.numpy()

            orig_trajs = orig_trajs_TE[:, 0:0+tot_num*tot_index, :]

            pred_x, pred_z = data_for_plot_graph(tot_index)
            pred_x = pred_x.reshape(trial_num, tot_num*tot_index, mesured_dim*(k+1))
            pred_z = pred_z.reshape(trial_num, tot_num*tot_index, latent_dim)
            pred_x_forgraph = pred_x.detach().cpu().numpy()
            pred_z_forgraph = pred_z.detach().cpu().numpy()

            path = "Results_pic/tau{}k{}/latent{}/data_loader_rnn2layer_lstm{}_lr{}_Trial{}/epoch{}".format(tau, k, latent_dim, rnn_nhidden, learning_rate, Training_Trial, itr)

            if not os.path.exists(path):
               os.makedirs(path)

            plotgraph_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

            gen_index = 40

            for i in range(len(plotgraph_index)):
                plot_graph(gen_index, times_index, plotgraph_index[i], deriv_index, pred_x_forgraph, orig_trajs, itr, path)

        print('Iter: {}, running avg mse: {:.4f} lowest val mse: {:.4f} at k {}'.format(itr, loss, lowest_val_loss, deriv_index))

In [None]:
for i in range(len(plotgraph_index)):
    plot_graph(gen_index, times_index, plotgraph_index[i], deriv_index, pred_x_forgraph, orig_trajs, itr, path)

In [None]:
orig_trajs.shape

In [None]:
train_loss = np.array(train_losses)

plt.plot(train_loss, 'r')
plt.savefig('C:/Users/shiny/Documents/NeuralODE_RatTreadMill/Results_pic/TrainingLossGraph/tau6k3/trainloss_0.005_nonoise_latent8_lstm256.png')
np.save('C:/Users/shiny/Documents/NeuralODE_RatTreadMill/Results_pic/TrainingLossGraph/tau6k3/trainloss_0.005_nonoise_latent8_lstm256.npy', train_loss)

In [None]:
folder = 'runs/model_'

def save_model():
    folder = 'runs/model'
    folder = os.path.join(folder, 'ckpt')

    ckpt_path = os.path.join(folder, f'ODE_normalized_4_128_2tanh.pth')

    save_dict = {
        'model_args': get_args(),
        'optimizer_state_dict': optimizer.state_dict(),
        'data': data_get_dict(),
        'train_loss': get_losses()
    }
    
    save_dict.update(get_state_dicts())
    
    torch.save(save_dict, 'C:/Users/shiny/Documents/NeuralODE_RatTreadMill/model/All_rodent_ODE_TakenEmbedding_tau6k3_LSTM_lr0.008_latent8_LSTMautoencoder_epoch500.pth')

save_model()

In [None]:
checkpoint = torch.load('model/All_rodent_ODE_TakenEmbedding_tau6k3_LSTM_lr0.008_latent8_LSTMautoencoder_epoch380.pth')
rec.load_state_dict(checkpoint['encoder_state_dict'])
func.load_state_dict(checkpoint['odefunc_state_dict'])
dec.load_state_dict(checkpoint['decoder_state_dict'])

## Long time series generation

In [None]:
gen_index = 20
times_index = 0
deriv_index = 0
itr= 380
orig_trajs_TE = np.load('orig_trajs_TE_tau6k3.npy')
orig_trajs_TE = orig_trajs_TE.reshape(203, 200*34, 31*(k+1)+2)
samp_trajs_TE_test = orig_trajs_TE[:, :50, :]

samp_trajs_TE_test = torch.from_numpy(samp_trajs_TE_test).float().to(device).reshape(203, 50, 31*(k+1)+2)
orig_trajs = orig_trajs_TE[:, 0:0+50*gen_index, :]

pred_x = data_for_plot_graph(gen_index)

path = "Results_pic/tau6k3/longtimeseries/epoch{}".format(itr)

if not os.path.exists(path):
   os.makedirs(path)

plot_graph(gen_index, times_index, 0, deriv_index, pred_x, orig_trajs, itr, path)
plot_graph(gen_index, times_index, 4, deriv_index, pred_x, orig_trajs, itr, path)
plot_graph(gen_index, times_index, 20, deriv_index, pred_x, orig_trajs, itr, path)

In [None]:
with torch.no_grad():
    # sample from trajectorys' approx. posterior

    ts_pos = np.linspace(0, 0.25*gen_index, num=50*gen_index)
    ts_pos = torch.from_numpy(ts_pos).float().to(device)
    
    h = torch.zeros(1, samp_trajs_TE.shape[0], rnn_nhidden).to(device)
    c = torch.zeros(1, samp_trajs_TE.shape[0], rnn_nhidden).to(device)
    
    hn = h[0, :, :]
    cn = c[0, :, :]
    
    for t in reversed(range(samp_trajs_TE.size(1))):
        obs = samp_trajs_TE[:, t, :]
        out, hn, cn = rec.forward(obs, hn, cn)
    qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
    epsilon = torch.randn(qz0_mean.size()).to(device)
    z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

    # forward in time and solve ode for reconstructions
    pred_z = odeint(func, z0, ts_pos).permute(1, 0, 2) #change time and batch with permute
    pred_x = dec(pred_z)

In [None]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x = pred_x.reshape(203, 50*gen_index, 28*(k+1)+2)
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    orig_trajs_forgraph = orig_trajs
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    ts_pos_combined = np.linspace(0, 0.25*gen_index, num=50*gen_index) 
    
    times_index = 0
    dataset_value = 0
    deriv_index = 0
    fig, axes = plt.subplots(nrows=6, ncols=5, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(ts_pos_combined[times_index:50*gen_index], orig_trajs_forgraph[dataset_value,0:50*gen_index, i], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+50*gen_index], pred_x_forgraph[dataset_value, times_index:+50*gen_index, i*(k+1)+deriv_index], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./Results_pic/tau6k3/longtimeseries/lstm_Tied_latent8_gen10_deriv0_50.png', dpi=500)
    #plt.savefig('./minibatchfps200_take300_predict900_positionalvalue3.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

In [None]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x = pred_x.reshape(203, 50*gen_index, 28*(k+1)+2)
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    orig_trajs_forgraph = orig_trajs
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    ts_pos_combined = np.linspace(0, 0.25*gen_index, num=50*gen_index) 
    
    times_index = 0
    dataset_value = 4
    deriv_index = 0
    fig, axes = plt.subplots(nrows=6, ncols=5, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(ts_pos_combined[times_index:50*gen_index], orig_trajs_forgraph[dataset_value,0:50*gen_index, i], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+50*gen_index], pred_x_forgraph[dataset_value, times_index:+50*gen_index, i*(k+1)+deriv_index], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./Results_pic/tau6k3/longtimeseries/lstm_split_latent8_gen10_deriv0_50.png', dpi=500)
    #plt.savefig('./minibatchfps200_take300_predict900_positionalvalue3.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

In [None]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x = pred_x.reshape(203, 50*gen_index, 28*(k+1)+2)
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    orig_trajs_forgraph = orig_trajs
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    ts_pos_combined = np.linspace(0, 0.25*gen_index, num=50*gen_index) 
    
    times_index = 0
    dataset_value = 20
    deriv_index = 0
    fig, axes = plt.subplots(nrows=6, ncols=5, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(ts_pos_combined[times_index:50*gen_index], orig_trajs_forgraph[dataset_value,0:50*gen_index, i], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+50*gen_index], pred_x_forgraph[dataset_value, times_index:+50*gen_index, i*(k+1)+deriv_index], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./Results_pic/tau6k3/longtimeseries/lstm_Tied_latent8_gen10_deriv0_50_wash.png', dpi=500)
    #plt.savefig('./minibatchfps200_take300_predict900_positionalvalue3.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

In [None]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x = pred_x.reshape(203, 50*gen_index, 31*(k+1)+2)
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    orig_trajs_forgraph = orig_trajs
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    ts_pos_combined = np.linspace(0, 0.25*gen_index, num=2500) 
    
    times_index = 0
    positional_value = 0
    fig, axes = plt.subplots(nrows=4, ncols=1, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(ts_pos_combined[0:50*gen_index], orig_trajs_forgraph[i,0:50*gen_index, positional_value], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+50*gen_index], pred_x_forgraph[i, times_index:+50*gen_index, positional_value*(k+1)], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./Results_pic/tau6k3/longtimeseries/Allrodent_Tau6k3_takenembedding_longepochgeneration_position0.png', dpi=500)
    #plt.savefig('./minibatchfps200_take300_predict900_positionalvalue3.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

## Predicting longer timescales with combining minibatches

In [None]:
with torch.no_grad():
    # sample from trajectorys' approx. posterior

    ts_pos = np.linspace(0, np.pi*2, num=50)
    ts_pos = torch.from_numpy(ts_pos).float().to(device)
    #ts_neg = np.linspace(-np.pi*20, 0., num=400)[::-1].copy()
    #ts_neg = torch.from_numpy(ts_neg).float().to(device)
    
    h = rec.initHidden().to(device)
    for t in reversed(range(samp_trajs.size(1))):
        obs = samp_trajs[:, t, :]
        out, h = rec.forward(obs, h)
    qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
    epsilon = torch.randn(qz0_mean.size()).to(device)
    z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

    # forward in time and solve ode for reconstructions
    pred_z = odeint(func, z0, ts_pos).permute(1, 0, 2)
    pred_x = dec(pred_z)

In [None]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x = pred_x.reshape(64, 1500, 46)
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    samp_trajs = samp_trajs.reshape(64, 300, 46)
    samp_trajs_forgraph = samp_trajs.detach().cpu().numpy()
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    samp_ts_combined = np.linspace(0, 299, num=300)
    ts_pos_combined = np.linspace(0, 1499, num=1500) 
    
    times_index = 0
    positional_value = 3
    fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(samp_ts_combined[times_index:+300], samp_trajs_forgraph[i,times_index:+300, positional_value], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+1200], pred_x_forgraph[i, times_index:+1200, positional_value], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./1200.png', dpi=500)
    #plt.savefig('./minibatchfps200_take300_predict900_positionalvalue3.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

## Predicting longer timescales: looking at each one

In [None]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    samp_trajs_forgraph = samp_trajs.detach().cpu().numpy()
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    
    times_index = 0
    positional_value = 0
    fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(samp_ts_combined[times_index:+50], samp_trajs_forgraph[i,times_index:+50, positional_value], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+250], pred_x_forgraph[i, times_index:+250, positional_value], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./test.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

In [None]:
z0.shape

## PCA for z0

In [None]:
from sklearn.decomposition import PCA
import pandas as pd

In [None]:
with torch.no_grad():
    ts_pos = np.linspace(0, np.pi*2*5, num=250)
    ts_pos = torch.from_numpy(ts_pos).float().to(device)
    
    h = torch.zeros(samp_trajs.shape[0], rnn_nhidden).to(device)
    
    for t in reversed(range(samp_trajs.size(1))):
        obs = samp_trajs[:, t, :]
        out, h = rec.forward(obs, h)
    qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
    epsilon = torch.randn(qz0_mean.size()).to(device)
    z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean
    z0 = z0.cpu()
    
    pca = PCA(n_components=2)
    pca.fit(z0.cpu())
    print("Explained variance:", pca.explained_variance_ratio_)
    z0_red = pca.fit_transform(z0)
    
    print(z0_red[:, 0].shape)

In [None]:
z0 = z0.cpu()

In [None]:
    pca = PCA(n_components=2)
    pca.fit(z0.cpu())
    print("Explained variance:", pca.explained_variance_ratio_)

In [None]:
pca_z = PCA(n_components=2)
pca.fit(z0)

z0_red = pca.fit_transform(z0)

d = {'PC1': z0_red[:, 0], 'PC2': z0_red[:, 1]}
df = pd.DataFrame(d)

plt.figure()
plt.plot(z0_red[:, 0], z0_red[:, 1], 'o', label='z0 samples in 2D', linewidth=2, zorder=1)
plt.legend()
plt.savefig('./PCAgraph.png', dpi=250)

In [None]:
from sklearn.manifold import TSNE

In [None]:
time_start = time.time()
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
tsne_results = tsne.fit_transform(data_subset)