Fix possible bug in the encoder-decoder reshaping, check results on different test cases (where validation runs and test runs are in interpolation as well as extrapolation regime)

Future:
- Check extrapolation (Higher BStrength compared to training)

- Loss function

- LR scheduler

- Latent space visualization

In [1]:
import os
os.getcwd()

'/home/ajivani/WLROM_new/WLROM_new/WhiteLight'

In [2]:
import sys
sys.path.append('./WhiteLight/')

In [3]:
import os
import sys
import matplotlib
import matplotlib.pyplot
import matplotlib.pyplot as plt

import time
import datetime
import argparse
import numpy as np
import pandas as pd
from random import SystemRandom
import logging

import torch
import torch.nn as nn
from torch.nn.functional import relu
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset

from rich.progress import track
import tqdm

adjoint=True
if adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint

from sunpy.visualization import colormaps as cm

import torchvision
import torchvision.transforms as T

import node_utils as nut
import data_utils as dut

ModuleNotFoundError: No module named 'node_utils'

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
plt.rc("axes.spines", right=True, top=True)
plt.rc("figure", dpi=300, 
       figsize=(9, 3)
      )
plt.rc("font", family="serif")
plt.rc("legend", edgecolor="none", frameon=True)

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
device

In [None]:
args = {
    "niters": 4000,
    "lr": 1e-2,
    "save": "experiments/",
    "load": None,
    # "load":68818, # 8 latent dims!
    "r": 1991,
    "node-layers": 2,
    "mlp-layers": 1,
    # "latents": 8,
    "latents": 20, # 44467 expt ID, else only 8 dims
    "units": 50,
    "test_mode": "t5p4",
    "resize_dims": (32, 128),
    "param_scaling": [1, 1],
    "warm_start": 38931,
    "do_warm_start": False
    # "test_mode" can take values "t2p1, t4p1" for now (number of training vs val and test sims)
}

args

In [None]:
validation_dir = "./validation_data"
validation_file = os.path.join(validation_dir, "CR2161_validation_PolarTensor.npy")
sim_file = os.path.join(validation_dir, "CR2161_SimID4edge_validation.npy")
param_file = os.path.join(validation_dir, "params_2161_validation.txt")

In [None]:
data_obj = dut.parse_datasets(validation_file, sim_file, param_file, args, device)

In [None]:
data_obj

In [None]:
model = nut.PNODE_Conv(input_dim=4096, latent_dim=args["latents"], param_dim=2, 
                       device=device,
                      n_layers=args["node-layers"],
                      n_units=args["units"]).to(device)


pEncoder = nut.count_parameters(model.encoder)
pPNODE = nut.count_parameters(model.pnode)
pDecoder = nut.count_parameters(model.decoder_mlp) + nut.count_parameters(model.decoder_conv)

print("Total Number of Parameters: ")
print(pEncoder + pPNODE + pDecoder)

nut.makedirs(os.path.join(os.getcwd(), "logs"))
logdir = os.path.join(os.getcwd(), "logs")

nut.makedirs(args["save"])
# nut.makedirs("./model_stopping")




if args["load"] is not None:
    experimentID = args["load"]
    print("Loading experiment")
else:
    if args["do_warm_start"] == True:
        experimentID = args["warm_start"]
        print("Warm starting previous experiment")
    else:
        experimentID = int(SystemRandom().random()*100000)
        print("Starting new experiment")

print("Experiment ID: ")
print(experimentID)

ckpt_path = os.path.join(args["save"], "experiment_" + str(experimentID) + '.ckpt')

print("Checkpoint Path: ")
print(ckpt_path)

logger = dut.get_logger(logpath=os.path.join(logdir, "expt_normalized_pnode_images_2.log"))
logger

logger.info(os.getcwd())
logger.info(model)
logger.info("Number of parameters: {}".format(pEncoder + pPNODE + pDecoder))
logger.info(args)
logger.info("Checkpoint Path")
logger.info(ckpt_path)
logger.info("Input Dim: ")
logger.info(data_obj["input_dim"])
logger.info("Param Dim: ")
logger.info(2)
logger.info("Latent Dim: ")
logger.info(args["latents"])
logger.info(device)

In [None]:
experimentID

In [None]:
torch.backends.cudnn.enabled = True  # Enable cuDNN
torch.backends.cudnn.benchmark = True  # Use cuDNN's auto-tuner for the best performance
# torch.multiprocessing.set_start_method('spawn')

In [None]:
# save optimizer state, and epoch. Load everything and warm start if interrupting kernel. also set ckpt_freq so its not
# saving the model at every successful reduction in validation loss.

In [None]:
# Credit to Hongfan Chen for the stopping and LR scheduling pipeline.

optimizer = optim.Adamax(model.parameters(), lr=args["lr"])
val_loss_list = []
train_loss_list = []
n_iters_to_viz = 1
num_batches = 1
early_stopping = nut.EarlyStopping(patience=60,
                                verbose=True,
                                path=ckpt_path)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.6, patience=40)
ckpt_freq=10
trainMax = data_obj["max_train"].item()
trainMin = data_obj["min_train"].item()

if args["load"] is None:
    # This is us trying to warm start predictions.
    if args["do_warm_start"] == True:
        checkpoint = torch.load(ckpt_path)
        model.load_state_dict(checkpoint['model'])
        start_epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        print("Warm Starting Experiment {} from Epoch {}".format(experimentID, start_epoch))
    else:
        start_epoch = 0
        print("Starting a new experiment {} from Epoch {}".format(experimentID, start_epoch))
    
    # Resume training from a specific epoch
    for itr in track(range(start_epoch + 1, num_batches * (args["niters"] + 1)), description="Looping over Epochs..."):
    # for itr in track(range(1, num_batches * (args["niters"] + 1)), description="Looping over Epochs..."):
        optimizer.zero_grad()
        batch_dict = dut.get_next_batch(data_obj["train_dataloader"], device=device)

        raw_pred = model(batch_dict["tp_to_predict"], batch_dict["observed_data"])
        model_pred = torch.permute(torch.squeeze(raw_pred), (1, 0, 2)) * (trainMax - trainMin) + trainMin
        target = batch_dict["data_to_predict"] * (trainMax - trainMin) + trainMin

        loss = nut.oden_loss(target, model_pred, loss='smooth_l1')
        loss.backward()
        optimizer.step()

        train_loss_list.append(loss.item())


        val_dict = dut.get_next_batch(data_obj["val_dataloader"], device=device)
        raw_val_pred = model(val_dict["tp_to_predict"], val_dict["observed_data"])
        val_pred = torch.permute(torch.squeeze(raw_val_pred), (1, 0, 2)) * (trainMax - trainMin) + trainMin
        val_target = val_dict["data_to_predict"] * (trainMax - trainMin) + trainMin

        val_loss = nut.oden_loss(val_target, val_pred, loss='smooth_l1')

        val_loss_list.append(val_loss.item())

        message = 'Epoch {:04d}|Val loss {:.6e}|Train loss {:.6e}|LR {:.5f}|'.format(
                        itr//num_batches, 
                        val_loss.item(), 
                        loss.item(), 
                        optimizer.param_groups[0]["lr"])

        logger.info(message)
        scheduler.step(val_loss)
        if itr % ckpt_freq == 0:
            early_stopping(val_loss, model, optimizer, itr, args)


        if early_stopping.early_stop:
            print('Early stopping')
            torch.save(model.state_dict(), os.path.join(args["save"], "experiment_" + str(experimentID) + '_early_stopping.ckpt'))
            break
        
else:
    nut.get_ckpt_model(ckpt_path, model, device)

In [None]:
# torch.save(model.state_dict(), ckpt_path)
# torch.save(model.state_dict(), os.path.join(args["save"], "experiment_" + str(experimentID) + '_early_stopping.ckpt'))

### Plotting train and val sims

In [None]:
BS_train = data_obj["train_params"][:, 0] * args["param_scaling"][0]
BS_val = data_obj["val_params"][:, 0] * args["param_scaling"][0]
BS_test = data_obj["test_params"][:, 0] * args["param_scaling"][0]

BS_train, BS_val, BS_test

In [None]:
plt.scatter(BS_train, np.ones(6), s=12, c='red', marker='x', label="Train")
plt.scatter(BS_val, np.ones(4), s=12, c='blue', marker='o', label="Val")
plt.scatter(BS_test, np.ones(2), s=12, c='green', marker='d', label="Test")
plt.xlabel("BStrength")
plt.legend()

In [None]:
batch_dict = dut.get_next_batch(data_obj["train_dataloader"], device=device)
val_dict = dut.get_next_batch(data_obj["val_dataloader"], device=device)
test_dict = dut.get_next_batch(data_obj["test_dataloader"], device=device)

raw_train_pred = model(batch_dict["tp_to_predict"], batch_dict["observed_data"])
raw_val_pred = model(val_dict["tp_to_predict"], val_dict["observed_data"])
raw_test_pred = model(test_dict["tp_to_predict"], test_dict["observed_data"])


with torch.no_grad():
    train_pred = torch.permute(torch.squeeze(raw_train_pred), (1, 0, 2)) * (trainMax - trainMin) + trainMin
    val_pred = torch.permute(torch.squeeze(raw_val_pred), (1, 0, 2)) * (trainMax - trainMin) + trainMin
    test_pred = torch.permute(torch.squeeze(raw_test_pred), (1, 0, 2)) * (trainMax - trainMin) + trainMin

train_target = batch_dict["data_to_predict"] * (trainMax - trainMin) + trainMin
val_target = val_dict["data_to_predict"] * (trainMax - trainMin) + trainMin
test_target = test_dict["data_to_predict"] * (trainMax - trainMin) + trainMin

print(train_target.shape, val_target.shape, test_target.shape)

train_pred_rs = train_pred.reshape((-1, len(batch_dict["tp_to_predict"]), 32, 128)).cpu()
val_pred_rs = val_pred.reshape((-1, len(val_dict["tp_to_predict"]), 32, 128)).cpu()
test_pred_rs = test_pred.reshape((-1, len(test_dict["tp_to_predict"]), 32, 128)).cpu()

train_target_rs = train_target.reshape((-1, len(batch_dict["tp_to_predict"]), 32, 128)).cpu()
val_target_rs = val_target.reshape((-1, len(val_dict["tp_to_predict"]), 32, 128)).cpu()
test_target_rs = test_target.reshape((-1, len(test_dict["tp_to_predict"]), 32, 128)).cpu()

In [None]:
tpredict = batch_dict["tp_to_predict"]

In [None]:
from sunpy.visualization import colormaps as cm

lc3_reg = cm.cmlist['soholasco3']
lc3_reg

lc3 = cm.cmlist['soholasco3'].reversed()
lc3

In [None]:
time_idx_to_plot = np.array([0, 10, 20, 25, 30, 37, 43])
sim_idx = 0

train_err_sim = train_target_rs[sim_idx, :, :, :] - train_pred_rs[sim_idx, :, :]


fig, axs = plt.subplots(nrows=7, ncols=3, figsize=(10, 14))

for row in range(7):
    im = axs[row, 0].imshow(train_pred_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 0])
    axs[row, 0].set_title("Train S1 Predicted t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
      

for row in range(7):
    im = axs[row, 1].imshow(train_target_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 1])
    axs[row, 1].set_title("Train S1 True t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
    
    
for row in range(7):
    im = axs[row, 2].imshow(train_err_sim[time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                            vmin=train_err_sim.min(),
                            vmax=train_err_sim.max()
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 2])
    axs[row, 2].set_title("Err t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))

fig.suptitle("Scaled data Train Sim 1")
fig.tight_layout()

In [None]:
time_idx_to_plot = np.array([0, 10, 20, 25, 30, 37, 43])
sim_idx = 1

train_err_sim = train_target_rs[sim_idx, :, :, :] - train_pred_rs[sim_idx, :, :]


fig, axs = plt.subplots(nrows=7, ncols=3, figsize=(10, 14))

for row in range(7):
    im = axs[row, 0].imshow(train_pred_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 0])
    axs[row, 0].set_title("Train S2 Predicted t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
      

for row in range(7):
    im = axs[row, 1].imshow(train_target_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 1])
    axs[row, 1].set_title("Train S2 True t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
    
    
for row in range(7):
    im = axs[row, 2].imshow(train_err_sim[time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                            vmin=train_err_sim.min(),
                            vmax=train_err_sim.max()
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 2])
    axs[row, 2].set_title("Err t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))

fig.suptitle("Scaled data Train Sim 2")
fig.tight_layout()

In [None]:
time_idx_to_plot = np.array([0, 10, 20, 25, 30, 37, 43])
sim_idx = 2

train_err_sim = train_target_rs[sim_idx, :, :, :] - train_pred_rs[sim_idx, :, :]


fig, axs = plt.subplots(nrows=7, ncols=3, figsize=(10, 14))

for row in range(7):
    im = axs[row, 0].imshow(train_pred_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 0])
    axs[row, 0].set_title("Train S3 Predicted t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
      

for row in range(7):
    im = axs[row, 1].imshow(train_target_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 1])
    axs[row, 1].set_title("Train S3 True t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
    
    
for row in range(7):
    im = axs[row, 2].imshow(train_err_sim[time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                            vmin=train_err_sim.min(),
                            vmax=train_err_sim.max()
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 2])
    axs[row, 2].set_title("Err t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))

fig.suptitle("Scaled data Train Sim 3")
fig.tight_layout()

In [None]:
time_idx_to_plot = np.array([0, 10, 20, 25, 30, 37, 43])
sim_idx = 3

train_err_sim = train_target_rs[sim_idx, :, :, :] - train_pred_rs[sim_idx, :, :]


fig, axs = plt.subplots(nrows=7, ncols=3, figsize=(10, 14))

for row in range(7):
    im = axs[row, 0].imshow(train_pred_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 0])
    axs[row, 0].set_title("Train S4 Predicted t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
      

for row in range(7):
    im = axs[row, 1].imshow(train_target_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 1])
    axs[row, 1].set_title("Train S4 True t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
    
    
for row in range(7):
    im = axs[row, 2].imshow(train_err_sim[time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                            vmin=train_err_sim.min(),
                            vmax=train_err_sim.max()
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 2])
    axs[row, 2].set_title("Err t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))

fig.suptitle("Scaled data Train Sim 4")
fig.tight_layout()

In [None]:
time_idx_to_plot = np.array([0, 10, 20, 25, 30, 37, 43])
sim_idx = 4

train_err_sim = train_target_rs[sim_idx, :, :, :] - train_pred_rs[sim_idx, :, :]


fig, axs = plt.subplots(nrows=7, ncols=3, figsize=(10, 14))

for row in range(7):
    im = axs[row, 0].imshow(train_pred_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 0])
    axs[row, 0].set_title("Train S5 Predicted t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
      

for row in range(7):
    im = axs[row, 1].imshow(train_target_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 1])
    axs[row, 1].set_title("Train S5 True t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
    
    
for row in range(7):
    im = axs[row, 2].imshow(train_err_sim[time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                            vmin=train_err_sim.min(),
                            vmax=train_err_sim.max()
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 2])
    axs[row, 2].set_title("Err t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))

fig.suptitle("Scaled data Train Sim 5")
fig.tight_layout()

In [None]:
time_idx_to_plot = np.array([0, 10, 20, 25, 30, 37, 43])
sim_idx = 5

train_err_sim = train_target_rs[sim_idx, :, :, :] - train_pred_rs[sim_idx, :, :]


fig, axs = plt.subplots(nrows=7, ncols=3, figsize=(10, 14))

for row in range(7):
    im = axs[row, 0].imshow(train_pred_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 0])
    axs[row, 0].set_title("Train S6 Predicted t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
      

for row in range(7):
    im = axs[row, 1].imshow(train_target_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 1])
    axs[row, 1].set_title("Train S6 True t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
    
    
for row in range(7):
    im = axs[row, 2].imshow(train_err_sim[time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                            vmin=train_err_sim.min(),
                            vmax=train_err_sim.max()
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 2])
    axs[row, 2].set_title("Err t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))

fig.suptitle("Scaled data Train Sim 6")
fig.tight_layout()

In [None]:
time_idx_to_plot = np.array([0, 10, 20, 25, 30, 37, 43])
sim_idx = 0
val_err_sim = val_target_rs[sim_idx, :, :, :] - val_pred_rs[sim_idx, :, :]


fig, axs = plt.subplots(nrows=7, ncols=3, figsize=(10, 14))

for row in range(7):
    im = axs[row, 0].imshow(val_pred_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 0])
    axs[row, 0].set_title("Val Predicted t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
      

for row in range(7):
    im = axs[row, 1].imshow(val_target_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 1])
    axs[row, 1].set_title("Val True t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
    
    
for row in range(7):
    im = axs[row, 2].imshow(val_err_sim[time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                            vmin=val_err_sim.min(),
                            vmax=val_err_sim.max()
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 2])
    axs[row, 2].set_title("Err t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))

fig.suptitle("Scaled data Validation Sim 1")
fig.tight_layout()

In [None]:
time_idx_to_plot = np.array([0, 10, 20, 25, 30, 37, 43])
sim_idx = 1
val_err_sim = val_target_rs[sim_idx, :, :, :] - val_pred_rs[sim_idx, :, :]


fig, axs = plt.subplots(nrows=7, ncols=3, figsize=(10, 14))

for row in range(7):
    im = axs[row, 0].imshow(val_pred_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 0])
    axs[row, 0].set_title("Val S2 Predicted t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
      

for row in range(7):
    im = axs[row, 1].imshow(val_target_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 1])
    axs[row, 1].set_title("Val S2 True t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
    
    
for row in range(7):
    im = axs[row, 2].imshow(val_err_sim[time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                            vmin=val_err_sim.min(),
                            vmax=val_err_sim.max()
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 2])
    axs[row, 2].set_title("Err t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))

fig.suptitle("Scaled data Validation Sim 2")
fig.tight_layout()

In [None]:
time_idx_to_plot = np.array([0, 10, 20, 25, 30, 37, 43])
sim_idx = 2
val_err_sim = val_target_rs[sim_idx, :, :, :] - val_pred_rs[sim_idx, :, :]


fig, axs = plt.subplots(nrows=7, ncols=3, figsize=(10, 14))

for row in range(7):
    im = axs[row, 0].imshow(val_pred_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 0])
    axs[row, 0].set_title("Val S3 Predicted t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
      

for row in range(7):
    im = axs[row, 1].imshow(val_target_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 1])
    axs[row, 1].set_title("Val S3 True t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
    
    
for row in range(7):
    im = axs[row, 2].imshow(val_err_sim[time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                            vmin=val_err_sim.min(),
                            vmax=val_err_sim.max()
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 2])
    axs[row, 2].set_title("Err t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))

fig.suptitle("Scaled data Validation Sim 3")
fig.tight_layout()

In [None]:
time_idx_to_plot = np.array([0, 10, 20, 25, 30, 37, 43])
sim_idx = 3
val_err_sim = val_target_rs[sim_idx, :, :, :] - val_pred_rs[sim_idx, :, :]


fig, axs = plt.subplots(nrows=7, ncols=3, figsize=(10, 14))

for row in range(7):
    im = axs[row, 0].imshow(val_pred_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 0])
    axs[row, 0].set_title("Val S4 Predicted t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
      

for row in range(7):
    im = axs[row, 1].imshow(val_target_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 1])
    axs[row, 1].set_title("Val S4 True t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
    
    
for row in range(7):
    im = axs[row, 2].imshow(val_err_sim[time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                            vmin=val_err_sim.min(),
                            vmax=val_err_sim.max()
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 2])
    axs[row, 2].set_title("Err t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))

fig.suptitle("Scaled data Validation Sim 4")
fig.tight_layout()

### Plot Test Predictions

In [None]:
time_idx_to_plot = np.array([0, 10, 20, 25, 30, 37, 43])
sim_idx = 0
test_err_sim = test_target_rs[sim_idx, :, :, :] - test_pred_rs[sim_idx, :, :]


fig, axs = plt.subplots(nrows=7, ncols=3, figsize=(10, 14))

for row in range(7):
    im = axs[row, 0].imshow(test_pred_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 0])
    axs[row, 0].set_title("Test 1 Predicted t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
      

for row in range(7):
    im = axs[row, 1].imshow(test_target_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 1])
    axs[row, 1].set_title("Test 1 True t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
    
    
for row in range(7):
    im = axs[row, 2].imshow(val_err_sim[time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                            vmin=test_err_sim.min(),
                            vmax=test_err_sim.max()
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 2])
    axs[row, 2].set_title("Err t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))

fig.suptitle("Scaled data Test Sim 1")
fig.tight_layout()

In [None]:
time_idx_to_plot = np.array([0, 10, 20, 25, 30, 37, 43])
sim_idx = 1
test_err_sim = test_target_rs[sim_idx, :, :, :] - test_pred_rs[sim_idx, :, :]


fig, axs = plt.subplots(nrows=7, ncols=3, figsize=(10, 14))

for row in range(7):
    im = axs[row, 0].imshow(test_pred_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 0])
    axs[row, 0].set_title("Test 2 Predicted t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
      

for row in range(7):
    im = axs[row, 1].imshow(test_target_rs[sim_idx, time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                           vmin=trainMin,
                           vmax=trainMax
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 1])
    axs[row, 1].set_title("Test 2 True t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))
    
    
for row in range(7):
    im = axs[row, 2].imshow(val_err_sim[time_idx_to_plot[row], :, :],
                           origin="lower",
                           cmap=lc3,
                            vmin=test_err_sim.min(),
                            vmax=test_err_sim.max()
                           )
    plt.colorbar(im, fraction=0.046, pad=0.04, ax=axs[row, 2])
    axs[row, 2].set_title("Err t = {:.3f}".format(tpredict[time_idx_to_plot[row]]))

fig.suptitle("Scaled data Test Sim 2")
fig.tight_layout()

### Visualize Latent Space

In [None]:
# batch_dict = dut.get_next_batch(data_obj["train_dataloader"], device=device)
# val_dict = dut.get_next_batch(data_obj["val_dataloader"], device=device)
# test_dict = dut.get_next_batch(data_obj["test_dataloader"], device=device)


def getLatentStatesFromModel(batch_data, model):
    input_dim = data_obj["input_dim"]
    y = batch_data["observed_data"]
    t = batch_data["tp_to_predict"]
    init_state = y[:, :, :input_dim]
    nbatch, ntraj, nseq = init_state.shape
    init_state = init_state.reshape((nbatch * ntraj, 1, 32, 128))
    init_latent = model.encoder(init_state).reshape((nbatch, ntraj, args["latents"]))
    init_latent = torch.cat((init_latent, y[:, :, input_dim:]),-1)
    latent_states = odeint(model.pnode, init_latent, t)
    return latent_states[:, 0, :, :args["latents"]].cpu().detach().numpy()

raw_train_latent = getLatentStatesFromModel(batch_dict, model)
raw_val_latent = getLatentStatesFromModel(val_dict, model)
raw_test_latent = getLatentStatesFromModel(test_dict, model)

print(raw_train_latent.shape, raw_val_latent.shape, raw_test_latent.shape)


# raw_train_pred = model(batch_dict["tp_to_predict"], batch_dict["observed_data"])
# raw_val_pred = model(val_dict["tp_to_predict"], val_dict["observed_data"])
# raw_test_pred = model(test_dict["tp_to_predict"], test_dict["observed_data"])


# with torch.no_grad():
#     train_pred = torch.permute(torch.squeeze(raw_train_pred), (1, 0, 2)) * (trainMax - trainMin) + trainMin
#     val_pred = torch.permute(torch.squeeze(raw_val_pred), (1, 0, 2)) * (trainMax - trainMin) + trainMin
#     test_pred = torch.permute(torch.squeeze(raw_test_pred), (1, 0, 2)) * (trainMax - trainMin) + trainMin

# train_target = batch_dict["data_to_predict"] * (trainMax - trainMin) + trainMin
# val_target = val_dict["data_to_predict"] * (trainMax - trainMin) + trainMin
# test_target = test_dict["data_to_predict"] * (trainMax - trainMin) + trainMin

# print(train_target.shape, val_target.shape, test_target.shape)

# train_pred_rs = train_pred.reshape((-1, len(batch_dict["tp_to_predict"]), 32, 128)).cpu()
# val_pred_rs = val_pred.reshape((-1, len(val_dict["tp_to_predict"]), 32, 128)).cpu()
# test_pred_rs = test_pred.reshape((-1, len(test_dict["tp_to_predict"]), 32, 128)).cpu()

# train_target_rs = train_target.reshape((-1, len(batch_dict["tp_to_predict"]), 32, 128)).cpu()
# val_target_rs = val_target.reshape((-1, len(val_dict["tp_to_predict"]), 32, 128)).cpu()
# test_target_rs = test_target.reshape((-1, len(test_dict["tp_to_predict"]), 32, 128)).cpu()

In [None]:
time_idx_to_plot = np.array([0, 10, 20, 25, 30, 37, 43])
nt_to_plot = len(time_idx_to_plot)
colors = plt.cm.jet(np.linspace(0,1,nt_to_plot))

fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(6, 8))

for i, ax in enumerate(axs.flat):
    for j in range(nt_to_plot):
        ax.plot(np.linspace(1, args["latents"], args["latents"]), raw_train_latent[j, i, :],
                    linewidth=2,
                    color=colors[j],
                    label="t={:.3f}".format(data_obj["tpredict"][time_idx_to_plot[j]])
        )
    ax.set_xlabel("Latent dim")
    ax.set_ylabel("Values")
    ax.set_title("Sim {}".format(i + 1))
    ax.legend(loc="upper right", fontsize=4)
      
fig.suptitle("Train Latent Space")
fig.tight_layout()

In [None]:
time_idx_to_plot = np.array([0, 10, 20, 25, 30, 37, 43])
nt_to_plot = len(time_idx_to_plot)
colors = plt.cm.jet(np.linspace(0,1,nt_to_plot))

fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(6, 8))

for i, ax in enumerate(axs.flat):
    for j in range(nt_to_plot):
        ax.plot(np.linspace(1, args["latents"], args["latents"]), raw_val_latent[j, i, :],
                    linewidth=2,
                    color=colors[j],
                    label="t={:.3f}".format(data_obj["tpredict"][time_idx_to_plot[j]])
        )
    ax.set_xlabel("Latent dim")
    ax.set_ylabel("Values")
    ax.set_title("Val Sim {}".format(i + 1))
    ax.legend(loc="upper right", fontsize=4)
      
fig.suptitle("Val Latent Space")
fig.tight_layout()

In [None]:
time_idx_to_plot = np.array([0, 10, 20, 25, 30, 37, 43])
nt_to_plot = len(time_idx_to_plot)
colors = plt.cm.jet(np.linspace(0,1,nt_to_plot))

fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(6, 8))

for i, ax in enumerate(axs.flat):
    for j in range(nt_to_plot):
        ax.plot(np.linspace(1, args["latents"], args["latents"]), raw_test_latent[j, i, :],
                    linewidth=2,
                    color=colors[j],
                    label="t={:.3f}".format(data_obj["tpredict"][time_idx_to_plot[j]])
        )
    ax.set_xlabel("Latent dim")
    ax.set_ylabel("Values")
    ax.set_title("Test Sim {}".format(i + 1))
    ax.legend(loc="upper right", fontsize=4)
      
fig.suptitle("Test Latent Space")
fig.tight_layout()

In [None]:
# # MODEL TESTING (ARCHITECTURE DIMS)
# data_obj = dut.parse_datasets(validation_file, sim_file, param_file, args, torch.device("cpu"))
# batch_dict = dut.get_next_batch(data_obj["train_dataloader"], device=torch.device("cpu"))

# y0 = batch_dict['observed_data']
# t0 = batch_dict['observed_tp']
# yt = batch_dict['data_to_predict']
# tt = batch_dict['tp_to_predict']

# input_dim = data_obj["input_dim"]
# print(input_dim)

# model = nut.PNODE_Conv(input_dim=4096, latent_dim=args["latents"], param_dim=2, 
#                        device=device,
#                       n_layers=args["node-layers"],
#                       n_units=args["units"]).to(torch.device("cpu"))

# init_state = y0[:, :, :input_dim].to(torch.device("cpu"))
# print(init_state.shape)
# nbatch, ntraj, nseq = init_state.shape
# print(nbatch, ntraj, nseq)
# init_state = init_state.reshape((ntraj*nbatch, 1, 32, 128))
# print(init_state.shape)

# print(model.encoder[0](init_state).shape)
# print(model.encoder[0:2](init_state).shape)
# print(model.encoder[0:4](init_state).shape)
# print(model.encoder[0:5](init_state).shape)
# print(model.encoder[0:7](init_state).shape)
# print(model.encoder[0:8](init_state).shape)
# print(model.encoder[0:10](init_state).shape)
# print(model.encoder[0:11](init_state).shape)
# print(model.encoder[0:13](init_state).shape)
# print(model.encoder[0:15](init_state).shape)
# print(model.encoder[0:16](init_state).shape)
# print(model.encoder[0:18](init_state).shape)
# print(model.encoder(init_state).shape)
# print(model.encoder(init_state).reshape((nbatch, ntraj, args["latents"])).shape)

# init_latent = model.encoder(init_state).reshape((nbatch, ntraj, args["latents"]))
# init_latent = torch.cat((init_latent, y0[:, :, input_dim:]), -1)
# print(init_latent.shape)
# latent_states = odeint(model.pnode, init_latent, tt)
# print(latent_states.shape)

# ls_init = latent_states[:, :, :, :args["latents"]]
# print(ls_init.shape)

# ntraj2, nbatch2, nseq2, _ = ls_init.shape
# print(ntraj2, nbatch2, nseq2)

# latent_mlp = model.decoder_mlp(ls_init)
# print(latent_mlp.shape)

# decoder_features = latent_mlp.reshape((-1, 64, 1, 4))
# print(decoder_features.shape)

# print(model.decoder_conv[0](decoder_features).shape)
# print(model.decoder_conv[0:3](decoder_features).shape)
# print(model.decoder_conv[0:5](decoder_features).shape)
# print(model.decoder_conv[0:8](decoder_features).shape)
# print(model.decoder_conv[0:10](decoder_features).shape)

# print(model.decoder_conv[0:10](decoder_features).reshape(ntraj2, nbatch2, nseq2, 4096).shape)

# print(model(tt, y0).shape)

# print(model)