## Install all dependencies

In [None]:
!pip install jax
!pip install tqdm
!pip install numpy

## Import All Packages

In [None]:
#!/usr/bin/env python
# coding: utf-8:wq


# In[ ]:
#DeepONet JAX code solving Posson equation with spatial source term as its parameter
#Developed by Seid Koric, NCSA, University of Illinois, from the 1D code by Wang et. al DOI: 10.1126/sciadv.abi8605
#Explicit iterative jacobi scheme used for data generation 


from __future__ import print_function    
import jax
import jax.numpy as np
from jax import random, grad, vmap, jit, hessian, lax
#from jax.experimental import optimizers
from jax.example_libraries import optimizers
from jax.nn import relu
##from jax.ops import index_update, index
from jax.flatten_util import ravel_pytree

import itertools
from functools import partial
from torch.utils import data
from tqdm import trange, tqdm

import matplotlib.pyplot as plt
from scipy.interpolate import griddata
#get_ipython().run_line_magic('matplotlib', 'inline')

import time
import math
import numpy as onp
import pylab as py
#import scipy.sparse as sp                 # import sparse matrix library
#from jax.scipy.sparse.linalg import spsolve
#from jax.scipy.sparse.linalg import bicgstab
#import scipy.fftpack

#from scipy.interpolate import RectBivariateSpline

from matplotlib import pyplot, cm
from mpl_toolkits.mplot3d import Axes3D

## Create Helper Functions

In [None]:


py.rcParams.update({'font.size': 20})

# import the file where the differentiation matrix operators are defined
#from diff_matrices import Diff_mat_1D, Diff_mat_2D   



# In[ ]:


# Data generator
class DataGenerator(data.Dataset):
    def __init__(self, u, u_map, y, s,  
                 batch_size=64, rng_key=random.PRNGKey(1234)):
        'Initialization'
        self.u = u # input sample with NO repeats
        self.u_map = u_map # Repeat map
        self.y = y # location
        self.s = s # labeled data evulated at y (solution measurements, BC/IC conditions, etc.)
        
        self.N = u_map.shape[0]
        self.batch_size = batch_size
        self.key = rng_key

    def __getitem__(self, index):
        'Generate one batch of data'
        self.key, subkey = random.split(self.key)
        inputs, outputs = self.__data_generation(subkey)
        return inputs, outputs

    @partial(jit, static_argnums=(0,))
    def __data_generation(self, key):
        'Generates data containing batch_size samples'
        idx = random.choice(key, self.N, (self.batch_size,), replace=False)
        s = self.s[idx,:]
        y = self.y[idx,:]
        u = self.u[ self.u_map[idx] ,:]
        # Construct batch
        inputs = (u, y)
        outputs = s
        return inputs, outputs


# In[ ]:


# Define the neural net
def MLP(layers, activation=relu):
  ''' Vanilla MLP'''
  def init(rng_key):
      def init_layer(key, d_in, d_out):
          k1, k2 = random.split(key)
          glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.)
          W = glorot_stddev * random.normal(k1, (d_in, d_out))
          b = np.zeros(d_out)
          return W, b
      key, *keys = random.split(rng_key, len(layers))
      params = list(map(init_layer, keys, layers[:-1], layers[1:]))
      return params
  def apply(params, inputs):
      for W, b in params[:-1]:
          outputs = np.dot(inputs, W) + b
          inputs = activation(outputs)
      W, b = params[-1]
      outputs = np.dot(inputs, W) + b
      return outputs
  return init, apply


# In[ ]:


# Define the model
class DeepONet:
    def __init__(self, branch_layers, trunk_layers):    
        # Network initialization and evaluation functions
        self.branch_init, self.branch_apply = MLP(branch_layers, activation=np.tanh)  # or Relu 
        self.trunk_init, self.trunk_apply = MLP(trunk_layers, activation=np.tanh)     # or Relu

        # Initialize
        branch_params = self.branch_init(rng_key = random.PRNGKey(1234))
        trunk_params = self.trunk_init(rng_key = random.PRNGKey(4321))
        params = (branch_params, trunk_params)

        # Use optimizers to set optimizer initialization and update functions
        self.opt_init, \
        self.opt_update, \
        self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 
                                                                      decay_steps=2000, 
                                                                      decay_rate=0.9))
        self.opt_state = self.opt_init(params)

        # Used to restore the trained model parameters
        _, self.unravel_params = ravel_pytree(params)

        self.itercount = itertools.count()

        # Loggers
        self.loss_log = []

    # Define DeepONet architecture
    def operator_net(self, params, u, x, t):
        branch_params, trunk_params = params
        y = np.stack([x, t])
        B = self.branch_apply(branch_params, u)
        T = self.trunk_apply(trunk_params, y)
        outputs = np.sum(B * T)
        return  outputs
  
    # Define operator loss
    def loss_operator(self, params, batch):
        inputs, outputs = batch
        u, y = inputs
        # Compute forward pass
        s_pred = vmap(self.operator_net, (None, 0, 0, 0))(params, u, y[:,0], y[:,1])
        # Compute loss
        loss = np.mean((outputs.flatten() - s_pred.flatten())**2)
        return loss

    # Define a compiled update step
    @partial(jit, static_argnums=(0,))
    def step(self, i, opt_state, batch):
        params = self.get_params(opt_state)
        g = grad(self.loss_operator)(params, batch)
        return self.opt_update(i, g, opt_state)

    # Optimize parameters in a loop
    def train(self, dataset, nIter = 10000):
        # Define data iterators
        data_iterator = iter(dataset)

        pbar = trange(nIter)
        # Main training loop
        for it in pbar:
            # Fetch data
            batch = next(data_iterator)
           
            self.opt_state = self.step(next(self.itercount), self.opt_state, batch)
            
            if it % 100 == 0:
                params = self.get_params(self.opt_state)

                # Compute loss
                loss_value = self.loss_operator(params, batch)

                # Store loss
                self.loss_log.append(loss_value)
  
                # Print loss
                pbar.set_postfix({'Loss': loss_value})
           
    # Evaluates predictions at test points  
    @partial(jit, static_argnums=(0,))
    def predict_s(self, params, U_star, Y_star):
        s_pred = vmap(self.operator_net, (None, 0, 0, 0))(params, U_star, Y_star[:,0], Y_star[:,1])
        return s_pred


# In[ ]:


# Defining custom plotting functions
def my_contourf(x,y,F,ttl):
    cnt = py.contourf(x,y,F,12,cmap = 'jet')
    py.colorbar()
    py.xlabel('x'); py.ylabel('y'); py.title(ttl)
    return 0
    

def RBF(x1, x2, params):
    output_scale, lengthscales = params
    diffs = np.expand_dims(x1 / lengthscales, 1) - \
            np.expand_dims(x2 / lengthscales, 0)
    print ("diffs.shape = ", diffs.shape)
    r2 = np.sum(diffs**2, axis=2)
    return output_scale * np.exp(-0.5 * r2)

#RBF source function 
def func(x, y, length_scale):
    #Generate a GP sample
    #length_scale = 0.1
    N = len(x)
#    print("N = ",N)
    gp_params = (1.0, length_scale)
    jitter = 1e-10
#    print("x.shape = ", x.shape)
#    print("y.shape = ", y.shape)
    K = RBF(x, y, gp_params)
#    print("K.shape = ", K.shape)
    L = np.linalg.cholesky(K + jitter*np.eye(N))
    value = np.dot(L, np.random.normal(size=[N,N]))
    return value 

def my_fftshift(x, axes=None):
    """
    Shift the zero-frequency component to the center of the spectrum.
    This function swaps half-spaces for all axes listed (defaults to all).
    Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even.
    Parameters
    ----------
    x : array_like
        Input array.
    axes : int or shape tuple, optional
        Axes over which to shift.  Default is None, which shifts all axes.
    Returns
    -------
    y : ndarray
        The shifted array.
    See Also
    --------
    ifftshift : The inverse of `fftshift`.
    Examples
    --------
    >>> freqs = np.fft.fftfreq(10, 0.1)
    >>> freqs
    array([ 0.,  1.,  2., ..., -3., -2., -1.])
    >>> np.fft.fftshift(freqs)
    array([-5., -4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.])
    Shift the zero-frequency component only along the second axis:
    >>> freqs = np.fft.fftfreq(9, d=1./9).reshape(3, 3)
    >>> freqs
    array([[ 0.,  1.,  2.],
           [ 3.,  4., -4.],
           [-3., -2., -1.]])
    >>> np.fft.fftshift(freqs, axes=(1,))
    array([[ 2.,  0.,  1.],
           [-4.,  3.,  4.],
           [-1., -3., -2.]])
    """
    x = np.asarray(x)
    if axes is None:
        axes = tuple(range(x.ndim))
        shift = [dim // 2 for dim in x.shape]
    elif isinstance(axes, integer_types):
        shift = x.shape[axes] // 2
    else:
        shift = [x.shape[ax] // 2 for ax in axes]

    return np.roll(x, shift, axes)

def fftind(size):
    """ Returns a numpy array of shifted Fourier coordinates k_x k_y.
        
        Input args:
            size (integer): The size of the coordinate array to create
        Returns:
            k_ind, numpy array of shape (2, size, size) with:
                k_ind[0,:,:]:  k_x components
                k_ind[1,:,:]:  k_y components
                
        Example:
        
            print(fftind(5))
            
            [[[ 0  1 -3 -2 -1]
            [ 0  1 -3 -2 -1]
            [ 0  1 -3 -2 -1]
            [ 0  1 -3 -2 -1]
            [ 0  1 -3 -2 -1]]
            [[ 0  0  0  0  0]
            [ 1  1  1  1  1]
            [-3 -3 -3 -3 -3]
            [-2 -2 -2 -2 -2]
            [-1 -1 -1 -1 -1]]]
            
        """
    k_ind = np.mgrid[:size, :size] - int( (size + 1)/2 )
    k_ind = my_fftshift(k_ind)
    return( k_ind )


def gaussian_random_field(key, alpha = 3.0,
                          size = 128, 
                          flag_normalize = True):
    
    subkeys = random.split(key[0], num=2)
    """ Returns a numpy array of shifted Fourier coordinates k_x k_y.
        
        Input args:
            alpha (double, default = 3.0): 
                The power of the power-law momentum distribution
            size (integer, default = 128):
                The size of the square output Gaussian Random Fields
            flag_normalize (boolean, default = True):
                Normalizes the Gaussian Field:
                    - to have an average of 0.0
                    - to have a standard deviation of 1.0
        Returns:
            gfield (numpy array of shape (size, size)):
                The random gaussian random field
                
        Example:
        import matplotlib
        import matplotlib.pyplot as plt
        example = gaussian_random_field()
        plt.imshow(example)
        """
        
        # Defines momentum indices
    k_idx = fftind(size)

        # Defines the amplitude as a power law 1/|k|^(alpha/2)
    amplitude = np.power( k_idx[0]**2 + k_idx[1]**2 + 1e-10, -alpha/4.0 )
    #amplitude[0,0] = 0
    amplitude = amplitude.at[0,0].set(0)
    
        # Draws a complex gaussian random noise with normal
        # (circular) distribution
    noise = random.normal(subkeys[1], (size, size)) \
        + 1j * random.normal(subkeys[1], (size, size))
    
    #random.normal(subkeys[0], (size,size))
        # To real space
    gfield = np.fft.ifft2(noise * amplitude).real
    
        # Sets the standard deviation to one
    if flag_normalize:
        gfield = gfield - np.mean(gfield)
        gfield = gfield/np.std(gfield)
        
    return gfield


def interp2d(
    x: np.ndarray,
    y: np.ndarray,
    xp: np.ndarray,
    yp: np.ndarray,
    zp: np.ndarray,
    fill_value: np.ndarray = None,
) -> np.ndarray:
    """
    Bilinear interpolation on a grid.
    Args:
        x, y: 1D arrays of point at which to interpolate. Any out-of-bounds
            coordinates will be clamped to lie in-bounds.
        xp, yp: 1D arrays of points specifying grid points where function values
            are provided.
        zp: 2D array of function values. For a function `f(x, y)` this must
            satisfy `zp[i, j] = f(xp[i], yp[j])`
    Returns:
        1D array `z` satisfying `z[i] = f(x[i], y[i])`.
    """
    if xp.ndim != 1 or yp.ndim != 1:
        raise ValueError("xp and yp must be 1D arrays")
    if zp.shape != (xp.shape + yp.shape):
        raise ValueError("zp must be a 2D array with shape xp.shape + yp.shape")

    ix = np.clip(np.searchsorted(xp, x, side="right"), 1, len(xp) - 1)
    iy = np.clip(np.searchsorted(yp, y, side="right"), 1, len(yp) - 1)

    # Using Wikipedia's notation (https://en.wikipedia.org/wiki/Bilinear_interpolation)
    z_11 = zp[ix - 1, iy - 1]
    z_21 = zp[ix, iy - 1]
    z_12 = zp[ix - 1, iy]
    z_22 = zp[ix, iy]

    z_xy1 = (xp[ix] - x) / (xp[ix] - xp[ix - 1]) * z_11 + (x - xp[ix - 1]) / (
        xp[ix] - xp[ix - 1]
    ) * z_21
    z_xy2 = (xp[ix] - x) / (xp[ix] - xp[ix - 1]) * z_12 + (x - xp[ix - 1]) / (
        xp[ix] - xp[ix - 1]
    ) * z_22

    z = (yp[iy] - y) / (yp[iy] - yp[iy - 1]) * z_xy1 + (y - yp[iy - 1]) / (
        yp[iy] - yp[iy - 1]
    ) * z_xy2

    if fill_value is not None:
        oob = (x < xp[0]) | (x > xp[-1]) | (y < yp[0]) | (y > yp[-1])
        z = np.where(oob, fill_value, z)

    return z





#analytical quadratic source function 
def func_org(x,y):
    #value = (-4 * x**2  + 4*x) * (-4 * y**2 + 4*y)
    value = (-4 * (x-0.25)**2  + 4*(x-0.25) * (-4 * (y-0.25)**2 + 4*(y-0.25)))
    return value

#https://scipython.com/book/chapter-8-scipy/examples/two-dimensional-interpolation-with-scipyinterpolaterectbivariatespline/
def f_fn_f(xp, yp, zp, xc, yc):
    xp = onp.array(xp)
    yp = onp.array(yp)
    zp = onp.array(zp)
    xc = onp.array(xc)
    yc = onp.array(yc)  
    
    f_fn = RectBivariateSpline(xp[:,None], yp[:,None], zp)
    z_test_tmp = f_fn(xc,yc)
    return z_test_tmp

def plot2D(x, y, p):
    fig = pyplot.figure(figsize=(11, 7), dpi=100)
    ax = fig.gca(projection='3d')
    X, Y = np.meshgrid(x, y)
    surf = ax.plot_surface(X, Y, p[:], rstride=1, cstride=1, cmap=cm.viridis,
            linewidth=0, antialiased=False)
    ax.view_init(30, 225)
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
    ax.set_zlabel('$u$')
    pyplot.show()
    
@jax.jit   
def explicit_update(pd,dx,dy,z_test):
    # Initialization
    k = 0.01
#    print("Nx = ", Nx)
    UU_loc  = np.zeros((Ny, Nx))
#    pd = np.zeros((Ny, Nx))

    UU_loc = UU_loc.at[1:-1,1:-1].set(((pd[1:-1, 2:] + pd[1:-1, :-2]) * dy**2 +
                    (pd[2:, 1:-1] + pd[:-2, 1:-1]) * dx**2 +
                    z_test[1:-1, 1:-1]/k * dx**2 * dy**2) / 
                    (2 * (dx**2 + dy**2)))

#y is first indices in BC 
    UU_loc = UU_loc.at[0, :].set(0)
    UU_loc = UU_loc.at[Ny-1, :].set(0)
    UU_loc = UU_loc.at[:, 0].set(0)
    UU_loc = UU_loc.at[:, Nx-1].set(0)
        
    return(UU_loc)


# In[ ]:


# A diffusion-reaction numerical solver
# A diffusion-reaction numerical solver
def solve_ADR(key, Nx, Ny, P):
    # Generate subkeys
    subkeys = random.split(key, 2)

    x = np.linspace(0,1,Nx)        # x variables in 1D
    y = np.linspace(0,1,Ny)        # y variable in 1D

    # Initialization
    UU  = np.zeros((Ny, Nx))
    pd = np.zeros((Ny, Nx))
    
    
    dx = x[1] - x[0]                # grid spacing along x direction
    dy = y[1] - y[0]                # grid spacing along y direction

    X,Y = np.meshgrid(x,y)          # 2D meshgrid

# 1D indexing
    Xu = X.ravel()                  # Unravel 2D meshgrid to 1D array
    Yu = Y.ravel()


#None adds an axis to numpy array, x[:,None] has (100,1) dimnesion, single column
#y[None,:] has (1,100) dimnesion, i.e a single row 

    #bf = func_org(x[:,None], y[None,:])
    bf = gaussian_random_field(subkeys, 3, Nx)
    
    print("type of bf ", type(bf))
    
    #interpolate source function data 
#    start_time = time.time()
    #z_test = f_fn_f(x, y, bf, x, y)
    X_ij, Y_ij = np.meshgrid(x,y,indexing="ij")
    Xu_ij = X_ij.ravel()
    Yu_ij = Y_ij.ravel()
    u = interp2d(Xu_ij, Yu_ij, x, y, bf)
#    print("x.shape = ", x.shape)
#    print("y.shape = ", y.shape)
    #print("z_test.shape = ", z_test.shape)
#    print("2D Interpolation time = %1.6s" % (time.time()-start_time))
    #print(type(z_test))
    z_test = u.reshape(Nx,Ny)
#    plot2D(x, y, z_test)
#    py.figure(figsize = (14,7))
#    my_contourf(x,y,z_test,r'u distribution')
    
    
#    start_time = time.time()
    nt = 30000
    for it in range(nt):
        UU = explicit_update(pd,dx,dy,z_test)
        pd = UU
    
#    print("Solver time = %1.6s" % (time.time()-start_time))
    

#m defined outside is number of input sensors where u is evaluated 
#   m_x must be sqrt(m)
#    m_x = 10
    m_x = np.sqrt(m)
    m_x = np.array(m_x.astype(np.int32))
    
    xx = np.linspace(0, 1, m_x)
    yy = np.linspace(0, 1, m_x)
#    u = f_fn_f(x, y, bf, xx, yy).ravel()
    XX_ij, YY_ij = np.meshgrid(xx,yy,indexing="ij")
    XXu_ij = XX_ij.ravel()
    YYu_ij = YY_ij.ravel()
    u = interp2d(XXu_ij, YYu_ij, x, y, bf)

#P is number of locations for evalution of s
#P = Nx
    #idx = np.random.randint(0, max(Nx, Ny),(P, 2))
    idx = random.randint(subkeys[1], (P, 2), 0, max(Nx, Ny))
    y_s = np.concatenate([x[idx[:,0]][:,None], y[idx[:,1]][:,None]], axis = 1)
    s = UU[idx[:,0], idx[:,1]]

    #print("u.shape = ", u.shape)
    #print("y_s.shape = ", y_s.shape)
    #print("s.shape = ", s.shape)
    
    return (x, y, UU, z_test), (u, y_s, s)

# Geneate training data corresponding to one input sample
def generate_one_training_data(key, P):
    # Numerical solution
#    print("Nx in generate_one_training_data = ", Nx)
    (x, t, UU, z_test), (u, y, s) = solve_ADR(key, Nx , Nt, P)

    # u = np.tile(u, (P, 1))

    return u, y, s

# Geneate test data corresponding to one input sample
def generate_one_test_data(key, P):
    Nx = P
    Nt = P
    (x, t, UU, z_test), (u, y, s) = solve_ADR(key, Nx , Nt, P,)

    XX, TT = np.meshgrid(x, t)

    #u_test = np.tile(u, (P**2, 1))
    u_test = u
    y_test = np.hstack([XX.flatten()[:,None], TT.flatten()[:,None]])
    s_test = UU.T.flatten()
    #s_test = UU.flatten()

    return u_test, y_test, s_test, z_test


# Geneate training data corresponding to N input sample
def generate_training_data(key, N, P):
    jax.config.update("jax_enable_x64", True)
    keys = random.split(key, N)
    u_train, y_train, s_train= vmap(generate_one_training_data, (0, None))(keys, P)

    # u_train = np.float32(u_train.reshape(N * P, -1))
    u_train = np.float32(u_train.reshape(N, -1))
    y_train = np.float32(y_train.reshape(N * P, -1))
    s_train = np.float32(s_train.reshape(N * P, -1))

    jax.config.update("jax_enable_x64", False)
    return u_train, y_train, s_train

# Geneate test data corresponding to N input sample
def generate_test_data(key, N, P):

    jax.config.update("jax_enable_x64", True)
    keys = random.split(key, N)
   
    u_test, y_test, s_test, z_test = vmap(generate_one_test_data, (0, None))(keys, P)

    #u_test = np.float32(u_test.reshape(N * P**2, -1))
    #y_test = np.float32(y_test.reshape(N * P**2, -1))
    #s_test = np.float32(s_test.reshape(N * P**2, -1))
    
    z_test = np.float32(z_test)

    config.update("jax_enable_x64", False)
    return u_test, y_test, s_test, z_test

# Compute relative l2 error over N test samples.
def compute_error(key, P):
    # Generate one test sample
    u_test, y_test, s_test, z_test = generate_test_data(key, 1, P)
    # Predict  
    s_pred = model.predict_s(params, u_test, y_test)[:,None]
    # Compute relative l2 error
    error_s = np.linalg.norm(s_test - s_pred) / np.linalg.norm(s_test) 
    return error_s

## Generate Data

Creates the grid points that will get used to train the DeepONet using the generate_one_test_data function.

In [None]:
key = random.PRNGKey(0)

###generating data on all grid points for DeepXDE, no random choosing of P points 
### m is alos on all grid points, i.e. m = Nx x Ny

# Resolution of the solution
Nx = 128
Nt = 128
Ny = 128

N = 5500 # number of input samples
m = Nx * Ny   # number of input sensors for each u sample 

##(x, t, UU, z_test), (u, y, s) = solve_ADR(key, Nx , Ny, P_train)

u_test, y_test, s_test, z_test =  generate_one_test_data(key, Nx)

xy_train_test_ht = y_test
print("xy_train_test_ht.shape = ", xy_train_test_ht.shape)
np.save('xy_train_test_ht.npy', xy_train_test_ht)

u_test, y_test, s_test, z_test = generate_test_data(key, N, Nx)


#### debug for figure ####
"""
u_test, y_test, s_test, z_test = generate_test_data(key, 2, Nx)

print("z_test.shape = ", z_test.shape)
print("z_test[0].shape = ", z_test[0].shape)
print("z_test[1].shape = ", z_test[1].shape)
#print("z_test[0] = ", z_test[0])
#print("z_test[1] = ", z_test[1])

#z_test =z_test.reshape(z_test.shape[0],z_test.shape[1]*z_test.shape[2])
#print("z_test.shape = ", z_test.shape)
#print("z_test[0].shape = ", z_test[0].shape)
#print("z_test[1].shape = ", z_test[1].shape)
#print("z_test[0] = ", z_test[0])
#print("z_test[1] = ", z_test[1])

print("u_test.shape = ", u_test.shape)
u_test_0 = u_test[0] 
print("u_test_0.shape = ", u_test_0.shape)
#print("u_test[0] = ", u_test[0])
#print("u_test[1] = ", u_test[1])

z_test_0 = z_test[0]
z_test_1 = z_test[1]

print("z_test_0.shape = ", z_test_0.shape)
print("z_test_1.shape = ", z_test_1.shape)

s_test_0 = s_test[0]
print("s_test_0.shape = ", s_test_0.shape)

s_test_0_nx_ny = s_test_0.reshape(Nx,Ny)
print("s_test_0_nx_ny.shape = ", s_test_0_nx_ny.shape)

u_test_0_nx_ny = u_test_0.reshape(Nx,Ny)
print("u_test_0_nx_ny.shape + ", u_test_0_nx_ny.shape)

u_test_1_nx_ny = u_test[1].reshape(Nx,Ny)
s_test_1_nx_ny = s_test[1].reshape(Nx,Ny)

x = np.linspace(0, 1, Nx)
y = np.linspace(0, 1, Ny)


fig = plt.figure(figsize=(18,5))
plt.subplot(1,2,1)
#py.figure(figsize = (14,7))
#my_contourf(x,y,z_test_0.T,r'Source Distrubution')
my_contourf(x,y,u_test_1_nx_ny.T,r'Source Distrubution')
plt.tight_layout()
plt.subplot(1,2,2)
#py.figure(figsize = (14,7))
my_contourf(x,y,s_test_1_nx_ny,r'Reference Solution')
plt.tight_layout()
plt.savefig("temperature_sample1_u.jpg", dpi=300)
#plt.show()

"""
##print("y_test.shape = ", y_test.shape)
#print("s_test.shape = ", s_test.shape)

#with np.printoptions(threshold=np.inf):
#    print("s_test[0] = ", s_test[0])
#    print("s_test[1] = ", s_test[1])
#print("z_test.shape = ", z_test.shape)

data_s_train_ht, data_s_testing_ht = np.split(s_test, [5000], axis=0)
print("data_s_train_ht.shape = ", data_s_train_ht.shape)
print("data_s_testing_ht.shape = ", data_s_testing_ht.shape)
np.save("data_s_train_ht_3.npy", data_s_train_ht)
np.save("data_s_testing_ht_3.npy", data_s_testing_ht)


data_u0_train_ht, data_u0_testing_ht = np.split(u_test, [5000], axis=0)
print("data_u0_train_ht.shape = ", data_u0_train_ht.shape)
print("data_u0_testing_ht.shape = ", data_u0_testing_ht.shape)
np.save("data_u0_train_ht_3.npy", data_u0_train_ht)
np.save("data_u0_testing_ht_3.npy", data_u0_testing_ht)


#This version just generates data for DeepXDE DeepONet, it can still train with Jax if below is uncommented 

"""


# Plot solution
#py.figure(figsize = (14,7))
#my_contourf(x,t,UU,r'$\nabla^2 s + b$ = 0')

#u_train1, y_train1, s_train1 = generate_one_training_data(key, P_train)
#at P (128 output) locations, evaluete u_train with m (100 points) 
#print("u_train1.shape = ", u_train1.shape)
#print("y_train1.shape = ", y_train1.shape)
#print("s_train1.shape = ", s_train1.shape)


#print("u_train1 = ", u_train1)
#print("s_train1 = ", s_train1)

#N(5000) u samples, each at P output locations(128) , evaluate each u_train with m (100 points)

P_train = 800 # number of output locations (sensors)

u_train, y_train, s_train = generate_training_data(key, N, P_train)
u_map_train = np.repeat( np.arange(N) , P_train )
print("u_train.shape = ", u_train.shape)
print("u_map_train.shape = ", u_map_train.shape)
print("y_train.shape = ", y_train.shape)
print("s_train.shape = ", s_train.shape)

np.save('u_train_N5000_P400_M144', u_train, allow_pickle=True)
np.save('u_map_train_N5000_P400_M144', u_map_train, allow_pickle=True)
np.save('y_train_N5000_P400_M144', y_train, allow_pickle=True)
np.save('s_train_N5000_P400_M144', s_train, allow_pickle=True)


# In[ ]:


# Initialize model
#branch_layers = [m, 50, 50, 50, 50, 50]
branch_layers = [m, 100, 100, 100, 100, 100, 100]
#trunk_layers =  [2, 50, 50, 50, 50, 50]
trunk_layers =  [2, 100, 100, 100, 100, 100, 100]
model = DeepONet(branch_layers, trunk_layers)


# In[ ]:


# Create data set
batch_size = 10000
dataset = DataGenerator(u_train, u_map_train , y_train, s_train, batch_size)


# In[ ]:

start_time = time.process_time()
# Train
model.train(dataset, nIter=200000)
print("Training time = %6.2f" % (time.process_time()-start_time))


# In[ ]:


# Test data
N_test = 100 # number of input samples 
P_test = m # m number of input sensors for each u 
             # P number of output sensors
key_test = random.PRNGKey(1234567)
keys_test = random.split(key_test, N_test)

# Predict
params = model.get_params(model.opt_state)

# Compute error
error_s = vmap(compute_error, (0, None))(keys_test,P_test) 

print('mean of relative L2 error of s: {:.2e}'.format(error_s.mean()))
print('std of relative L2 error of s: {:.2e}'.format(error_s.std()))


# In[ ]:


#Plot for loss function
#import matplotlib.pyplot as plt
plt.figure(figsize = (6,5))
plt.plot(model.loss_log, lw=2)

plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.yscale('log')
plt.legend()
plt.tight_layout()
plt.show()


# In[ ]:


# Generate one test sample
#from scipy.interpolate import griddata
for i in range(1):
    key = random.PRNGKey(4511236)
    P_test = m
    Nx = m
    u_test, y_test, s_test, z_test = generate_test_data(key, 1, P_test)
#u_test, y_test, s_test = generate_one_test_data(key,P_test)

# Predict
    params = model.get_params(model.opt_state)
#params_tmp = onp.asarray(params)

    start_time = time.time()
    s_pred = model.predict_s(params, u_test, y_test)
    s_pred_tmp = onp.asarray(s_pred)
    print("Inference time = %1.6s" % (time.time()-start_time))

# Generate an uniform mesh
    x = np.linspace(0, 1, Nx)
    t = np.linspace(0, 1, Nt)
    XX, TT = np.meshgrid(x, t)

# Grid data
    S_pred = griddata(y_test, s_pred.flatten(), (XX,TT), method='cubic')
    S_test = griddata(y_test, s_test.flatten(), (XX,TT), method='cubic')

# Compute the relative l2 error 
    error = np.linalg.norm(S_pred - S_test, 2) / np.linalg.norm(S_test, 2) 
    print('Relative l2 errpr: {:.3e}'.format(error))
    z_test = np.squeeze(z_test, axis=0)
#z_test = np.float32(z_test)
    print(z_test.shape)
    print(type(z_test))


# In[ ]:


    fig = plt.figure(figsize=(18,5))
    plt.subplot(1,3,1)
#py.figure(figsize = (14,7))
    my_contourf(x,t,z_test.T,r'Source Distrubution')
    plt.tight_layout()
    plt.subplot(1,3,2)
    s_test_nx_nt = s_test.reshape(Nx,Nt)
#py.figure(figsize = (14,7))
    my_contourf(x,t,s_test_nx_nt,r'Reference Solution')
    plt.tight_layout()
    plt.subplot(1,3,3)
    s_pred_nx_nt = s_pred.reshape(Nx,Nt)
#py.figure(figsize = (14,7))
    my_contourf(x,t,s_pred_nx_nt,r'Predicted Solution')
    plt.tight_layout()
    #plt.savefig("seventh_2_sample_alpha5.jpg", dpi=300)
    plt.savefig("temperature_sample{}_jet12.jpg".format(i+1), dpi=300)
    plt.show()

"""