In [None]:
#importing all the necessary libraries
import numpy as np
import matplotlib.pyplot as plt
import scipy
import scipy.io

import jax
import jax.numpy as jnp
from jax import random
from jax import jit, vmap, pmap, grad, value_and_grad

from tqdm import tqdm

import flax
import flax.linen as nn
import optax

from typing import Callable, Tuple, List, Dict, Optional, Any, Sequence

from sklearn.model_selection import train_test_split

from functools import partial

import os
import sys
import pickle

In [None]:
#Set the random seed and create a JAX key

seed = 42
np.random.seed(seed)
key = random.PRNGKey(seed)

In [None]:
#Load the KdV dataset
data = jnp.load("data/data_kdv.npz")
u = data['u']  #Initial condition
xt = data['xt']  #grid
g_u = data['g_u']  #Output
ns = 1000
nx = 100
nt = 200

In [None]:
#Include u in the output and remake the grid

g_u = g_u.reshape(ns, nt, nx)   
u_ = u[:, jnp.newaxis, :]

g_u_new = jnp.concatenate([u_, g_u], axis = 1)

In [None]:
#Inputs = u, outputs = g_u_new
inputs = u
outputs = g_u_new

#Only consider half of the data upto timestep = 100
outputs = outputs[:,:100,]

ns, nt, nx = outputs.shape
print(f"ns = {ns}, nt = {nt}, nx = {nx}")

#Create the grid with coordinate pairs (t,x) to feed into trunk network
tspan = jnp.linspace(0, 5, g_u_new.shape[2])
xspan = jnp.linspace(0, 1, nx)

#Take only half of the temporal domain
tspan = tspan[:100]

#Create meshgrid for trunk net
[t,x] = jnp.meshgrid(tspan, xspan, indexing = 'ij')
grid = jnp.transpose(jnp.array([t.flatten(), x.flatten()]))
print(grid.shape)
print(grid)

# Split the data into training (2000) and testing (500) samples
inputs_train, inputs_test, outputs_train, outputs_test, train_indices, test_indices = \
            train_test_split(inputs, outputs, jnp.arange(ns), test_size=0.2, random_state=seed)

#Reshape outputs from (ns, nt, nx) to (ns, nt*nx)
outputs_train = outputs_train.reshape(outputs_train.shape[0], nt*nx)
outputs_test = outputs_test.reshape(outputs_test.shape[0], nt*nx)

# Check the shapes of the subsets
print("Shape of inputs_train:", inputs_train.shape)
print("Shape of inputs_test:", inputs_test.shape)
print("Shape of outputs_train:", outputs_train.shape)
print("Shape of outputs_test:", outputs_test.shape)

In [None]:
#Network Inputs - train
branch_inputs_train = inputs_train
trunk_inputs_train = grid

#Inspecting the shapes
print("Shape of train branch inputs: ",branch_inputs_train.shape)
print("Shape of train trunk inputs: ",trunk_inputs_train.shape)
print("Shape of train output: ",outputs_train.shape)

In [None]:
#Network Inputs - test
branch_inputs_test = inputs_test      
trunk_inputs_test = grid             

#Inspecting the shapes
print("Shape of test branch inputs: ",branch_inputs_test.shape)
print("Shape of test trunk inputs: ",trunk_inputs_test.shape)
print("Shape of test output: ",outputs_test.shape)

In [None]:
#Utility class for defining the branch network
class branch_net(nn.Module):

    layer_sizes: Sequence[int] 
    activation: Callable
    
    @nn.compact
    def __call__(self,x):
        init = nn.initializers.glorot_normal()
        
        for layer in self.layer_sizes[:-1]:
            x = nn.Dense(layer, kernel_init=init)(x)
            x = self.activation(x)
        x = nn.Dense(self.layer_sizes[-1], kernel_init=init)(x)
        return x

In [None]:
#Utility class for defining the trunk network
class trunk_net(nn.Module):

    layer_sizes: Sequence[int]    
    activation: Callable

    @nn.compact
    def __call__(self,x):
        init = nn.initializers.glorot_normal()
        
        for layer in self.layer_sizes:
            x = nn.Dense(layer, kernel_init=init)(x)
            x = self.activation(x)
        
        return x

In [None]:
#Defining the DeepONet model
class DeepONet(nn.Module):

    branch_net_config: Sequence[int]
    trunk_net_config: Sequence[int]

    def setup(self):

        self.branch_net = branch_net(self.branch_net_config, nn.swish)
        self.trunk_net = trunk_net(self.trunk_net_config, nn.swish)


    def __call__(self, x_branch, x_trunk):
        
        #Vectorize over multiple samples of input functions
        branch_outputs = vmap(self.branch_net, in_axes = 0)(x_branch)
        
        #Vectorize over multiple query points
        trunk_outputs = vmap(self.trunk_net, in_axes = 0)(x_trunk)
        
        inner_product = jnp.einsum('ik,jk->ij', branch_outputs, trunk_outputs)

        return inner_product

In [None]:
#Define the latent dimension at the output of branch/trunk net
latent_vector_size = 300

#Create the branch and trunk layer configurations
branch_network_layer_sizes = [150, 250, 450, 380, 320] + [latent_vector_size]
trunk_network_layer_sizes = [200, 220, 240, 250, 260, 280] + [latent_vector_size]

#Instantiate the DeepONet model
model = DeepONet(branch_net_config = branch_network_layer_sizes, 
                                      trunk_net_config = trunk_network_layer_sizes)

In [None]:
#Utility function to save model params
def save_model_params(params, path, filename):
    
    #Create output directory for saving model params
    if not os.path.exists(path):
        os.makedirs(path)
    
    save_path = os.path.join(path, filename)
    with open(save_path, 'wb') as f:
        pickle.dump(params, f)

#Utility function to load model params
def load_model_params(path, filename):
    load_path = os.path.join(path, filename)
    with open(load_path, 'rb') as f:
        params = pickle.load(f)
    return params

In [None]:
# Define the training process from here
@jax.jit
def loss_fn(params, branch_inputs, trunk_inputs, gt_outputs):
    predictions = model.apply(params, branch_inputs,trunk_inputs)
    mse_loss = jnp.mean(jnp.square(predictions - gt_outputs))
    return mse_loss

@jax.jit
def update(params, branch_inputs, trunk_inputs, gt_outputs, opt_state):
    loss, grads = jax.value_and_grad(loss_fn)(params, branch_inputs, trunk_inputs, gt_outputs)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# Initialize model parameters
params = model.init(key, branch_inputs_train, trunk_inputs_train)

# # Optimizer setup
lr_scheduler = optax.schedules.exponential_decay(init_value=1e-3, transition_steps=1000, decay_rate=0.95)
optimizer = optax.adam(learning_rate=lr_scheduler)
opt_state = optimizer.init(params)

training_loss_history = []
test_loss_history = []
num_epochs = int(1e5)
batch_size = 256

min_test_loss = jnp.inf

filepath = 'DeepONet_full_rollout'

In [None]:
for epoch in tqdm(range(num_epochs), desc="Training Progress"):

    #Perform mini-batching
    shuffled_indices = jax.random.permutation(random.PRNGKey(epoch), branch_inputs_train.shape[0])
    batch_indices = shuffled_indices[:batch_size]

    branch_inputs_train_batch = branch_inputs_train[batch_indices]
    outputs_train_batch = outputs_train[batch_indices]

    # Update the parameters and optimizer state
    params, opt_state, loss = update(
        params=params,
        branch_inputs=branch_inputs_train_batch,
        trunk_inputs=trunk_inputs_train,
        gt_outputs=outputs_train_batch,
        opt_state=opt_state
    )

    #Keep track of the train loss
    training_loss_history.append(loss)
    
    #Do predictions on the test data simultaneously
    test_mse_loss = loss_fn(params = params, 
                            branch_inputs = branch_inputs_test, 
                            trunk_inputs = trunk_inputs_test, 
                            gt_outputs = outputs_test)
    test_loss_history.append(test_mse_loss)
    
    #Save the params of the best model encountered till now
    if test_mse_loss < min_test_loss:
        best_params = params
        save_model_params(best_params, path = filepath, filename = 'model_params_best.pkl')
        min_test_loss = test_mse_loss
    
    #Print the train and test loss history every 1000 epochs
    if epoch % 1000 == 0:
        print(f"Epoch: {epoch}, training_loss_MSE: {loss}, test_loss_MSE: {test_mse_loss}, \
                  best_test_loss_MSE: {min_test_loss}")

In [None]:
#Visualize the train and test loss histories
plt.figure(dpi = 130)
plt.semilogy(np.arange(num_epochs), training_loss_history, label = "Train loss")
plt.semilogy(np.arange(num_epochs), test_loss_history, label = "Test loss")

plt.xlabel("Epochs")
plt.ylabel("Loss")

plt.tick_params(which = 'major', axis = 'both', direction = 'in', length = 6)
plt.tick_params(which = 'minor', axis = 'both', direction = 'in', length = 3.5)
plt.minorticks_on()

plt.grid(alpha = 0.3)
plt.legend(loc = 'best')
plt.savefig(filepath + "/loss_plot.jpeg", dpi = 800)
plt.show()

In [None]:
#Save the loss arrays
np.save(filepath + "/Train_loss.npy", training_loss_history)
np.save(filepath + "/Test_loss.npy", test_loss_history)

In [None]:
# Predictions
import sklearn
from sklearn import metrics
filepath = 'DeepONet_full_rollout'
#Import the best model saved after full training
best_params = load_model_params(path = filepath, filename = 'model_params_best.pkl')

#Ground truth output
outputs = g_u_new

branch_input_new = outputs[:, 0, :]
tspan = jnp.linspace(0, 1, 201)
xspan = jnp.linspace(0, 1, 100)

#Create meshgrid (t, x) for trunk net
[t_new, x_new] = jnp.meshgrid(tspan, xspan, indexing = 'ij')
grid = jnp.transpose(jnp.array([t_new.flatten(), x_new.flatten()]))

#Perform inferencing
predictions_outputs_new = model.apply(best_params, branch_input_new, grid)
predictions_outputs_new = predictions_outputs_new.reshape(predictions_outputs_new.shape[0], 201, 100)

#Randomly selecting "size" number of samples out of the test dataset
random_samples = np.random.choice(np.arange(outputs.shape[0]), size = 3, replace = 'True')

for i in random_samples:
    
    prediction_i = predictions_outputs_new[i, :, :]
    target_i = outputs[i, :, :]
    
    error_i = np.abs(prediction_i - target_i)
    
    plt.figure(figsize = (12,3))
    
    plt.subplot(1,3,1)
    contour1 = plt.contourf(xspan, tspan, prediction_i, levels = 20, cmap = 'jet')
    cbar1 = plt.colorbar()
    cbar1.ax.tick_params(labelsize = 12)
    plt.xlabel("x", fontsize = 14)
    plt.ylabel("t", fontsize = 14)
    plt.xticks(fontsize = 12)
    plt.yticks(fontsize = 12)
    plt.title("Predicted", fontsize = 16)
    
    plt.subplot(1,3,2)
    contour2 = plt.contourf(xspan, tspan, target_i, levels = 20, cmap = 'jet')
    cbar2 = plt.colorbar()
    cbar2.ax.tick_params(labelsize = 12)
    plt.xlabel("x", fontsize = 14)
    plt.ylabel("t", fontsize = 14)
    plt.xticks(fontsize = 12)
    plt.yticks(fontsize = 12)
    plt.title("Actual", fontsize = 16)
  
    
    plt.subplot(1,3,3)
    contour3 = plt.contourf(xspan, tspan, error_i, levels = 20, cmap = 'Wistia')
    cbar3 = plt.colorbar()
    cbar3.ax.tick_params(labelsize = 12)
    plt.xlabel("x", fontsize = 14)
    plt.ylabel("t", fontsize = 14)
    plt.xticks(fontsize = 12)
    plt.yticks(fontsize = 12)
    plt.title("Error", fontsize = 16)
    
    
    plt.tight_layout()
    plt.savefig(filepath + f"/Contour_plots_sidx_{i}.jpeg", dpi = 800)
    plt.show()

In [None]:
#Compute the relative L2 error incurred at every timestep

auto_reg_error = []
num_time_steps = 201

for i in range(num_time_steps):
    l2_error = jnp.linalg.norm(predictions_outputs_new[:,i,:] - outputs[:,i,:])/jnp.linalg.norm(outputs[:,i,:])
    auto_reg_error.append(l2_error)

In [None]:
#Save the auto_reg_error array for comparing with TI-DON approach
np.save(filepath + "/Auto_reg_error_full_rollout.npy", auto_reg_error)

In [None]:
#Save the predictions and ground truth outputs

np.save(filepath + "/u_pred.npy", predictions_outputs_new)
np.save(filepath + "/u_actual.npy", outputs)