In [None]:
#Importing the necessary libraries
import os, sys, pickle
import jax, jaxlib
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scipy
from scipy.io import loadmat
import numpy as np

import flax
from flax import linen as nn
import optax
from sklearn.model_selection import train_test_split
from typing import Callable, Sequence

from tqdm import tqdm

import time

In [None]:
#Load the Burgers dataset
burgers = loadmat("Burger.mat")
output = burgers['output']
burgers.keys()

In [None]:
'''
Now, we need to create a training data where input is [(u0, u1, u2, u3, ...., u50)] and
output is [(u1, u2, u3,....., u51)]
'''

#Creating the input and output training data
init_timestep = 0
end_timestep = 50

input_data_NN = output[:,init_timestep,:]
output_data_NN = output[:,init_timestep+1,:]

for i in range(init_timestep+1, end_timestep):
    input_data_NN = jnp.vstack((input_data_NN, output[:,i,:]))
    output_data_NN = jnp.vstack((output_data_NN, output[:,i+1,:]))

In [None]:
#Create the training and testing data splits
input_data_NN_train, input_data_NN_test, output_data_NN_train, output_data_NN_test = \
                        train_test_split(input_data_NN, output_data_NN, test_size = 0.2, random_state = 42)
input_data_NN_train.shape, input_data_NN_test.shape, output_data_NN_train.shape, output_data_NN_test.shape

In [None]:
#Utility class for defining the branch network
class branch_net(nn.Module):
    branch_layer_config: Sequence[int]
    activation: Callable
    
    @nn.compact
    def __call__(self, x):
        
        init = nn.initializers.glorot_normal()
        
        #Branch network forward pass
        for i, layer_size in enumerate(self.branch_layer_config[:-1]):
            x = nn.Dense(layer_size, kernel_init = init)(x)
            x = self.activation(x)
        x = nn.Dense(self.branch_layer_config[-1], kernel_init = init)(x)
        return x

In [None]:
#Utility class for defining the trunk network
class trunk_net(nn.Module):
    trunk_layer_config: Sequence[int]
    activation: Callable
    
    @nn.compact
    def __call__(self, x):
        
        init = nn.initializers.glorot_normal()
        
        #Trunk network forward pass
        for i, layer_size in enumerate(self.trunk_layer_config):
            x = nn.Dense(layer_size, kernel_init = init)(x)
            x = self.activation(x)
        
        return x

In [None]:
#Define the DeepONet model
class DeepONet(nn.Module):

    branch_net_config: Sequence[int]
    trunk_net_config: Sequence[int]

    def setup(self):

        self.branch_net = branch_net(self.branch_net_config, nn.activation.tanh)
        self.trunk_net = trunk_net(self.trunk_net_config, nn.activation.tanh)


    def __call__(self, x_branch, x_trunk):
        
        #Vectorize over multiple samples of input functions
        branch_outputs = jax.vmap(self.branch_net, in_axes = 0)(x_branch)
        
        #Vectorize over multiple query points
        trunk_outputs = jax.vmap(self.trunk_net, in_axes = 0)(x_trunk)       
        
        inner_product = jnp.einsum('ik,jk->ij', branch_outputs, trunk_outputs)

        return inner_product

In [None]:
#Utility class for defining the auxiliary NN for getting the learnable RK4 slope coefficients
class LearnableRK4(nn.Module):
    hidden_dim: int = 32
    
    @nn.compact
    def __call__(self, u_curr):
        x = u_curr
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.activation.tanh(x)
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.activation.tanh(x)
        x = nn.Dense(4)(x)
        x = nn.activation.softmax(x)
        return x

In [None]:
#Perform one adaptive rk4 step
def dynamic_rk4_step(u_curr, model_fn, params, model_rk_fn, rk_params, trunk_inputs, dt):
    
    alpha = jax.vmap(model_rk_fn, in_axes = (None, 0))(rk_params, u_curr)         #(Shape: (batch_size,4)
    
    #Extract the coefficients  - each with shape (batch_size, 1)
    alpha1 = alpha[:,0:1]
    alpha2 = alpha[:,1:2]
    alpha3 = alpha[:,2:3]
    alpha4 = alpha[:,3:]

    #Get the RK4 slopes
    k1 = model_fn(params, u_curr, trunk_inputs)
    k2 = model_fn(params, u_curr + 0.5 * dt * k1, trunk_inputs)
    k3 = model_fn(params, u_curr + 0.5 * dt * k2, trunk_inputs)
    k4 = model_fn(params, u_curr + dt * k3, trunk_inputs)

    #Perform the adaptive RK4 update
    u_next = u_curr + dt * (alpha1 * k1 + alpha2 * k2 + alpha3 * k3 + alpha4 * k4)
    return u_next, alpha

In [None]:
@jax.jit
def loss_fn(params, rk_params, branch_inputs, trunk_inputs, gt_outputs, dt=0.01):
    
    u_curr = branch_inputs  # Current state input (e.g., u(t))
    u_next = gt_outputs     # Ground truth next state (e.g., u(t+1))
    
    u_pred_next, alpha = dynamic_rk4_step(u_curr, model_fn, params, model_rk_fn, rk_params, trunk_inputs, dt)

    # Compute the Mean Squared Error loss between the predicted and ground truth next states
    mse_loss = jnp.mean(jnp.square(u_pred_next - u_next))
    
    return (mse_loss, alpha)

In [None]:
@jax.jit
def update(params, rk_params, branch_inputs, trunk_inputs, gt_outputs, opt_state, opt_state_rk):
    
    #Update for DeepONet params
    (loss, _), grads = \
            jax.value_and_grad(loss_fn, argnums = 0, has_aux=True)(params, rk_params, branch_inputs, trunk_inputs, gt_outputs)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    #Update for RK params
    (_, alpha), rk_grads = \
        jax.value_and_grad(loss_fn, argnums = 1, has_aux=True)(params, rk_params, branch_inputs, trunk_inputs, gt_outputs)
    updates_rk, opt_state_rk = optimizer_rk.update(rk_grads, opt_state_rk)
    rk_params = optax.apply_updates(rk_params, updates_rk)
    
    return params, rk_params, opt_state, opt_state_rk, loss, alpha

In [None]:
#Form branch and trunk inputs train
grid = jnp.linspace(0, 1, 101)[:,jnp.newaxis]   #Trunk net takes only spatial coordinates
branch_inputs_train = input_data_NN_train
trunk_inputs_train = grid
outputs_train = output_data_NN_train

In [None]:
#For branch and trunk inputs test
branch_inputs_test = input_data_NN_test
trunk_inputs_test = grid
outputs_test = output_data_NN_test

In [None]:
#DeepONet settings

#Define the latent dimension at the output of the branch/trunk network
latent_vector_size = 60

#Create the branch and trunk layer configurations
branch_network_layer_sizes = [101] + [100]*6 + [latent_vector_size]
trunk_network_layer_sizes = [100]*6 + [latent_vector_size]

#Instantiate the DeepONet model
model = DeepONet(branch_network_layer_sizes, trunk_network_layer_sizes)

#Create a jitted DeepONet model forward function
model_fn = jax.jit(model.apply)

#Instantiate the learnable RK4 NN
model_rk = LearnableRK4()

#Create a jitted learnableRK4 model forward function
model_rk_fn = jax.jit(model_rk.apply)

In [None]:
#Utility function for saving the model params
def save_model_params(params, path, filename):
    
    #Create output directory for saving model params
    if not os.path.exists(path):
        os.makedirs(path)
    
    save_path = os.path.join(path, filename)
    with open(save_path, 'wb') as f:
        pickle.dump(params, f)

#Utility function for loading the model params
def load_model_params(path, filename):
    load_path = os.path.join(path, filename)
    with open(load_path, 'rb') as f:
        params = pickle.load(f)
    return params

In [None]:
# Initialize model parameters
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)

#Initialize all the network params - DeepONet and learnable RK4 NN
params = model.init(key, branch_inputs_train, trunk_inputs_train)
rk_params = model_rk.init(subkey, branch_inputs_train)

# Optimizer setup

#Initialize optimizer for DeepONet
lr_scheduler = optax.schedules.exponential_decay(init_value=1e-3, transition_steps=5000, decay_rate=0.95)
optimizer = optax.adam(learning_rate=lr_scheduler)
opt_state = optimizer.init(params)

#Initialize optimizer for learnable RK4 NN
lr_scheduler = optax.schedules.exponential_decay(init_value=1e-3, transition_steps=5000, decay_rate=0.95)
optimizer_rk = optax.adam(learning_rate=lr_scheduler)
opt_state_rk = optimizer_rk.init(rk_params)

training_loss_history = []
test_loss_history = []
num_epochs = int(1e5)
batch_size = 256

min_test_loss = jnp.inf

filepath = 'TI-DON_Burgers_learnableRK'

In [None]:
alpha_lst = []

for epoch in tqdm(range(num_epochs)):

    #Perform mini-batching
    shuffled_indices = jax.random.permutation(jax.random.PRNGKey(epoch), input_data_NN.shape[0])
    batch_indices = shuffled_indices[:batch_size]

    branch_inputs_train_batch = branch_inputs_train[batch_indices]
    outputs_train_batch = outputs_train[batch_indices]

    # Update the parameters and optimizer state
    params, rk_params, opt_state, opt_state_rk, loss, alpha = update(
        params=params,
        rk_params=rk_params,
        branch_inputs=branch_inputs_train_batch,
        trunk_inputs=trunk_inputs_train,
        gt_outputs=outputs_train_batch,
        opt_state=opt_state,
        opt_state_rk=opt_state_rk
    )
    #Keep a track of the training loss
    training_loss_history.append(loss)
    alpha_lst.append(alpha)
    
    #Do predictions on the test data simultaneously
    
    test_mse_loss, _ = loss_fn(params = params, 
                            rk_params = rk_params,
                            branch_inputs = branch_inputs_test, 
                            trunk_inputs = trunk_inputs_test, 
                            gt_outputs = outputs_test)
    test_loss_history.append(test_mse_loss)
    
    #Save the params of the best model encountered till now
    if test_mse_loss < min_test_loss:
        best_params = {"deeponet_params": params, "rk_params": rk_params}
        save_model_params(best_params, path = filepath, filename = 'model_params_best.pkl')
        min_test_loss = test_mse_loss
    
    #Print the train and test loss history every 1000 epochs
    if epoch % 1000 == 0:
        print(f"Epoch: {epoch}, training_loss_MSE: {loss}, test_loss_MSE: {test_mse_loss}, \
                                best_test_loss_MSE: {min_test_loss}")

In [None]:
#Visualize the train and test loss histories
plt.figure(dpi = 130)
plt.semilogy(np.arange(epoch+1), training_loss_history, label = "Train loss")
plt.semilogy(np.arange(epoch+1), test_loss_history, label = "Test loss")

plt.xlabel("Epochs")
plt.ylabel("Loss")

plt.tick_params(which = 'major', axis = 'both', direction = 'in', length = 6)
plt.tick_params(which = 'minor', axis = 'both', direction = 'in', length = 3.5)
plt.minorticks_on()

plt.grid(alpha = 0.3)
plt.legend(loc = 'best')
plt.savefig(filepath + "/loss_plot.jpeg", dpi = 800)
plt.show()

In [None]:
#Save the loss arrays
np.save("Train_loss.npy",training_loss_history)
np.save("Test_loss.npy",test_loss_history)

In [None]:
#For plotting the epoch wise learning of the alphas, save the alpha array

alpha_epoch_wise_arr = jnp.array(alpha_lst)
np.save(filepath + "/alpha_epoch_wise_arr.npy", alpha_epoch_wise_arr)

In [None]:
##Need to modify inferencing code as now we can use the learnt RK4 coefficients and do RK4 in prediction
#Instead of AB-AM predictor-corrector

In [None]:
#Perform one step of inferencing using adaptive RK4
@jax.jit
def inference(u_curr, trunk_inputs_test, dt=0.01):
    u_next = dynamic_rk4_step(u_curr, model_fn, best_params, model_rk_fn, 
                              best_rk_params, trunk_inputs_test, dt)
    return u_next

In [None]:
#Utility function for doing inference over all timesteps
def run_inference(initial_u, trunk_inputs_test, n_steps, dt=0.01):
    u_states = np.zeros_like(output)  # Array to store the states over time
    u_states[:,0,:] = initial_u
    
    # Initialize the current state (this could be your u_0 and u_1, etc.)
    u_curr = initial_u  # Set the current state to the initial state
    
    for i in range(1, n_steps):
        # Perform one inference step
        u_next = inference(u_curr, trunk_inputs_test, dt)
        
        # Assign the predicted state
        u_states[:, i, :] = u_next
        
        # Update current state for the next step
        u_curr = u_next
    
    return u_states

In [None]:
# Load the best model parameters
best_full_params = load_model_params(path=filepath, filename='model_params_best.pkl')
best_params = best_full_params['deeponet_params']
best_rk_params = best_full_params['rk_params']

#Start with u(t=0, x)
u_curr = output[:, 0, :]

#Perform inferencing
u_pred = run_inference(u_curr, trunk_inputs_test, n_steps=101)

In [None]:
#Randomly selecting "size" number of samples out of the test dataset
indices = np.random.choice(np.arange(output.shape[0]), size = 3, replace = 'True')

x_test = jnp.linspace(0,1,101)
t_test = jnp.linspace(0,1,101)

for idx in indices:
    plt.figure(figsize = (12,3))
    plt.subplot(1, 3, 1)
    contour1 = plt.contourf(x_test, t_test, u_pred[idx, :, :], levels = 20, cmap = 'jet')
    plt.xlabel("x", fontsize = 14)
    plt.ylabel("t", fontsize = 14)
    plt.yticks(fontsize = 12)
    plt.xticks(fontsize = 12)
    cbar1 = plt.colorbar()
    cbar1.ax.tick_params(labelsize=12)
    plt.title("Predicted", fontsize = 16)

    plt.subplot(1, 3, 2)
    contour2 = plt.contourf(x_test, t_test, output[idx, :, :], levels = 20, cmap = 'jet')
    plt.xlabel("x", fontsize = 14)
    plt.ylabel("t", fontsize = 14)
    plt.yticks(fontsize = 12)
    plt.xticks(fontsize = 12)
    cbar2 = plt.colorbar()
    cbar2.ax.tick_params(labelsize=12)
    plt.title("Actual", fontsize=16)

    plt.subplot(1,3,3)
    contour3 = plt.contourf(x_test, t_test, jnp.abs(u_pred[idx, :, :] - output[idx, :, :]), cmap = 'Wistia')
    plt.xlabel("x", fontsize = 14)
    plt.ylabel("t", fontsize = 14)
    plt.yticks(fontsize = 12)
    plt.xticks(fontsize = 12)
    cbar3 = plt.colorbar()
    cbar3.ax.tick_params(labelsize=12)
    plt.title("Error", fontsize = 16)

    plt.tight_layout()
    
    plt.savefig(filepath + f"/Contour_plots_sidx_{idx}.jpeg", dpi = 800)
    plt.show()

In [None]:
#Plotting the relative L2 error obtained at every timestep to show accummulation of autoregressive error

auto_reg_error = []
num_time_steps = 101

for i in range(num_time_steps):
    l2_error = jnp.linalg.norm(u_pred[:,i,:] - output[:,i,:])/jnp.linalg.norm(output[:,i,:])
    auto_reg_error.append(l2_error)
    
plt.plot(jnp.arange(num_time_steps), auto_reg_error)
plt.xlabel("Timesteps")
plt.ylabel("Relative L2 error")
plt.grid()
plt.show()

In [None]:
#Save the auto_reg_error array for comparing with TI approach
np.save(filepath + "/Auto_reg_error_with_TI-DON_learnableRK4.npy", auto_reg_error)

In [None]:
#Save the predictions and ground truth outputs

np.save(filepath + "/u_pred.npy", u_pred)
np.save(filepath + "/u_actual.npy", output)