All the results presented were obtained as follows:
1. By estimating the gradients in the physics-informed loss terms using forward mode automatic differentiation (AD).
2. The output field values at given grid points were computed in one forward pass of the network using the einsum function.

In [None]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torchsummary import summary
import torch.distributions as td
import math
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import argparse
import random
import os
import time 
import pickle

import sys
sys.path.append("../..")
from utils.networks import *
from utils.deeponet_networks_1d import *
from utils.visualizer_misc import *
from utils.visualizer_1d import *
from utils.forward_autodiff import *
from utils.misc import *

import warnings
warnings.filterwarnings("ignore")

In [None]:
# Tag this cell with 'parameters'
# parameters
seed = 0 # Seed number.
n_used = 200 # Number of full training fields used for estimating the data-driven loss term in the PI-Latent-NO
save = True # Save results.

In [None]:
if save == True:
    resultdir = os.path.join(os.getcwd(),'results','c_PI-Latent-NO_with-AE','seed='+str(seed)+'_n_used='+str(n_used)) 
    if not os.path.exists(resultdir):
        os.makedirs(resultdir)
else:
    resultdir = None

In [None]:
set_seed(seed)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# Load the data
data = np.load(os.path.join('..','..','data/Diffusion-reaction_dynamics_t=0to1/Diffusion-reaction_dynamics.npz')) # Load the .npz file
print(data)
print(data['t_span'].shape)
print(data['x_span'].shape)
print(data['input_s_samples'].shape) # Random Source fields: Gaussian random fields, Nsamples x 100, each sample is (1 x 100)
print(data['output_u_samples'].shape) # Time evolution of the solution field: Nsamples x 101 x 100.
                               # Each field is 101 x 100, rows correspond to time and columns respond to location.
                               # First row corresponds to solution at t=0 (1st time step)
                               # and next  row corresponds to solution at t=0.01 (2nd time step) and so on.
                               # last row correspond to solution at t=1 (101th time step).

In [None]:
# Convert NumPy arrays to PyTorch tensors
inputs = torch.from_numpy(data['input_s_samples']).float().to(device)
outputs = torch.from_numpy(data['output_u_samples']).float().to(device)

t_span = torch.from_numpy(data['t_span']).float().to(device)
x_span = torch.from_numpy(data['x_span']).float().to(device)
nt, nx = len(t_span), len(x_span) # number of discretizations in time and location.
print("nt =",nt, ", nx =",nx)
print("Shape of t-span and x-span:",t_span.shape, x_span.shape)
print("t-span:", t_span)
print("x-span:", x_span)

# Estimating grid points
T, X = torch.meshgrid(t_span, x_span)
# print(T)
# print(X)

# Split the data into training and testing samples
inputs_train, inputs_test, outputs_train, outputs_test = train_test_split(inputs, outputs, test_size=500, random_state=seed)

# Check the shapes of the subsets
print("Shape of inputs_train:", inputs_train.shape)
print("Shape of inputs_test:", inputs_test.shape)
print("Shape of outputs_train:", outputs_train.shape)
print("Shape of outputs_test:", outputs_test.shape)
print('#'*100)

In [None]:
# Of these full training fields available I am using only n_used fields for estimating the data-driven loss term in the PI-Latent-NO
inputs_train_used = inputs_train[:n_used, :]
print("Shape of inputs_train_used:", inputs_train_used.shape)
outputs_train_used = outputs_train[:n_used, :, :]
print("Shape of outputs_train_used:", outputs_train_used.shape)

In [None]:
latent_dim = 9 # d_z  

if n_used > 0:
    # Learning latent fields using Autoencoder

    class Autoencoder_MLP(nn.Module):

        def __init__(self, encoder_net, decoder_net):
            super().__init__()

            self.encoder_net = encoder_net
            self.decoder_net = decoder_net

        def forward(self, x):

            encoded = self.encoder_net(x)
            decoded = self.decoder_net(encoded)

            return decoded


    input_dim = nx

    encoder_net = DenseNet(layersizes=[input_dim] + [80, 60, 40] + [latent_dim], activation=nn.SiLU()) #nn.LeakyReLU() #nn.Tanh()
    encoder_net.to(device)
    # print(encoder_net)
    # print('ENCODER-NET SUMMARY:')
    # summary(encoder_net, input_size=(input_dim,))  
    # print('#'*100)

    decoder_net = DenseNet(layersizes=[latent_dim] + [40, 60, 80] + [input_dim], activation=nn.SiLU()) #nn.LeakyReLU() #nn.Tanh()
    decoder_net.to(device)
    # print(decoder_net)
    # print('DECODER-NET SUMMARY:')
    # summary(decoder_net, input_size=(latent_dim,))
    # print('#'*100)

    model_AE = Autoencoder_MLP(encoder_net, decoder_net)
    model_AE.to(device);

    if save == True:
        resultdir_ = os.path.join(resultdir, 'pretrained-AE_latent_dim='+str(latent_dim)) 
        if not os.path.exists(resultdir_):
            os.makedirs(resultdir_)
    else:
        resultdir_ = None

    data_used_for_AE = outputs_train_used.reshape(-1, nx)
    print("Shape of data_used_for_AE:", data_used_for_AE.shape)

    n_iterations_AE = 60000
    data_train = data_used_for_AE

    print(colored('LATENT DIMENSION = '+str(latent_dim), 'red'))
    print ("------STARTED-TRAINING------")
    start_time = time.time()

    bs = 512 # Batch size

    # Training
    optimizer = torch.optim.Adam(model_AE.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30000, gamma=0.1) # gamma=0.8

    iteration_list, loss_list, learningrates_list = [], [], []

    for iteration in range(n_iterations_AE):

        num_samples = len(data_train)
        indices = torch.randperm(num_samples).to(device) # Generate random permutation of indices
        data_batch = data_train[indices[0:bs]]
        #print(f"Shape of data_train_batch[{i}]:", data_batch.shape) # (bs, nx)

        optimizer.zero_grad()
        outputs = model_AE(data_batch)
        loss = nn.MSELoss()(outputs, data_batch)
        loss.backward()
        # torch.nn.utils.clip_grad_value_(model_AE.parameters(), clip_value=1.0)
        optimizer.step()
        scheduler.step()

        if iteration % 1000 == 0:
            print('Iteration %s -' % iteration, 'loss = %f,' % loss,
                  'learning rate = %f' % optimizer.state_dict()['param_groups'][0]['lr']) 

        iteration_list.append(iteration)
        loss_list.append(loss.item())
        learningrates_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
    print ("------ENDED-TRAINING------")

    if save == True:
        np.save(os.path.join(resultdir_,'iteration_list.npy'), np.asarray(iteration_list))
        np.save(os.path.join(resultdir_,'loss_list.npy'), np.asarray(loss_list))
        np.save(os.path.join(resultdir_,'learningrates_list.npy'), np.asarray(learningrates_list))

    plot_training_loss(resultdir_, iteration_list, loss_list, save)

    plot_learningrates(resultdir_, iteration_list, learningrates_list, save)

    # end timer
    end_time = time.time()
    training_time = end_time - start_time # time for AE network to train
    print("Time (sec) to complete:\n" +str(training_time)) # time for AE network to train
    print('*'*10)

    if save == True:
        torch.save(model_AE.state_dict(), os.path.join(resultdir_,'model_state_dict.pt'))

    # Evaluate on test data
    data_test = outputs_test.reshape(-1, nx) #(test_size*nt, nx)
    outputs = model_AE(data_test) #(test_size*nt, nx)
    reconstruction_loss_test = nn.MSELoss()(outputs, data_test)
    print(f"TEST DATA RECONSTRUCTION ERROR FOR LATENT DIMENSION {latent_dim}: {reconstruction_loss_test.item():.2e}")
    print('*'*10)

    for i in range(outputs_test.shape[0]):

        if (i+1) % 50 == 0:
            data_i = outputs_test[i] #(nt, nx)
            reconstructed_i = model_AE(data_i) #(nt, nx)

            plot_AE_reconstructions(i, resultdir_, data_i, reconstructed_i, X, T, 'seismic', save)

    print(colored('*'*115, 'red'))

In [None]:
if n_used > 0:
    # Load the pretrained autoencoder
    pretrained_autoencoder_model = Autoencoder_MLP(encoder_net, decoder_net)
    pretrained_autoencoder_model.to(device) 
    if save == True:
        pretrained_autoencoder_model.load_state_dict(torch.load(os.path.join(resultdir,
                                                                         'pretrained-AE_latent_dim='+str(latent_dim),
                                                                         'model_state_dict.pt'), map_location=device))
    if save == False:
        pretrained_autoencoder_model.load_state_dict(model_AE.state_dict())

    # Evaluating latent data and detach autoencoder (detaching is important, otherwise graph will be retained and takes lot of time to train)
    latent_outputs_train = torch.zeros((outputs_train.shape[0], outputs_train.shape[1], latent_dim)).to(device)
    latent_outputs_train_used = torch.zeros((outputs_train_used.shape[0], outputs_train_used.shape[1], latent_dim)).to(device)
    latent_outputs_test = torch.zeros((outputs_test.shape[0], outputs_test.shape[1], latent_dim)).to(device)
    for i in range(latent_outputs_train.shape[0]):
        latent_outputs_train[i] = pretrained_autoencoder_model.encoder_net(outputs_train[i]).detach()
    for i in range(latent_outputs_train_used.shape[0]):
        latent_outputs_train_used[i] = pretrained_autoencoder_model.encoder_net(outputs_train_used[i]).detach()
    for i in range(latent_outputs_test.shape[0]):
        latent_outputs_test[i] = pretrained_autoencoder_model.encoder_net(outputs_test[i]).detach()

    # Check the shapes of the subsets
    print("Shape of latent_outputs_train:", latent_outputs_train.shape)
    print("Shape of latent_outputs_train_used:", latent_outputs_train_used.shape)
    print("Shape of latent_outputs_test:", latent_outputs_test.shape)
    print('#'*100)

In [None]:
"""
input_neurons_latent_branch: Number of input neurons in the latent_branch net.
input_neurons_latent_trunk: Number of input neurons in the latent_trunk net.
latent_p: Number of output neurons in both the latent_branch and latent_trunk net.
"""
latent_p = latent_dim*16 # Number of output neurons in both the latent_branch and latent_trunk net.

input_neurons_latent_branch = nx # m
latent_branch_net = DenseNet(layersizes=[input_neurons_latent_branch] + [64]*3 + [latent_p], activation=nn.SiLU()) #nn.LeakyReLU() #nn.Tanh()
latent_branch_net.to(device)
# print(latent_branch_net)
print('LATENT BRANCH-NET SUMMARY:')
summary(latent_branch_net, input_size=(input_neurons_latent_branch,))  
print('#'*100)

input_neurons_latent_trunk = 1 # 1 corresponds to t
latent_trunk_net = DenseNet(layersizes=[input_neurons_latent_trunk] + [64]*3 + [latent_p], activation=nn.SiLU()) #nn.LeakyReLU() #nn.Tanh()
latent_trunk_net.to(device)
# print(latent_trunk_net)
print('LATENT TRUNK-NET SUMMARY:')
summary(latent_trunk_net, input_size=(input_neurons_latent_trunk,))
print('#'*100)

"""
input_neurons_reconstruction_branch: Number of input neurons in the reconstruction_branch net.
input_neurons_reconstruction_trunk: Number of input neurons in the reconstruction_trunk net.
reconstruction_q: Number of output neurons in both the reconstruction_branch and reconstruction_trunk net.
"""
reconstruction_q = 128 # Number of output neurons in both the reconstruction_branch and reconstruction_trunk net.

input_neurons_reconstruction_branch = latent_dim # d_z
reconstruction_branch_net = DenseNet(layersizes=[input_neurons_reconstruction_branch] + [64]*3 + [reconstruction_q], activation=nn.SiLU()) #nn.LeakyReLU() #nn.Tanh()
reconstruction_branch_net.to(device)
# print(reconstruction_branch_net)
print('RECONSTRUCTION BRANCH-NET SUMMARY:')
summary(reconstruction_branch_net, input_size=(input_neurons_reconstruction_branch,))  
print('#'*100)

input_neurons_reconstruction_trunk = 1 # 1 corresponds to x
reconstruction_trunk_net = DenseNet(layersizes=[input_neurons_reconstruction_trunk] + [64]*3 + [reconstruction_q], activation=nn.SiLU()) #nn.LeakyReLU() #nn.Tanh()
reconstruction_trunk_net.to(device)
# print(reconstruction_trunk_net)
print('RECONSTRUCTION TRUNK-NET SUMMARY:')
summary(reconstruction_trunk_net, input_size=(input_neurons_reconstruction_trunk,))
print('#'*100)

model = Latent_NO_model(latent_branch_net, latent_trunk_net, latent_dim, reconstruction_branch_net, reconstruction_trunk_net)
model.to(device);

In [None]:
num_learnable_parameters = (count_learnable_parameters(latent_branch_net)
                            + count_learnable_parameters(latent_trunk_net)
                            + count_learnable_parameters(reconstruction_branch_net)
                            + count_learnable_parameters(reconstruction_trunk_net))
print("Total number of learnable parameters:", num_learnable_parameters)

In [None]:
def loss_pde_residual(net, source_fields, t, x):
    
    _, u = net(source_fields, t, x) # u is (bs, neval_t, neval_x)
    
    # Using forward automatic differention to estimate derivatives in the physics informed loss
    tangent_t, tangent_x = torch.ones(t.shape).to(device), torch.ones(x.shape).to(device)
    ut  = FWDAD_first_order_derivative(lambda t: net(source_fields, t, x)[1], t, tangent_t) # (bs, neval_t, neval_x)
    uxx = FWDAD_second_order_derivative(lambda x: net(source_fields, t, x)[1], x, tangent_x) # (bs, neval_t, neval_x)
    
    bs_ = source_fields.shape[0]
    sf_values_ = torch.zeros((bs_, x.shape[0], 1)).to(device)
    for j in range(bs_):
        sf_values_[j] = linear_interpolation(x, x_span, source_fields[j]) # source function: s(x) values
    sf_values__ = sf_values_.reshape(-1, 1, x.shape[0]) # (bs, 1, neval_x)
    # Repeat elements along neval_x for neval_t times and reshape
    sf_values = sf_values__.repeat(1, neval_t, 1) # (bs, neval_t, neval_x) # s(x) values are same for all times
    
    pde_residual = (ut - (0.01*uxx) - (0.01*(u**2)) - sf_values)**2
    
    return torch.mean(pde_residual)

In [None]:
def loss_pde_bcs(net, source_fields, t, x):
    
    t_b1, x_b1 = t[0], x[0]
    t_b2, x_b2 = t[1], x[1]

    _, u_b1 = net(source_fields, t_b1, x_b1) # u is (bs, neval_t, 1)
    _, u_b2 = net(source_fields, t_b2, x_b2) # u is (bs, neval_t, 1)

    bc1_value, bc2_value = 0., 0.
    pde_bc1 = (u_b1 - bc1_value)**2
    pde_bc2 = (u_b2 - bc2_value)**2
    
    return torch.mean(pde_bc1) + torch.mean(pde_bc2)

In [None]:
def loss_pde_ic(net, source_fields, t, x):

    _, u_ic = net(source_fields, t, x) # u is (bs, 1, neval_x)
    
    ic_value = 0.
    pde_ic = (u_ic - ic_value)**2
    
    return torch.mean(pde_ic)

In [None]:
start_time = time.time()

bs = 64 # Batch size
neval_t = 16  # Number of time points at which latent output field is evaluated for a given input source field sample
neval_x = 16 
# neval_loc = neval_x  # Number of locations at which output field is evaluated at each time point.

neval_c = {'t': neval_t, 'loc': neval_x}  # Number of collocation points within the domain.
neval_b = {'t': neval_t, 'loc': 1}        # Number of collocation points on each boundary.
neval_i = {'t': 1, 'loc': neval_x}        # Number of collocation points at t=0.
        
# Training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25000, gamma=0.1) # gamma=0.8

iteration_list, loss_list, learningrates_list = [], [], []
datadriven_loss_list, pinn_loss_list = [], []

n_iterations = 50000
for iteration in range(n_iterations):
    
    if n_used > 0:
        indices_datadriven = torch.randperm(n_used).to(device) # Generate random permutation of indices
        inputs_train_used_batch = inputs_train_used[indices_datadriven[0:bs]]
        latent_outputs_train_used_batch = latent_outputs_train_used[indices_datadriven[0:bs]]
        outputs_train_used_batch = outputs_train_used[indices_datadriven[0:bs]]
        #print("Shape of inputs_train_used_batch:", inputs_train_used_batch.shape) # (bs, nx)
        #print("Shape of latent_outputs_train_used_batch:", latent_outputs_train_used_batch.shape) # (bs, nt, latent_dim)
        #print("Shape of outputs_train_used_batch:", outputs_train_used_batch.shape) # (bs, nt, nx)

        latent_predicted_values, reconstruction_predicted_values = model(inputs_train_used_batch, t_span.reshape(-1, 1), x_span.reshape(-1, 1)) # (bs, nt, latent_dim), (bs, nt, nx)
        latent_target_values = latent_outputs_train_used_batch # (bs, nt, latent_dim)
        reconstruction_target_values = outputs_train_used_batch # (bs, nt, nx)
        datadriven_loss = nn.MSELoss()(latent_predicted_values, latent_target_values) + nn.MSELoss()(reconstruction_predicted_values, reconstruction_target_values)
        # print('*********')
    elif n_used == 0:
        datadriven_loss = torch.tensor([0.]).to(device)
        # print('*********')
    
    indices_pinn = torch.randperm(len(inputs_train)).to(device) # Generate random permutation of indices
    inputs_batch = inputs_train[indices_pinn[0:bs]]
    #print(f"Shape of inputs_train_batch:", inputs_batch.shape) # (bs, nx)
    
    # points within the domain
    tc = td.uniform.Uniform(0., 1.).sample((neval_c['t'], 1)).to(device)
    xc = td.uniform.Uniform(0., 1.).sample((neval_c['loc'], 1)).to(device)

    # boundary points on the 2 boundaries (hard-coded)
    tb = [td.uniform.Uniform(0., 1.).sample((neval_b['t'], 1)).to(device),
          td.uniform.Uniform(0., 1.).sample((neval_b['t'], 1)).to(device)]
    xb = [torch.tensor([[0.]]).to(device), 
          torch.tensor([[1.]]).to(device)]

    # initial points
    ti = torch.zeros((1, 1)).to(device)
    xi = td.uniform.Uniform(0., 1.).sample((neval_i['loc'], 1)).to(device)

    pinn_loss = (loss_pde_residual(model, inputs_batch, tc, xc) 
               + loss_pde_bcs(model, inputs_batch, tb, xb) 
               + loss_pde_ic(model, inputs_batch, ti, xi))
    # print('*********')

    optimizer.zero_grad()
    loss = datadriven_loss + pinn_loss
    loss.backward()
    # torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)
    optimizer.step()
    scheduler.step()

    if iteration % 500 == 0:
        print('Iteration %s -' % iteration, 'loss = %f,' % loss,
              'data-driven loss = %f,' % datadriven_loss,'pinn loss = %f,' % pinn_loss,
              'learning rate = %f' % optimizer.state_dict()['param_groups'][0]['lr']) 

    iteration_list.append(iteration)
    loss_list.append(loss.item())
    datadriven_loss_list.append(datadriven_loss.item())
    pinn_loss_list.append(pinn_loss.item())
    learningrates_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
    
if save == True:
    np.save(os.path.join(resultdir,'iteration_list.npy'), np.asarray(iteration_list))
    np.save(os.path.join(resultdir,'loss_list.npy'), np.asarray(loss_list))
    np.save(os.path.join(resultdir, 'datadriven_loss_list.npy'), np.asarray(datadriven_loss_list))
    np.save(os.path.join(resultdir, 'pinn_loss_list.npy'), np.asarray(pinn_loss_list))
    np.save(os.path.join(resultdir,'learningrates_list.npy'), np.asarray(learningrates_list))

plot_loss_terms(resultdir, iteration_list, loss_list, datadriven_loss_list, pinn_loss_list, save)  
    
plot_training_loss(resultdir, iteration_list, loss_list, save)  

plot_learningrates(resultdir, iteration_list, learningrates_list, save)  
    
# end timer
end_time = time.time()
training_time = end_time - start_time

runtime_per_iter = training_time/n_iterations # in sec/iter

In [None]:
if save == True:
    torch.save(model.state_dict(), os.path.join(resultdir,'model_state_dict.pt'))
# model.load_state_dict(torch.load(os.path.join(resultdir,'model_state_dict.pt'), map_location=device))

In [None]:
# Predictions
latent_branch_inputs = inputs_test # (bs, m) = (bs, nx) 
latent_trunk_inputs = t_span.reshape(-1, 1) # (nt, 1)
reconstruction_trunk_inputs = x_span.reshape(-1, 1) # (nx, 1)
latent_predictions_test, reconstruction_predictions_test = model(latent_branch_inputs, latent_trunk_inputs, reconstruction_trunk_inputs)# (bs, nt, latent_dim), (bs, nt, nx)
# print(latent_predictions_test.shape, reconstruction_predictions_test.shape)

mse_list, latent_mse_list, reconstruction_mse_list, r2score_list, relerror_list = [], [], [], [], []

for i in range(inputs_test.shape[0]):

    latent_prediction_i, reconstruction_prediction_i = latent_predictions_test[i].unsqueeze(0), reconstruction_predictions_test[i].unsqueeze(0)# (1, nt, latent_dim), (1, nt, nx)
    
    if n_used > 0:
        latent_target_i = latent_outputs_test[i].unsqueeze(0) # (1, nt, latent_dim)
    reconstruction_target_i = outputs_test[i].unsqueeze(0) # (1, nt, nx)
    
    if n_used > 0:
        latent_mse_i = F.mse_loss(latent_prediction_i.cpu(), latent_target_i.cpu())
    elif n_used == 0:
        latent_mse_i = torch.tensor([0.])
    reconstruction_mse_i = F.mse_loss(reconstruction_prediction_i.cpu(), reconstruction_target_i.cpu())
    mse_i = latent_mse_i + reconstruction_mse_i
    
    r2score_i = metrics.r2_score(reconstruction_target_i.flatten().cpu().detach().numpy(), reconstruction_prediction_i.flatten().cpu().detach().numpy()) 
    relerror_i = np.linalg.norm(reconstruction_target_i.flatten().cpu().detach().numpy() - reconstruction_prediction_i.flatten().cpu().detach().numpy()) / np.linalg.norm(reconstruction_target_i.flatten().cpu().detach().numpy())
        
    latent_mse_list.append(latent_mse_i.item())
    reconstruction_mse_list.append(reconstruction_mse_i.item())
    mse_list.append(mse_i.item())
    r2score_list.append(r2score_i.item())
    relerror_list.append(relerror_i.item())

    if (i+1) % 10 == 0:
        plot_predictions(i, resultdir, reconstruction_target_i, reconstruction_prediction_i, x_span, inputs_test, X, T, nt, nx, r'$s(x)$', 'Source field', 'seismic', save)
        
latent_mse = sum(latent_mse_list) / len(latent_mse_list)
print("Latent Mean Squared Error Test:\n", latent_mse)
reconstruction_mse = sum(reconstruction_mse_list) / len(reconstruction_mse_list)
print("Reconstruction Mean Squared Error Test:\n", reconstruction_mse)
mse = sum(mse_list) / len(mse_list)
print("Mean Squared Error Test:\n", mse)
r2score = sum(r2score_list) / len(r2score_list)
print("R2 score Test:\n", r2score)
relerror = sum(relerror_list) / len(relerror_list)
print("Rel. L2 Error Test:\n", relerror)

In [None]:
if n_used > 0:
    # Plotting learned latent fields

    z_span = torch.arange(1, latent_dim+1, 1).float().to(device)
    # Estimating grid points
    T_z, Z = torch.meshgrid(t_span, z_span)
    # print(Z.shape, T_z.shape)
    # print(T_z)
    # print(Z)

    for i in range(inputs_test.shape[0]):

        latent_prediction_i = latent_predictions_test[i].unsqueeze(0) # (1, nt, latent_dim)
        latent_target_i = latent_outputs_test[i].unsqueeze(0) # (1, nt, latent_dim)
        reconstruction_target_i = outputs_test[i].unsqueeze(0) # (1, nt, nx)

        if (i+1) % 50 == 0:
            plot_latentfields(i, resultdir, reconstruction_target_i, latent_target_i, latent_prediction_i, x_span, inputs_test, z_span, X, T, Z, T_z, nt, nx, latent_dim, r'$s(x)$', 'Source field', 'True latent field from AE', 'seismic', save)


In [None]:
test_dict = {
    "inputs_test": inputs_test.cpu(),
    "outputs_test": outputs_test.cpu(),
    "predictions_test": reconstruction_predictions_test.cpu()
}
for key, value in test_dict.items():
    print(f"Shape of {key}: {value.shape}")
print(colored('#'*230, 'green'))

if save == True:
    torch.save(test_dict, os.path.join(resultdir,'test_dict.pth'))

In [None]:
performance_metrics(reconstruction_mse, r2score, relerror, training_time, runtime_per_iter, resultdir, save)