# Install (Important)

In [None]:
# Install the latest version of author's repo neural ode implementation
!git clone https://github.com/rtqichen/torchdiffeq.git
!cd torchdiffeq && pip install -e .
!ls torchdiffeq/torchdiffeq

# Latent ODE on Multi-variate data

# Rough

In [None]:
import os
import time
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
import pandas as pd 
from pandas import concat
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
%matplotlib inline
from torchdiffeq.torchdiffeq import odeint

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dict_tickers = {
    'Apple': 'AAPL',
    'Microsoft': 'MSFT',
    'Google': 'GOOG',
    'Bitcoin': 'BTC-USD',
    'Facebook': 'FB',
    'Walmart': 'WMT',
    'Amazon': 'AMZN',
    'CVS': 'CVS',
    'Berkshire': 'BRK-B',
    'ExxonMobil': 'XOM',
    'AtandT': 'T',
    'Costco': 'COST',
    'Walgreens': 'WBA',
    'Kroger': 'KR',
    'JPMorgan': 'JPM',
    'Verizon': 'VZ',
    'FordMotor': 'F',
    'GeneralMotors': 'GM',
    'Dell': 'DELL',
    'BankOfAmerica': 'BAC',
    'Target': 'TGT',
    'GeneralElectric': 'GE',
    'JohnsonandJohnson': 'JNJ',
    'Nvidia': 'NVDA',
    'Intel': 'INTC',
}

def stockDataTransformer(filepath):
    df = pd.read_csv(filepath)
    df.set_index('Date', inplace=True)
    df1 = df[['Open', 'Close']].copy()
    data = df1.values
    n_samples = data.shape[0]//10*10
    reshape_number = n_samples*data.shape[1]//10
    data1 = data[:n_samples].reshape((reshape_number, 10))
    return data1

def series_to_supervised(data, n_in=1, n_out=1, dropnan=True):
    """
    Frame a time series as a supervised learning dataset.
    Arguments:
        data: Sequence of observations as a list or NumPy array.
        n_in: Number of lag observations as input (X).
        n_out: Number of observations as output (y).
        dropnan: Boolean whether or not to drop rows with NaN values.
    Returns:
        Pandas DataFrame of series framed for supervised learning.
    """
    n_vars = 1 if type(data) is list else data.shape[1]
    df = pd.DataFrame(data)
    cols, names = list(), list()
    # input sequence (t-n, ... t-1)
    for i in range(n_in, 0, -1):
        cols.append(df.shift(i))
        names += [('var%d(t-%d)' % (j+1, i)) for j in range(n_vars)]
    # forecast sequence (t, t+1, ... t+n)
    for i in range(0, n_out):
        cols.append(df.shift(-i))
        if i == 0:
            names += [('var%d(t)' % (j+1)) for j in range(n_vars)]
        else:
            names += [('var%d(t+%d)' % (j+1, i)) for j in range(n_vars)]
    # put it all together
    agg = concat(cols, axis=1)
    agg.columns = names
    # drop rows with NaN values
    if dropnan:
        agg.dropna(inplace=True)
    return agg

def get_median(array, axis = 1):
    # https://numpy.org/doc/stable/reference/generated/numpy.median.html
    return np.median(array, axis = axis).reshape(data_size, 1)  #, keepdims=True)

def split_data(perc_train, perc_valid, lag, data_orig, data_m1, n_features_orig, n_features_median):
    values = data_m1
    
    sizeOfReframed = len(data_m1)
    len_train = int(perc_train*sizeOfReframed) # int(sizeOfReframed - len_test) # - len_valid)
    train_data_orig = data_orig[:len_train, :]
    # valid = values[len_train:len_valid+len_train, :]
    test_data_orig = data_orig[len_train:, :]  # [len_valid+len_train:, :]
    # n_features = n_features
    
    train_data_ml = values[:len_train, :]
    test_data_ml = values[len_train:, :] 
    # split into input and outputs
    n_obs = lag * n_features_orig
    n_obs_median = (lag+forecast) * n_features_median
    train_X, train_y = train_data_orig[:, :n_obs], train_data_ml[:, :n_obs_median]
    test_X, test_y = test_data_orig[:, :n_obs], test_data_ml[:, :n_obs_median]
    # valid_X, valid_y = valid[:, :n_obs], valid[:, -1]
    print(train_X.shape, len(train_X), train_y.shape)
    
    # reshape input to be 3D [samples, features, lag]
    train_X = train_X.reshape((train_X.shape[0], n_features_orig, lag))
    test_X = test_X.reshape((test_X.shape[0], n_features_orig, lag))
    # valid_X = valid_X.reshape((valid_X.shape[0], lag, n_features))
    print(train_X.shape, train_y.shape, test_X.shape, test_y.shape)  # , valid_X.shape, valid_y.shape)
    
    # Get the reconstruction train_y, test_y and extrapolated train_y, test_y
    train_y_recon, train_y_extrapol = train_y[:, :lag], train_y[:, lag:]
    test_y_recon, test_y_extrapol = test_y[:, :lag], test_y[:, lag:]
    dataload = {
        'train_data_orig': train_data_orig,
        'test_data_orig': test_data_orig,
        'train_data_ml': train_data_ml,
        'test_data_ml': test_data_ml,
        # 'valid': valid,
        'train_X': train_X,
        'train_y': train_y,
        'test_X': test_X,
        'test_y': test_y,
        'n_features_orig': n_features_orig,
        'n_features_median': n_features_median,
        'n_obs': n_obs,
        'n_obs_median': n_obs_median,
        # 'valid_X': valid_X,
        # 'valid_y': valid_y,
        'train_y_recon': train_y_recon,
        'train_y_extrapol': train_y_extrapol,
        'test_y_recon': test_y_recon,
        'test_y_extrapol': test_y_extrapol
    }
    
    return dataload

# https://discuss.pytorch.org/t/rmse-loss-function/16540/3
class RMSELoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.mse = nn.MSELoss()
        self.eps = eps
        
    def forward(self,yhat,y):
        loss = torch.sqrt(self.mse(yhat,y) + self.eps)
        return loss
class LatentODEfunc(nn.Module):

    def __init__(self, latent_dim=4, nhidden=20):
        super(LatentODEfunc, self).__init__()
        self.elu = nn.ELU(inplace=True)
        self.fc1 = nn.Linear(latent_dim, nhidden)
        self.fc2 = nn.Linear(nhidden, nhidden)
        self.fc3 = nn.Linear(nhidden, latent_dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.fc1(x).to(device)
        out = self.elu(out).to(device)
        out = self.fc2(out).to(device)
        out = self.elu(out).to(device)
        out = self.fc3(out).to(device)
        return out

class RecognitionRNN(nn.Module):

    def __init__(self, latent_dim=4, obs_dim=5, nhidden=25, nbatch=1):
        super(RecognitionRNN, self).__init__()
        self.nhidden = nhidden
        self.nbatch = nbatch
        self.i2h = nn.Linear(obs_dim + nhidden, nhidden)
        self.i2h = self.i2h.float()
        self.h2o = nn.Linear(nhidden, latent_dim * 2)
        self.h2o = self.h2o.float()

    def forward(self, x, h):
        combined = torch.cat((x, h), dim=1).to(device)
        h = torch.tanh(self.i2h(combined.float())).to(device)
        out = self.h2o(h).to(device)
        return out, h

    def initHidden(self):
        return torch.zeros(self.nbatch, self.nhidden)

class Decoder(nn.Module):

    def __init__(self, latent_dim=4, obs_dim=5, nhidden=20):
        super(Decoder, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(latent_dim, nhidden)
        self.fc2 = nn.Linear(nhidden, obs_dim)

    def forward(self, z):
        out = self.fc1(z).to(device)
        out = self.relu(out).to(device)
        out = self.fc2(out).to(device)
        return out

class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val


def log_normal_pdf(x, mean, logvar):
    const = torch.from_numpy(np.array([2. * np.pi])).float().to(x.device).to(device)
    const = torch.log(const).to(device)
    return -.5 * (const + logvar + (x - mean) ** 2. / torch.exp(logvar))

def normal_kl(mu1, lv1, mu2, lv2):
    v1 = torch.exp(lv1)
    v2 = torch.exp(lv2)
    lstd1 = lv1 / 2.
    lstd2 = lv2 / 2.

    kl = lstd2 - lstd1 + ((v1 + (mu1 - mu2) ** 2.) / (2. * v2)) - .5
    return kl.to(device)

def train(loss_str, niters):
    loss_list = []
    for itr in range(1, niters + 1):
    
        optimizer.zero_grad()
        # backward in time to infer q(z_0)
        h = rec.initHidden().to(device)  # (# nbatches_train, rnn_hidden)
        for t_r in reversed(range(train_X.shape[2])):  # input_dimension
            obs = train_X[:, :, t_r].to(device)
            # obs = torch.reshape(obs, (1, 1)).to(device)
            out, h = rec.forward(obs, h)
        qz0_mean, qz0_logvar = out[:, :latent_dim].to(device), out[:, latent_dim:].to(device)
        epsilon = torch.randn(qz0_mean.size()).to(device)
        z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

        # forward in time and solve ode for reconstructions
        pred_z = odeint(func, z0, t.to(device)).permute(1, 0, 2)  # [:, -1, :]
        pred_x = dec(pred_z).to(device)
        pred_x = torch.reshape(pred_x, (train_X.shape[0], lag+forecast)).to(device)

        # compute loss
        if loss_str == 'mse':
            loss = torch.nn.MSELoss()(pred_x.float(), train_X[i, :, :].float().to(device)).float()
        elif loss_str == 'elbo':
            noise_std_ = torch.zeros(pred_x.size()).to(device) + noise_std
            noise_logvar = 2. * torch.log(noise_std_).to(device)
            logpx = log_normal_pdf(
                train_y[:, 0, :].to(device), pred_x, noise_logvar).sum(-1)  # .sum(-1)
            pz0_mean = pz0_logvar = torch.zeros(z0.size()).to(device)
            analytic_kl = normal_kl(qz0_mean, qz0_logvar,
                                    pz0_mean, pz0_logvar).sum(-1)
            loss = torch.mean(-logpx + analytic_kl, dim=0).to(device)
            # loss = torch.reshape(loss, (1, 1)).to(device)
        loss_list.append(loss)
            # loss_ = torch.mean(torch.cat([x.float() for x in loss_list])).to(device)  
        loss.backward()
        optimizer.step()
        # loss_meter.update(loss.item())
        if itr%10==0:
            print('Iter: {}, running: {:.4f}'.format(itr, loss.item()))
    return loss_list

def train_loss(h):
    train_loss = 0.0
    predictions = []
    
    for t_r in reversed(range(train_X.shape[2])):
        obs = train_X[:, :, t_r].to(device)
        # obs = torch.reshape(obs, (1, 1)).to(device)
        out, h = rec.forward(obs, h)

    qz0_mean, qz0_logvar = out[:, :latent_dim].to(device), out[:, latent_dim:].to(device)
    epsilon = torch.randn(qz0_mean.size()).to(device)
    z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

    # forward in time and solve ode for reconstructions
    pred_z = odeint(func, z0, t.to(device)).permute(1, 0, 2)
    pred_x = dec(pred_z).to(device)
    pred_x = torch.reshape(pred_x, (pred_x.shape[0], 1, lag+forecast))
    pred_x_recon = pred_x[:, :, :lag]
    pred_x_extrapol = pred_x[:, :, lag:]
    rmse = RMSELoss()
    # loss = torch.nn.MSELoss()(pred_x, train_X[:, :, :].to(device))
    loss_recon = rmse(pred_x_recon, train_y_recon)
    loss_extrapol = rmse(pred_x_extrapol, train_y_extrapol)
    # train_loss += loss
    # predictions.append(pred_x)
    
    # train_loss = torch.sqrt(train_loss)
    # train_pred = torch.cat([x.float() for x in predictions])
    # train_pred = torch.reshape(train_pred, (train_pred.shape[0], 1, lag))
    with torch.no_grad():
        print('Train: Reconstruction Loss')
        print('Total Train Loss {:.6f}'.format(loss_recon.item()))
        print('Train: Extrapolation Loss')
        print('Total Train Extrapolation Loss {:.6f}'.format(loss_extrapol.item()))
    return pred_x, pred_x_recon, pred_x_extrapol

def test_loss(h, t_test):
    # print(h.shape)
    test_loss = 0.0
    predictions = []
    rmse = RMSELoss()
    for t_r in reversed(range(test_X.shape[2])):
        obs = test_X[:, :, t_r].to(device)
        # obs = torch.reshape(obs, (1, 1)).to(device)
        out, h = rec.forward(obs, h)
    qz0_mean, qz0_logvar = out[:, :latent_dim].to(device), out[:, latent_dim:].to(device)
    epsilon = torch.randn(qz0_mean.size()).to(device)
    z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

    # forward in time and solve ode for reconstructions
    pred_z = odeint(func, z0, t_test).permute(1, 0, 2).to(device)
    pred_x = dec(pred_z).to(device)
    pred_x = torch.reshape(pred_x, (pred_x.shape[0], 1, lag+forecast))
    pred_x_recon = pred_x[:, :, :lag]
    pred_x_extrapol = pred_x[:, :, lag:]
    # pred_test_X = torch.reshape(pred_test_X, (pred_test_X.shape[0], pred_test_X.shape[1]))
    # pred_test_y = torch.reshape(pred_test_y, (1, pred_test_y.shape[0], 1, lag))
    # print(test_pred_y.shape)
    rmse = RMSELoss()
    # loss = torch.nn.MSELoss()(pred_x, train_X[i, :, :].to(device))
    loss_recon = rmse(pred_x_recon, test_y_recon)
    loss_extrapol = rmse(pred_x_extrapol, test_y_extrapol)

    # loss = torch.nn.MSELoss()(pred_x, torch.reshape(test_X[i, :, :].to(device), (1, 5))).to(device)
    # test_loss += loss
    # predictions.append(pred_x)
    
    # test_loss = torch.sqrt(test_loss)
    # test_pred = torch.cat([x.float() for x in predictions])
    
    # loss = torch.nn.MSELoss()(test_pred_y[train_size-test_size:, batch_time2-2, :], label_batch_y[train_size-test_size:, batch_time2-2, :])
    with torch.no_grad():
        print('Test: Reconstruction Loss')
        print('Total Loss {:.6f}'.format(loss_recon.item()))
        print('Test: Extrapolation Loss')
        print('Total Loss {:.6f}'.format(loss_extrapol.item()))
    return pred_x, pred_x_recon, pred_x_extrapol

def plot_train_recon(i, train_y_recon, pred_y_recon, fig, axes):
    t_ = torch.linspace(1., train_y_recon.shape[0], train_y_recon.shape[0])
    plt.figure()
    plt.plot(t_.numpy(), train_y_recon.cpu().numpy()[:, :, i-1], 'g', label = f"orig_week{i}")
    with torch.no_grad():
        rmse = np.sqrt(((train_y_recon.cpu().numpy()[:, :, i-1] - pred_y_recon.cpu().numpy()[:, :, i-1]) ** 2).mean())
        plt.plot(t_.numpy(), pred_y_recon.cpu().numpy()[:, :, i-1], '--', label = f"recon_week{i}")
    plt.title(f"Trial No. {trial}: Train Recon: {tickerName}'s Median Stock price for week{i}: RMSE {rmse}")
    plt.legend(framealpha=1, frameon=True);
    plt.savefig(f"plots-latentode/Trial No. {trial}: Train Recon: {tickerName}'s Median Stock price for week{i}.pdf", dpi = 150)
    plt.show()

def plot_train_extrapol(i, train_y_extrapol, pred_y_extrapol, fig, axes):
    t_ = torch.linspace(1., train_y_extrapol.shape[0], train_y_extrapol.shape[0])
    plt.figure()
    plt.plot(t_.numpy(), train_y_extrapol.cpu().numpy()[:, :, i-1-lag], 'g', label = f"orig_week{i}")
    with torch.no_grad():
        rmse = np.sqrt(((train_y_extrapol.cpu().numpy()[:, :, i-1-lag] - pred_y_extrapol.cpu().numpy()[:, :, i-1-lag]) ** 2).mean())
        plt.plot(t_.numpy(), pred_y_extrapol.cpu().numpy()[:, :, i-1-lag], '--', label = f"extrapol_week{i}")
    plt.title(f"Trial No. {trial}: Train Extrapol: {tickerName}'s Median Stock price for week{i}: RMSE {rmse}")
    plt.legend(framealpha=1, frameon=True);
    plt.savefig(f"plots-latentode/Trial No. {trial}: Train Extrapol: {tickerName}'s Median Stock price for week{i}.pdf", dpi = 150)
    plt.show()

def plot_test_recon(i, test_y_recon, pred_y_recon, fig, axes):
    t_ = torch.linspace(1., test_y_recon.shape[0], test_y_recon.shape[0])
    plt.figure()
    plt.plot(t_.numpy(), test_y_recon.cpu().numpy()[:, :, i-1], 'g', label = f"orig_week{i}")
    with torch.no_grad():
        rmse = np.sqrt(((test_y_recon.cpu().numpy()[:, :, i-1] - pred_y_recon.cpu().numpy()[:, :, i-1]) ** 2).mean())
        plt.plot(t_.numpy(), pred_y_recon.cpu().numpy()[:, :, i-1], '--', label = f"recon_week{i} : RMSE {rmse}")
    plt.title(f"Trial No. {trial}: Test Recon: {tickerName}'s Median Stock price for week{i}: RMSE {rmse}")
    plt.legend(framealpha=1, frameon=True);
    plt.savefig(f"plots-latentode/Trial No. {trial}: Test Recon: {tickerName}'s Median Stock price for week{i}.pdf", dpi = 150)
    plt.show()

def plot_test_extrapol(i, test_y_extrapol, pred_y_extrapol, fig, axes):
    t_ = torch.linspace(1., test_y_extrapol.shape[0], test_y_extrapol.shape[0])
    plt.figure()
    plt.plot(t_.numpy(), test_y_extrapol.cpu().numpy()[:, :, i-1-lag], 'g', label = f"orig_week{i}")
    with torch.no_grad():
        rmse = np.sqrt(((test_y_extrapol.cpu().numpy()[:, :, i-1-lag] - pred_y_extrapol.cpu().numpy()[:, :, i-1-lag]) ** 2).mean())
        plt.plot(t_.numpy(), pred_y_extrapol.cpu().numpy()[:, :, i-1-lag], '--', label = f"extrapol_week{i} : RMSE {rmse}")
    plt.title(f"Trial No. {trial}: Test Extrapol: {tickerName}'s  Median Stock price for week{i}: RMSE {rmse}")
    plt.legend(framealpha=1, frameon=True);
    plt.savefig(f"plots-latentode/Trial No. {trial}: Test Extrapol: {tickerName}'s  Median Stock price for week{i}.pdf", dpi = 150)
    plt.show()

def plot_loss(loss_list):
    plt.plot(loss_list)
    plt.title(f'{tickerName}: Train Loss (Recon+Extrapol')
    plt.ylabel('ELBO Loss')
    plt.xlabel('Epochs')
    plt.savefig(f"plots-latentode/{tickerName}: Trial No. {trial}: ELBO Train Loss.pdf", dpi = 150)
    plt.show()

def rmse_table(recon_rmse_filepath, extrapol_rmse_filepath, train_y_recon, pred_train_recon, train_y_extrapol, pred_train_extrapol, test_y_recon, pred_test_recon, test_y_extrapol, pred_test_extrapol):
    recon_rmse_data = {
        'week': list(),
        'train_recon_rmse': list(),
        'test_recon_rmse': list()
    }

    extrapol_rmse_data = {
        'week': list(),
        'train_extrapol_rmse': list(),
        'test_extrapol_rmse': list()
    }
    for i in range(1, lag+forecast+1):
        with torch.no_grad():
            if i<=lag:
                recon_rmse_data['week'].append(i)
                train_recon_rmse = np.sqrt(((train_y_recon.cpu().numpy()[:, :, i-1] - pred_train_recon.cpu().numpy()[:, :, i-1]) ** 2).mean())
                recon_rmse_data['train_recon_rmse'].append(train_recon_rmse)
                test_recon_rmse = np.sqrt(((test_y_recon.cpu().numpy()[:, :, i-1] - pred_test_recon.cpu().numpy()[:, :, i-1]) ** 2).mean())
                recon_rmse_data['test_recon_rmse'].append(test_recon_rmse)
            elif i>lag:
                extrapol_rmse_data['week'].append(i)
                train_extrapol_rmse = np.sqrt(((train_y_extrapol.cpu().numpy()[:, :, i-1-lag] - pred_train_extrapol.cpu().numpy()[:, :, i-1-lag]) ** 2).mean())
                extrapol_rmse_data['train_extrapol_rmse'].append(train_extrapol_rmse)
                test_extrapol_rmse = np.sqrt(((test_y_extrapol.cpu().numpy()[:, :, i-1-lag] - pred_test_extrapol.cpu().numpy()[:, :, i-1-lag]) ** 2).mean())
                extrapol_rmse_data['test_extrapol_rmse'].append(test_extrapol_rmse)
    recon_rmse_df = pd.DataFrame(recon_rmse_data)
    recon_rmse_df.to_csv(recon_rmse_filepath)
    extrapol_rmse_df = pd.DataFrame(extrapol_rmse_data)
    extrapol_rmse_df.to_csv(extrapol_rmse_filepath)

    return recon_rmse_data, extrapol_rmse_data


In [None]:

# FileName
tickerName = 'ExxonMobil'
# Filepath
filepath = f"raw-stock-data/{tickerName}.csv"
# Get the data in the required format
data = stockDataTransformer(filepath)
print('Data after Data Transformation')
# # Total Data Size
data_size = data.shape[0]

print(pd.DataFrame(data))
print('\n')

lag = 7
forecast = 3
data_orig = series_to_supervised(data, lag, forecast).values
print('Data Original after series to supervised on data')
print(data_orig.shape)
print(pd.DataFrame(data_orig))
print('\n')

median_data = get_median(data)
print('Median data')
# Median data for each week
print(median_data.shape)
print(pd.DataFrame(median_data, columns = ['median_stockprice_week']).head(10))
print('\n')

# Convert median_data to (n_samples, 5) matrix
data_m1 = series_to_supervised(median_data, lag, forecast).values
print('Median data after series to supervised')
print(data_m1.shape)
print(pd.DataFrame(data_m1, columns = [f"week i+{i}" for i in range(1, lag+forecast+1)]))
print('\n')

dataload = split_data(0.8, 0, lag, data_orig, data_m1, data.shape[1], 1)

print('Get Train and Test data')
train_X = torch.from_numpy(dataload['train_X']).to(device)
print(f"train_X shape: {train_X.shape}")  # (#training, 1, 5)

train_y = torch.from_numpy(dataload['train_y']).to(device)
print(f"train_y shape: {train_y.shape}")  # (#training, 5)
train_y = torch.reshape(train_y, (train_X.shape[0], 1, train_y.shape[1])).to(device)
print(f"train_y shape: {train_y.shape}")  # (#training, 1, 5)
test_X = torch.from_numpy(dataload['test_X']).to(device)
print(f"test_X.shape : {test_X.shape}")  # (#testing, 1, 5)
test_y = torch.from_numpy(dataload['test_y']).to(device)
print(f"test_y.shape : {test_y.shape}")  # (#testing, 5)
test_y = torch.reshape(test_y, (test_X.shape[0], 1, test_y.shape[1])).to(device)
print(f"test_y.shape : {test_y.shape}")

train_y_recon = torch.from_numpy(dataload['train_y_recon']).to(device)
train_y_recon = torch.reshape(train_y_recon, (train_y_recon.shape[0], 1,train_y_recon.shape[1]))
train_y_extrapol = torch.from_numpy(dataload['train_y_extrapol']).to(device)
train_y_extrapol = torch.reshape(train_y_extrapol, (train_y_extrapol.shape[0], 1, train_y_extrapol.shape[1]))
test_y_recon = torch.from_numpy(dataload['test_y_recon']).to(device)
test_y_recon = torch.reshape(test_y_recon, (test_y_recon.shape[0], 1, test_y_recon.shape[1]))
test_y_extrapol = torch.from_numpy(dataload['test_y_extrapol']).to(device)
test_y_extrapol = torch.reshape(test_y_extrapol, (test_y_extrapol.shape[0], 1, test_y_extrapol.shape[1]))
print(f"train_y_recon.shape : {train_y_recon.shape}")
print(f"train_y_extrapol.shape : {train_y_extrapol.shape}")
print(f"test_y_recon.shape : {test_y_recon.shape}")
print(f"test_y_extrapol.shape : {test_y_extrapol.shape}")
print('\n')

trial = 2
# predictions on test_data's time size (test_size-batch_time, 1)
# Time steps
t = torch.linspace(0, lag+forecast-1, lag+forecast).to(device)
print(f"t.shape : {t.shape}")

latent_dim = 50
nhidden = 100
rnn_nhidden = 125
obs_dim = data.shape[1]
out_dim = 1
noise_std = .3

lr = 0.001
loss_str = 'elbo'
# all_values = True
niters = 2000  # training epochs