In [None]:
#importing all the necessary libraries
import numpy as np
import matplotlib.pyplot as plt
import scipy
import scipy.io

import torch
import jax
import jax.numpy as jnp
from jax import random
from jax import jit, vmap, pmap, grad, value_and_grad

from tqdm import tqdm

import flax
import flax.linen as nn
import optax

from typing import Callable, Tuple, List, Dict, Optional, Any, Sequence

from sklearn.model_selection import train_test_split

from functools import partial

import os
import sys
import pickle

In [None]:
#Set the random seed and create a JAX key

seed = 42
np.random.seed(seed)
key = random.PRNGKey(seed)

In [None]:
#Load the 2D Burgers' dataset
dataset = torch.load("Burgers_equation_2D_scalar.pt")

#Input initial conditions
inputs = dataset['input_samples']

#Ground truth output solution fields - #Ns = 5000, Nx = 32, Ny = 32, Nt = 101
outputs = dataset['output_samples']

#Defining t-domain, x-domain, y-domain
tspan = jnp.linspace(0, 1, 101)
xspan = jnp.linspace(0, 1, 32)
yspan = jnp.linspace(0, 1, 32)

#Convert to JAX NumPy arrays
inputs = jnp.array(inputs)
outputs = jnp.array(outputs)

#Consider only 1000 samples for analysis (due to memory constraints)
inputs = inputs[:1000, :, :]
outputs = outputs[:1000, :, :]    

In [None]:
#Free memory by deleting the dataset
del dataset

In [None]:
#Only consider one-third of the data upto timestep = 33
outputs = outputs[:,:33,:]

ns, nt, nx, ny = outputs.shape
print(f"ns: {ns}, nt: {nt}, nx: {nx}, ny: {ny}")

#Take only one-third of the temporal domain
tspan = tspan[:33]

#Create for trunk network
[t,x,y] = jnp.meshgrid(tspan, xspan, yspan, indexing = 'ij')
grid = jnp.transpose(jnp.array([t.flatten(), x.flatten(), y.flatten()]))


# Split the data into training (2000) and testing (500) samples
inputs_train, inputs_test, outputs_train, outputs_test = \
                    train_test_split(inputs, outputs, test_size=0.2, random_state=seed)

#Reshape the ground truth output solution field from (ns, nt, nx, ny) to (ns, nt*nx*ny)
outputs_train = outputs_train.reshape(outputs_train.shape[0], nt*nx*ny)
outputs_test = outputs_test.reshape(outputs_test.shape[0], nt*nx*ny)

# Check the shapes of the subsets
print("Shape of inputs_train:", inputs_train.shape)
print("Shape of inputs_test:", inputs_test.shape)
print("Shape of outputs_train:", outputs_train.shape)
print("Shape of outputs_test:", outputs_test.shape)

In [None]:
#Network Inputs - train
branch_inputs_train = inputs_train    
trunk_inputs_train = grid             

#Inspecting the shapes
print("Shape of train branch inputs: ",branch_inputs_train.shape)
print("Shape of train trunk inputs: ",trunk_inputs_train.shape)
print("Shape of train output: ",outputs_train.shape)

In [None]:
#Network Inputs - test
branch_inputs_test = inputs_test      
trunk_inputs_test = grid              

#Inspecting the shapes
print("Shape of test branch inputs: ",branch_inputs_test.shape)
print("Shape of test trunk inputs: ",trunk_inputs_test.shape)
print("Shape of test output: ",outputs_test.shape)

In [None]:
#Utility class for defining the branch network
class branch_net(nn.Module):

    layer_sizes: Sequence[int] 
    activation: Callable
    
    @nn.compact
    def __call__(self, x):
        init = nn.initializers.glorot_normal()
        
        # #x has shape (ns, nx, ny) - so add channel dimension: (ns, nx, ny, nc)
        x = x[..., jnp.newaxis]
        
        #2D Convolutional and pooling layers
        x = nn.Conv(features = 64, kernel_size = (3,3), strides = 1, padding = "SAME")(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides = (2, 2), padding = "SAME")
        
        x = nn.Conv(features = 64, kernel_size = (2, 2), strides = 1, padding = "SAME")(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape = (2,2), strides = (2,2), padding = "SAME")
         
        x = x.flatten()   #flattening layer
        
        #MLP layers
        for i, layer in enumerate(self.layer_sizes[:-1]):
            x = nn.Dense(layer, kernel_init = init)(x)
            x = self.activation(x)
        x = nn.Dense(self.layer_sizes[-1], kernel_init = init)(x)
        return x

In [None]:
#Utility class for defining the trunk network
class trunk_net(nn.Module):

    layer_sizes: Sequence[int]    
    activation: Callable
    
    @nn.compact
    def __call__(self, x):
        init = nn.initializers.glorot_normal()
        
        for i, layer in enumerate(self.layer_sizes):
            x = nn.Dense(layer, kernel_init = init)(x)
            x = self.activation(x)
        return x

In [None]:
#Defining the DeepONet model
class DeepONet(nn.Module):

    branch_net_config: Sequence[int]
    trunk_net_config: Sequence[int]

    def setup(self):

        self.branch_net = branch_net(self.branch_net_config, nn.activation.silu)
        self.trunk_net = trunk_net(self.trunk_net_config, nn.activation.silu)

    def __call__(self, x_branch, x_trunk):

        #Vectorize over multiple samples of input functions
        branch_outputs = vmap(self.branch_net, in_axes = 0)(x_branch)

        #Vectorize over multiple query points
        trunk_outputs = vmap(self.trunk_net, in_axes = 0)(x_trunk)
        
        inner_product = jnp.einsum('ik,jk->ij', branch_outputs, trunk_outputs)

        return inner_product

In [None]:
#Define the latent dimension at the output of branch/trunk net
latent_vector_size = 100

#Create the branch and trunk layer configurations
#Note that the config for the 2D Conv and pooling layers is hard coded into the branch net class
branch_network_layer_sizes = [256, 128] + [latent_vector_size]
trunk_network_layer_sizes = [128]*3 + [latent_vector_size]

#Instantiate the DeepONet model
model = DeepONet(branch_net_config = branch_network_layer_sizes, 
                                      trunk_net_config = trunk_network_layer_sizes)

In [None]:
#Utility function to save model params
def save_model_params(params, path, filename):
    
    #Create output directory for saving model params
    if not os.path.exists(path):
        os.makedirs(path)
    
    save_path = os.path.join(path, filename)
    with open(save_path, 'wb') as f:
        pickle.dump(params, f)

#Utility function to load model params
def load_model_params(path, filename):
    load_path = os.path.join(path, filename)
    with open(load_path, 'rb') as f:
        params = pickle.load(f)
    return params

In [None]:
# Define the training process from here
@jax.jit
def loss_fn(params, branch_inputs, trunk_inputs, gt_outputs):
    predictions = model.apply(params, branch_inputs,trunk_inputs)
    mse_loss = jnp.mean(jnp.square(predictions - gt_outputs))   
    l2_error = jnp.linalg.norm(predictions - gt_outputs)/jnp.linalg.norm(gt_outputs)
    return mse_loss, l2_error

@jax.jit
def update(params, branch_inputs, trunk_inputs, gt_outputs, opt_state):
    (loss, l2_error), grads = \
            jax.value_and_grad(loss_fn, has_aux=True)(params, branch_inputs, trunk_inputs, gt_outputs)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss, l2_error

# Initialize model parameters
params = model.init(key, branch_inputs_train, trunk_inputs_train)

# # Optimizer setup
lr_scheduler = optax.schedules.exponential_decay(init_value = 1e-3, transition_steps = 5000, decay_rate = 0.95)
optimizer = optax.adam(learning_rate=lr_scheduler)
opt_state = optimizer.init(params)

training_loss_history = []
test_loss_history = []
num_epochs = int(1e5)
batch_size = 128

min_test_mse_loss = jnp.inf

filepath = 'DeepONet_full_rollout'

In [None]:
#Freeing memory by deleting inputs and outputs

del inputs, outputs

In [None]:
for epoch in tqdm(range(num_epochs), desc="Training Progress"):

    #Perform mini-batching
    shuffled_indices = jax.random.permutation(jax.random.PRNGKey(epoch), branch_inputs_train.shape[0])
    batch_indices = shuffled_indices[:batch_size]

    branch_inputs_train_batch = branch_inputs_train[batch_indices]
    outputs_train_batch = outputs_train[batch_indices]

    # Update the parameters and optimizer state
    params, opt_state, loss, l2_error = update(
        params=params,
        branch_inputs=branch_inputs_train_batch,
        trunk_inputs=trunk_inputs_train,
        gt_outputs=outputs_train_batch,
        opt_state=opt_state
    )

    #Keep a track of the training loss
    training_loss_history.append(loss)
    
    #Do predictions on the test data simultaneously
    test_mse_loss, test_l2_error = loss_fn(params = params, 
                            branch_inputs = branch_inputs_test, 
                            trunk_inputs = trunk_inputs_test, 
                            gt_outputs = outputs_test)
    test_loss_history.append(test_mse_loss)
    
    #Save the params of the best model encountered till now    
    if test_mse_loss < min_test_mse_loss:
        best_params = params
        save_model_params(best_params, path = filepath, filename = 'model_params_best.pkl')
        min_test_mse_loss = test_mse_loss
        
    
    #Print the train and test loss history every 1000 epochs
    if epoch % 1000 == 0:
        print(f"Epoch: {epoch}, training_loss_MSE: {loss}, test_loss: {test_mse_loss}, \
              min_test_loss: {min_test_mse_loss}, min_test_l2_error: {min_test_l2_error}")

In [None]:
#Visualize the train and test loss histories
plt.figure(dpi = 130)
plt.semilogy(np.arange(epoch+1), training_loss_history, label = "Train loss")
plt.semilogy(np.arange(epoch+1), test_loss_history, label = "Test loss")

plt.xlabel("Epochs")
plt.ylabel("Loss")

plt.tick_params(which = 'major', axis = 'both', direction = 'in', length = 6)
plt.tick_params(which = 'minor', axis = 'both', direction = 'in', length = 3.5)
plt.minorticks_on()

plt.grid(alpha = 0.3)
plt.legend(loc = 'best')
plt.savefig(filepath + "/loss_plot.jpeg", dpi = 800)
plt.show()

In [None]:
#Save the loss arrays
np.save(filepath + "/train_loss.npy", training_loss_history)
np.save(filepath + "/test_loss.npy", test_loss_history)

In [None]:
#Reload the required datasets afresh for performing inferencing
dataset = torch.load("Burgers_equation_2D_scalar.pt")
inputs = dataset['input_samples']
outputs = dataset['output_samples']

inputs = jnp.array(inputs)
outputs = jnp.array(outputs)

inputs = inputs[:1000]
outputs = outputs[:1000]

#Free memory by deleting dataset
del dataset

tspan = jnp.linspace(0, 1, 101)
xspan = jnp.linspace(0, 1, 32)
yspan = jnp.linspace(0, 1, 32)

#Create for trunk network
[t,x,y] = jnp.meshgrid(tspan, xspan, yspan, indexing = 'ij')
grid = jnp.transpose(jnp.array([t.flatten(), x.flatten(), y.flatten()]))

#Creating grid for branch inputs new and trunk_inputs_new
branch_inputs_new = inputs
trunk_inputs_new = grid
trunk_inputs_new.shape, trunk_inputs_new

In [None]:
# Predictions
import sklearn
from sklearn import metrics
filepath = 'DeepONet_full_rollout'

#Import the best model saved after full training
best_params = load_model_params(path = filepath, filename = 'model_params_best.pkl')


#Perform inferencing
predictions_outputs_new = model.apply(best_params, branch_inputs_new, trunk_inputs_new)
predictions_outputs_new = predictions_outputs_new.reshape(predictions_outputs_new.shape[0], 101, 32, 32)

#Randomly selecting "size" number of samples out of the test dataset
random_samples = np.random.choice(np.arange(outputs.shape[0]), size = 3, replace = 'True')

t_query = [0, 20, 25, 50, -1]

for i in random_samples:
    
    for t in t_query:
        prediction_i = predictions_outputs_new[i, t, :, :]
        target_i = outputs[i, t, :, :]
        

        error_i = np.abs(prediction_i - target_i)

        plt.figure(figsize = (12,3))

        plt.subplot(1,3,1)
        contour1 = plt.contourf(xspan, yspan, prediction_i, levels = 20, cmap = 'jet')
        cbar1 = plt.colorbar()
        cbar1.ax.tick_params(labelsize = 12)
        plt.xlabel("x", fontsize = 14)
        plt.ylabel("y", fontsize = 14)
        plt.xticks(fontsize = 12)
        plt.yticks(fontsize = 12)
        plt.title("Predicted", fontsize = 16)

        plt.subplot(1,3,2)
        contour2 = plt.contourf(xspan, yspan, target_i, levels = 20, cmap = 'jet')
        cbar2 = plt.colorbar()
        cbar2.ax.tick_params(labelsize = 12)
        plt.xlabel("x", fontsize = 14)
        plt.ylabel("y", fontsize = 14)
        plt.xticks(fontsize = 12)
        plt.yticks(fontsize = 12)
        plt.title("Actual", fontsize = 16)


        plt.subplot(1,3,3)
        contour3 = plt.contourf(xspan, yspan, error_i, levels = 20, cmap = 'Wistia')
        cbar3 = plt.colorbar()
        cbar3.ax.tick_params(labelsize = 12)
        plt.xlabel("x", fontsize = 14)
        plt.ylabel("y", fontsize = 14)
        plt.xticks(fontsize = 12)
        plt.yticks(fontsize = 12)
        plt.title("Error", fontsize = 16)
        
        plt.suptitle(f"Idx: {i}, timestep: {t}")

        plt.tight_layout()
        plt.savefig(filepath + f"/Contour_plots_sidx_{i}_timestep_{t}.jpeg", dpi = 800)
        plt.show()

In [None]:
#Compute the relative L2 error incurred at every timestep

auto_reg_error = []
num_time_steps = 101

for i in range(num_time_steps):
    l2_error = jnp.linalg.norm(predictions_outputs_new[:,i,:,:] - outputs[:,i,:,:])/jnp.linalg.norm(outputs[:,i,:,:])
    auto_reg_error.append(l2_error)

In [None]:
#Save the auto_reg_error array for comparing with NODE approach
np.save(filepath + "/Auto_reg_error_full_rollout.npy", auto_reg_error)

In [None]:
#Saving the u_pred and ground truth output arrays for separate postprocessing

save = True
if save:
    np.save(filepath + "/u_pred.npy", u_pred)
    np.save(filepath + "/u_actual.npy", outputs)