In [None]:
#Import the necessary libraries
import os, sys, pickle
import jax, jaxlib
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scipy
from scipy.io import loadmat
import numpy as np

import torch
import flax
from flax import linen as nn
import optax
from sklearn.model_selection import train_test_split
from typing import Callable, Sequence

from tqdm import tqdm

In [None]:
#Set the random seed and create a JAX PRNG Key
seed = 42
np.random.seed(seed)
key = jax.random.PRNGKey(seed)

In [None]:
#Load the 2D Burgers' dataset

dataset = torch.load("Burgers_equation_2D_scalar.pt")
inputs = dataset['input_samples']
outputs = dataset['output_samples']     #Ns = 5000, Nt = 101, Nx = 32, Ny = 32

inputs = jnp.array(inputs)
outputs = jnp.array(outputs)

#Consider first 1000 samples due to memory constraints
inputs = inputs[:1000, :, :]
outputs = outputs[:1000, :, :]

In [None]:
#Delete the dataset and free up memory
del dataset

In [None]:
ns, nt, nx, ny = outputs.shape
print(f"ns: {ns}, nt: {nt}, nx: {nx}, ny: {ny}")

'''
Now, we need to create a training data where input is [(u0, u1, u2, u3, ...., u33)] and
output is [(u1, u2, u3,....., u34)]
'''

#Creating the input and output training data
init_timestep = 0
end_timestep = 33

input_data_NN = outputs[:, init_timestep, :, :]    
output_data_NN = outputs[:, init_timestep+1, :, :]

for i in range(init_timestep+1, end_timestep):
    input_data_NN = jnp.vstack((input_data_NN, outputs[:,i,:,:]))
    output_data_NN = jnp.vstack((output_data_NN, outputs[:,i+1,:,:]))

In [None]:
#Reshaping the output_data_NN from (ns*nt//3, nx, ny) to (ns*nt//3, nx*ny)
#Input_data_NN remains as it is, i.e., (ns*nt//3, nx, ny)
output_data_NN = output_data_NN.reshape(output_data_NN.shape[0], output_data_NN.shape[1]*output_data_NN.shape[2])

In [None]:
#Create the train and test data splits
input_data_NN_train, input_data_NN_test, output_data_NN_train, output_data_NN_test = \
                        train_test_split(input_data_NN, output_data_NN, test_size = 0.2, random_state = 42)

In [None]:
#Freeing memory by deleting input_data_NN and output_data_NN
del input_data_NN, output_data_NN

In [None]:
#Utility class for defining the branch network
class branch_net(nn.Module):

    layer_sizes: Sequence[int] 
    activation: Callable
    
    @nn.compact
    def __call__(self, x):
        init = nn.initializers.glorot_normal()
        
        # #x has shape (ns, nx, ny) - so add channel dimension: (ns, nx, ny, nc)
        x = x[..., jnp.newaxis]
        
        #2D Convolutional layers and pooling layers
        x = nn.Conv(features = 64, kernel_size = (3,3), strides = 1, padding = "SAME")(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides = (2, 2), padding = "SAME")
        
        x = nn.Conv(features = 64, kernel_size = (2, 2), strides = 1, padding = "SAME")(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape = (2,2), strides = (2,2), padding = "SAME")
        
        x = x.flatten()   #flattening layer
        
        #MLP layers
        for i, layer in enumerate(self.layer_sizes[:-1]):
            x = nn.Dense(layer, kernel_init = init)(x)
            x = self.activation(x)
        x = nn.Dense(self.layer_sizes[-1], kernel_init = init)(x)
        return x

In [None]:
#Utility class for defining the trunk network
class trunk_net(nn.Module):
    trunk_layer_config: Sequence[int]
    activation: Callable
    
    @nn.compact
    def __call__(self, x):
        
        init = nn.initializers.glorot_normal()
        
        #Trunk network forward pass
        for i, layer_size in enumerate(self.trunk_layer_config):
            x = nn.Dense(layer_size, kernel_init = init)(x)
            x = self.activation(x)
        
        return x

In [None]:
#Define the DeepONet model
class DeepONet(nn.Module):

    branch_net_config: Sequence[int]
    trunk_net_config: Sequence[int]

    def setup(self):

        self.branch_net = branch_net(self.branch_net_config, nn.activation.tanh)
        self.trunk_net = trunk_net(self.trunk_net_config, nn.activation.tanh)


    def __call__(self, x_branch, x_trunk):
        
        #Vectorize over multiple samples of input functions
        branch_outputs = jax.vmap(self.branch_net, in_axes = 0)(x_branch)
        
        #Vectorize over multiple query points
        trunk_outputs = jax.vmap(self.trunk_net, in_axes = 0)(x_trunk)       
        
        inner_product = jnp.einsum('ik,jk->ij', branch_outputs, trunk_outputs)

        return inner_product

In [None]:
#Form branch and trunk inputs train
xspan = jnp.linspace(0, 1, nx)
yspan = jnp.linspace(0, 1, ny)

#Create for trunk network - a meshgrid of only spatial coordinates
[x,y] = jnp.meshgrid(xspan, yspan, indexing = 'ij')
grid = jnp.transpose(jnp.array([x.flatten(), y.flatten()]))

In [None]:
#Creating the training data for branch and trunk inputs
branch_inputs_train = input_data_NN_train
trunk_inputs_train = grid
outputs_train = output_data_NN_train

print("Shape of branch inputs train: ",branch_inputs_train.shape)
print("Shape of trunk inputs train: ",trunk_inputs_train.shape)
print("Shape of outputs train: ",outputs_train.shape)
print("Shape of grid: ",grid.shape)

In [None]:
#For branch and trunk inputs test

branch_inputs_test = input_data_NN_test
trunk_inputs_test = grid
outputs_test = output_data_NN_test

print("Shape of branch inputs test: ",branch_inputs_test.shape)
print("Shape of trunk inputs test: ",trunk_inputs_test.shape)
print("Shape of outputs test: ",outputs_test.shape)
print("Shape of grid: ",grid.shape)

In [None]:
#DeepONet settings

#Define the latent dimension at the output of the branch/trunk net
latent_vector_size = 100

#Create the branch and trunk network layer configurations
branch_network_layer_sizes = [256, 128] + [latent_vector_size]
trunk_network_layer_sizes = [128]*4 + [latent_vector_size]

#Instantiate the DeepONet model
model = DeepONet(branch_network_layer_sizes, trunk_network_layer_sizes)

#Create a jitted model forward function
model_fn = jax.jit(model.apply)

In [None]:
#Utility function for saving the model params
def save_model_params(params, path, filename):
    
    #Create output directory for saving model params
    if not os.path.exists(path):
        os.makedirs(path)
    
    save_path = os.path.join(path, filename)
    with open(save_path, 'wb') as f:
        pickle.dump(params, f)

#Utility function for loading the model params
def load_model_params(path, filename):
    load_path = os.path.join(path, filename)
    with open(load_path, 'rb') as f:
        params = pickle.load(f)
    return params

In [None]:
@jax.jit
def loss_fn(params, branch_inputs, trunk_inputs, gt_outputs, dt=0.01):
    
    u_curr = branch_inputs  # Current state input (e.g., u(t))
    u_next = gt_outputs     # Ground truth next state (e.g., u(t+1))

    # Predict the system dynamics (u_dot) at the current state using the model
    u_dot = model_fn(params, u_curr, trunk_inputs)  # Model's predicted rate of change

    # Implementing the 4th-order Runge-Kutta (RK4) time-stepping method
    k1 = u_dot   #(ns*nt, nx*ny)
    k1 = k1.reshape(k1.shape[0], nx, ny)   #(ns*nt, nx, ny)
    
    k2 = model_fn(params, u_curr + 0.5 * dt * k1, trunk_inputs)   #(ns*nt, nx*ny)
    k2 = k2.reshape(k2.shape[0], nx, ny)    #(ns*nt, nx, ny)
    
    k3 = model_fn(params, u_curr + 0.5 * dt * k2, trunk_inputs)     #(ns*nt, nx*ny)
    k3 = k3.reshape(k3.shape[0], nx, ny)    #(ns*nt, nx, ny)
    
    k4 = model_fn(params, u_curr + dt * k3, trunk_inputs)    #(ns*nt, nx*ny)
    k4 = k4.reshape(k4.shape[0], nx, ny)    #(ns*nt, nx, ny)
    
    # Calculate the next state using RK4
    u_pred_next = u_curr + (dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4)  #(ns*nt, nx, ny)
    
    #Reshape u_pred_next to match compatibility of u_next
    u_pred_next = u_pred_next.reshape(u_pred_next.shape[0], nx*ny)

    # Compute the Mean Squared Error loss between the predicted and ground truth next states
    mse_loss = jnp.mean(jnp.square(u_pred_next - u_next))
    
    return mse_loss

In [None]:
@jax.jit
def update(params, branch_inputs, trunk_inputs, gt_outputs, opt_state):
    loss, grads = jax.value_and_grad(loss_fn)(params, branch_inputs, trunk_inputs, gt_outputs)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    return params, opt_state, loss

In [None]:
#Freeing memory by deleting inputs and outputs
del inputs, outputs

In [None]:
# Initialize model parameters
params = model.init(key, branch_inputs_train[0:1, ...], trunk_inputs_train[0:1, ...])

#Initialize optimizer for DeepONet
lr_scheduler = optax.schedules.exponential_decay(init_value=1e-3, transition_steps=5000, decay_rate=0.95)
optimizer = optax.adam(learning_rate=lr_scheduler)
opt_state = optimizer.init(params)

training_loss_history = []
test_loss_history = []
num_epochs = int(1.2e5)
batch_size = 64

min_test_loss = jnp.inf

filepath = 'TI-DON_2D_Burgers'

In [None]:
for epoch in tqdm(range(num_epochs), desc="Training Progress"):

    #Perform mini-batching
    shuffled_indices = jax.random.permutation(jax.random.PRNGKey(epoch), branch_inputs_train.shape[0])
    batch_indices = shuffled_indices[:batch_size]

    branch_inputs_train_batch = branch_inputs_train[batch_indices]
    outputs_train_batch = outputs_train[batch_indices]

    # Update the parameters and optimizer state
    params, opt_state, loss = update(
        params=params,
        branch_inputs=branch_inputs_train_batch,
        trunk_inputs=trunk_inputs_train,
        gt_outputs=outputs_train_batch,
        opt_state=opt_state
    )

    #Keep a track of the train loss
    training_loss_history.append(loss)
    
    #Do predictions on the test data simultaneously
    test_mse_loss = loss_fn(params = params, 
                            branch_inputs = branch_inputs_test, 
                            trunk_inputs = trunk_inputs_test, 
                            gt_outputs = outputs_test)
    test_loss_history.append(test_mse_loss)
    
    #Save the params of the best model encountered till now
    if test_mse_loss < min_test_loss:
        best_params = params
        save_model_params(best_params, path = filepath, filename = 'model_params_best.pkl')
        min_test_loss = test_mse_loss
    
    #Print the train and test loss history every 1000 epochs
    if epoch % 1000 == 0:
        print(f"Epoch: {epoch}, training_loss_MSE: {loss}, test_loss_MSE: {test_mse_loss}, \
                                best_test_loss_MSE: {min_test_loss}")

In [None]:
#Visualize the train and test loss histories
plt.figure(dpi = 130)
plt.semilogy(np.arange(epoch+1), training_loss_history, label = "Train loss")
plt.semilogy(np.arange(epoch+1), test_loss_history, label = "Test loss")

plt.xlabel("Epochs")
plt.ylabel("Loss")

plt.tick_params(which = 'major', axis = 'both', direction = 'in', length = 6)
plt.tick_params(which = 'minor', axis = 'both', direction = 'in', length = 3.5)
plt.minorticks_on()

plt.grid(alpha = 0.3)
plt.legend(loc = 'best')
plt.savefig(filepath + "/loss_plot.jpeg", dpi = 800)
plt.show()

In [None]:
#Save the loss arrays
np.save(filepath + "/Train_loss.npy",training_loss_history)
np.save(filepath + "/Test_loss.npy",test_loss_history)

In [None]:
#Do one AB2/AM3 inference step
@jax.jit
def inference_ab(u_curr, u_prev, trunk_inputs_test, dt=0.01):
    # Step 1: Apply the predictor (Adams-Bashforth) using u_curr and u_prev
    u_dot_curr = model_fn(best_params, u_curr, trunk_inputs_test)  # Predict the rate of change at u_curr
    u_dot_prev = model_fn(best_params, u_prev, trunk_inputs_test)  # Predict the rate of change at u_prev
    
    #Reshaping u_dot_curr and u_dot_prev to broadcast compatible with u_curr
    u_dot_curr = u_dot_curr.reshape(u_dot_curr.shape[0], nx, ny)
    u_dot_prev = u_dot_prev.reshape(u_dot_prev.shape[0], nx, ny)
    
    
    # Adams-Bashforth predictor (using previous two points)
    u_pred = u_curr + dt * (1.5 * u_dot_curr - 0.5 * u_dot_prev)
    
    # Step 2: Apply the corrector (Adams-Moulton) using the predicted u_pred
    u_dot_pred = model_fn(best_params, u_pred, trunk_inputs_test)  # Predict the rate of change at u_pred
    
    #Reshaping u_dot_pred to broadcast compatible with u_curr, u_dot_curr, u_dot_prev
    u_dot_pred = u_dot_pred.reshape(u_dot_pred.shape[0], nx, ny)
    
    # Adams-Moulton corrector (refine the prediction using u_pred)
    u_next = u_curr + dt * (5/12 * u_dot_pred + 8/12 * u_dot_curr - 1/12 * u_dot_prev)
    
    return u_next

In [None]:
#Do one RK4 inference step
@jax.jit
def inference_rk(u_curr, trunk_inputs_test, dt = 0.01):
    
    # Predict the system dynamics (u_dot) at the current state using the model
    u_dot = model_fn(params, u_curr, trunk_inputs_test)  # Model's predicted rate of change

    # Implementing the 4th-order Runge-Kutta (RK4) time-stepping method
    k1 = u_dot   #(ns*nt, nx*ny)
    k1 = k1.reshape(k1.shape[0], nx, ny)   #(ns*nt, nx, ny)
    
    k2 = model_fn(params, u_curr + 0.5 * dt * k1, trunk_inputs_test)   #(ns*nt, nx*ny)
    k2 = k2.reshape(k2.shape[0], nx, ny)    #(ns*nt, nx, ny)
    
    k3 = model_fn(params, u_curr + 0.5 * dt * k2, trunk_inputs_test)     #(ns*nt, nx*ny)
    k3 = k3.reshape(k3.shape[0], nx, ny)    #(ns*nt, nx, ny)
    
    k4 = model_fn(params, u_curr + dt * k3, trunk_inputs_test)    #(ns*nt, nx*ny)
    k4 = k4.reshape(k4.shape[0], nx, ny)    #(ns*nt, nx, ny)
    
    # Calculate the next state using RK4
    u_pred_next = u_curr + (dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4)  #(ns*nt, nx, ny)
    
    return u_pred_next

In [None]:
#Utility function for perform inferencing over all timesteps
def run_inference(initial_u, trunk_inputs_test, n_steps, method, dt=0.01):
    u_states = np.zeros(shape = (ns, nt, nx, ny))  # Array to store the states over time
    u_states[:,0,:,:] = initial_u
    
    # Initialize the previous state (this could be your u_0 and u_1, etc.)
    u_prev = initial_u  # Set the previous state to the initial state
    u_curr = initial_u  # Set the current state to the initial state
    
    for i in range(1, n_steps):
        
        if method == "AB":
            # Perform one inference step using the Adams-Bashforth method
            u_next = inference_ab(u_curr, u_prev, trunk_inputs_test, dt)

            # Assign the predicted state
            u_states[:, i, :, :] = u_next

            # Update previous and current states for the next step
            u_prev = u_curr
            u_curr = u_next
        
        elif method == "RK":
            #Perform one inference step using the RK-4 method
            u_next = inference_rk(u_curr, trunk_inputs_test, dt)
            
            # Assign the predicted state
            u_states[:, i, :, :] = u_next
            
            #Update the current state for the next step
            u_curr = u_next
    
    return u_states

In [None]:
# Load the best model parameters
best_params = load_model_params(path=filepath, filename='model_params_best.pkl')

#Reload the relevant datasets afresh for performing inference
dataset = torch.load("Burgers_equation_2D_scalar.pt")
inputs = dataset['input_samples']
outputs = dataset['output_samples']

inputs = jnp.array(inputs)
outputs = jnp.array(outputs)

#Consider first 1000 samples due to memory constraints
inputs = inputs[:1000, :, :]
outputs = outputs[:1000, :, :]

del dataset

method = "AB"

#Start with u(t-0, x, y)
u_curr = outputs[:, 0, :, :]

#Perform inference
u_pred = run_inference(u_curr, trunk_inputs_test, n_steps=nt, method = method)

In [None]:
#Randomly selecting "size" number of samples out of the test dataset
indices = np.random.choice(np.arange(u_pred.shape[0]), size = 3, replace = 'False')

x_test = jnp.linspace(0, 1, nx)
y_test = jnp.linspace(0, 1, ny)
t_test = jnp.linspace(0, 1, nt)

t_query = [25, 50, -1]

for idx in indices:
    
    for t in t_query:
        plt.figure(figsize = (12,3))
        plt.subplot(1, 3, 1)
        contour1 = plt.contourf(x_test, y_test, u_pred[idx, t, :, :], levels = 20, cmap = 'jet')
        plt.xlabel("x", fontsize = 14)
        plt.ylabel("t", fontsize = 14)
        plt.yticks(fontsize = 12)
        plt.xticks(fontsize = 12)
        cbar1 = plt.colorbar()
        cbar1.ax.tick_params(labelsize=12)
        plt.title("Predicted", fontsize = 16)

        plt.subplot(1, 3, 2)
        contour2 = plt.contourf(x_test, y_test, outputs[idx, t, :, :], levels = 20, cmap = 'jet')
        plt.xlabel("x", fontsize = 14)
        plt.ylabel("t", fontsize = 14)
        plt.xticks(fontsize = 12)
        plt.yticks(fontsize = 12)
        cbar2 = plt.colorbar()
        cbar2.ax.tick_params(labelsize=12)
        plt.title("Actual", fontsize=16)

        plt.subplot(1,3,3)
        contour3 = plt.contourf(x_test, y_test, jnp.abs(u_pred[idx,t, :, :] - 
                                                        outputs[idx,t, :, :]), cmap = 'Wistia')
        plt.xlabel("x", fontsize = 14)
        plt.ylabel("t", fontsize = 14)
        plt.xticks(fontsize = 12)
        plt.yticks(fontsize = 12)
        cbar3 = plt.colorbar()
        cbar3.ax.tick_params(labelsize=12)
        plt.title("Error", fontsize = 16)
        
        plt.suptitle(f"Sample Idx: {idx}, Timestep: {t}")
        
        plt.tight_layout()

        plt.savefig(filepath + f"/Contour_plots_sidx_{idx}_{method}.jpeg", dpi = 800)
        plt.show()

In [None]:
#Plotting the relative L2 error obtained at every timestep to show accummulation of autoregressive error

auto_reg_error = []
num_time_steps = nt

for i in range(num_time_steps):
    l2_error = jnp.linalg.norm(u_pred[:,i,:,:] - outputs[:,i,:,:])/jnp.linalg.norm(outputs[:,i,:,:])
    auto_reg_error.append(l2_error)
    
plt.plot(jnp.arange(num_time_steps), auto_reg_error)
plt.xlabel("Timesteps")
plt.ylabel("Relative L2 error")
plt.grid()
plt.show()

In [None]:
#Save the auto_reg_error array for comparing with TI approach
np.save(filepath + f"/Auto_reg_error_with_TI-DON_{method}.npy", auto_reg_error)

In [None]:
#Saving the u_pred and ground truth output arrays for separate postprocessing

save = True
if save:
    np.save(filepath + f"/u_pred_{method}.npy", u_pred)
    np.save(filepath + f"/actual_{method}.npy", outputs)