In [1]:
import os
import time
import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt
from tqdm import trange
from jax import jvp, value_and_grad
from flax import linen as nn
from typing import Sequence
from functools import partial
import time
import tree_math as tm
import matplotlib.font_manager as font_manager
from copy import copy
from copy import deepcopy
from jax import jit
import numpy as np
import pickle as pkl
import pandas as pd
import pickle
import math

In [2]:
# Plot settings
font = {'family': 'serif',
        'color':  'black',
        'weight': 'normal',
        'size': 10,
        }

font_legend = font_manager.FontProperties(family='serif', style='normal')

plt.rcParams['font.family'] = 'serif'


## 1. Anant-Net

In [3]:
class Anant(nn.Module):
    features: Sequence[int]

    def einsum(self, outputs):
      dim = len(outputs)

      # einsum(a,b->c)
      a = 'za'
      b = 'zb'
      c = 'zab'
      pred = jnp.einsum(f'{a}, {b}->{c}', outputs[0], outputs[1])

      for i in range(dim-2):
          a = c
          b = f'z{chr(97+i+2)}'
          c = c+chr(97+i+2)
          if i == dim-3:
              c = c[1:]
          pred = jnp.einsum(f'{a}, {b}->{c}', pred, outputs[i+2])

      return pred


    def stack_body_networks(self, init, X):
        for fs in self.features[:-1]:
            X = nn.Dense(fs, kernel_init=init)(X)
            X = jnp.sin(X)
        X = nn.Dense(self.features[-1], kernel_init=init)(X)
        outputs = [jnp.transpose(X, (1, 0))]

        return outputs

    @nn.compact
    def __call__(self, x1, x2, x3, y1, y2, y3, z1, z2, z3):

        x = jnp.hstack([x1, x2, x3])
        y = jnp.hstack([y1, y2, y3])
        z = jnp.hstack([z1, z2, z3])

        inputs, outputs = [x, y, z], []
        init = nn.initializers.glorot_normal()

        for i,X in enumerate(inputs):
          outputs += self.stack_body_networks(init, X)

        u = self.einsum(outputs)
        return u

In [4]:
class Anant_test(nn.Module):
    features: Sequence[int]

    def einsum(self, outputs):
      dim = len(outputs)

      # einsum(a,b->c)
      a = 'za'
      b = 'zb'
      c = 'zab'
      pred = jnp.einsum(f'{a}, {b}->{c}', outputs[0], outputs[1])

      for i in range(dim-2):
          a = c
          b = f'z{chr(97+i+2)}'
          c = c+chr(97+i+2)
          if i == dim-3:
              c = c[1:]
          pred = jnp.einsum(f'{a}, {b}->{c}', pred, outputs[i+2])

      return pred


    def stack_body_networks(self, init, X):
        for fs in self.features[:-1]:
            X = nn.Dense(fs, kernel_init=init)(X)
            X = jnp.sin(X)
        X = nn.Dense(self.features[-1], kernel_init=init)(X)
        outputs = [jnp.transpose(X, (1, 0))]

        return outputs


    @nn.compact
    def __call__(self, x, y, z):

        inputs, outputs = [x, y, z], []
        init = nn.initializers.glorot_normal()

        for i,X in enumerate(inputs):
          outputs += self.stack_body_networks(init, X)

        u = self.einsum(outputs)
        return u

In [5]:
# hessian-vector product
def hvp_fwdfwd(f, primals, tangents, return_primals=False):
    """
    Function to compute the hessian via jacobian-vector product (jvp) of a function

    Arguments:
      f: function to differntiate
      primals: values for which the function (f) should be evaluated
      tangents: tangent vector for which the Jacobian-vector product should be evaluated
      return_primals: return primals if True

    Return:
      primals_out: function evaluated at the primals
      tangents_out: the Jacobian-vector product of function evaluated at primals with tangents.
    """
    g = lambda primals: jvp(f, (primals,), tangents)[1]
    primals_out, tangents_out = jvp(g, primals, tangents)
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out

In [6]:
@partial(jax.jit, static_argnums=(0,))
def apply_anant_collocation(apply_fn, params, train_data):
    """
    Function to compute the PDE residue along the active dimensions for a high-dimensional PDE

    Arguments:
      apply_fn: calling Anant-Net
      params: trainable parameters of Anant-Net
      train_data: collocation points

    Return:
      loss: residual loss
      gradient: gradient of residual loss
    """
    def residual_loss(params, x, y, z, source_term):

        x1, x2, x3 = x
        y1, y2, y3 = y
        z1, z2, z3 = z

        # compute u
        u = apply_fn(params, x1, x2, x3, y1, y2, y3, z1, z2, z3)

        # tangent vector dx/dx
        v_x = jnp.ones(x2.shape)
        v_y = jnp.ones(y2.shape)
        v_z = jnp.ones(z2.shape)

        # 2nd derivatives of u
        uxx = hvp_fwdfwd(lambda x2: apply_fn(params, x1, x2, x3, y1, y2, y3, z1, z2, z3), (x2,), (v_x,))
        uyy = hvp_fwdfwd(lambda y2: apply_fn(params, x1, x2, x3, y1, y2, y3, z1, z2, z3), (y2,), (v_y,))
        uzz = hvp_fwdfwd(lambda z2: apply_fn(params, x1, x2, x3, y1, y2, y3, z1, z2, z3), (z2,), (v_z,))

        return jnp.mean((- uzz - uyy - uxx + u - u**3 - source_term)**2)

    # unpack data
    xc, yc, zc, uc, _ = train_data

    # isolate loss func from redundant arguments
    loss_fn = lambda params: residual_loss(params, xc, yc, zc, uc)
    loss, gradient = jax.value_and_grad(loss_fn)(params)

    return loss, gradient

In [7]:
@partial(jax.jit, static_argnums=(0,))
def apply_anant_boundary(apply_fn, params, train_data, lamb=10):
    """
    Function to compute the data loss (boundary/initial conditions) for a high-dimensional PDE

    Arguments:
      apply_fn: calling Anant-Net
      params: trainable parameters of Anant-Net
      train_data: boundary points

    Return:
      loss: data loss
      gradient: gradient of data loss
    """
    def boundary_loss(params, X):
        loss = 0.
        for i in range(len(X)):
            loss += (1/len(X)) * jnp.mean((apply_fn(params, X[i][0][0], X[i][0][1], X[i][0][2],\
                                                    X[i][1][0], X[i][1][1], X[i][1][2],\
                                                    X[i][2][0], X[i][2][1], X[i][2][2]) - X[i][3])**2)
        return loss

    # isolate loss func from redundant arguments
    loss_fn = lambda params: lamb*boundary_loss(params, train_data)
    loss, gradient = jax.value_and_grad(loss_fn)(params)

    return loss, gradient

In [8]:
# optimizer step function
@partial(jax.jit, static_argnums=(0,))
def update_model(optim, gradient, params, state):
    """
    Function to update the trainable parameters based on the computed gradients

    Arguments:
      optim: optimizer to update the trainable parameters
      gradient: gradient of loss function
      params: trainable parameters of Anant-Net
      state: optimizer state

    Return:
      params: updated trainable parameters of Anant-Net
      state: current optimizer state
    """
    updates, state = optim.update(gradient, state, params)
    params = optax.apply_updates(params, updates)
    return params, state

## 2. Data generator

In [9]:
def allencahnNd_exact_u_(index_batch, fixed_batch, X, fX = 0.0, random_test = True):
    """
    Function to compute the exact solution of high-dimensional Allen-Cahn equation
    Note: the exact solutions generated using this function are meant for testing purposes and not utlized for model training

    Arguments:
      index_batch: active dimensions
      fixed_batch: inactive dimensions
      X: values for active dimensions
      fX: fixed value for inactive dimensions
      random_test: if "True" returns exact solution at randomly sampled points from the high-dimensional space for testing the trained Anant-Net model
                   if "False" returns the exact solution for a strucutred grid in the high-dimensional space

    Return:
      exact: exact solution
    """
    if random_test:
      ndims = len(index_batch) + len(fixed_batch)
      X_tot = [xi for x in X for xi in x]

    else:
      uxx = jnp.zeros_like(X[0])
      fX = fX * jnp.ones_like(X[0])

      ndims = len(index_batch) + len(fixed_batch)
      X_tot = []
      for i in range(ndims):
        if i in index_batch:
          id = np.argwhere(np.array(index_batch) == i).ravel()[0]
          X_tot.append(X[id])
        else:
          X_tot.append(fX)

    ############ Exact solution ##########
    key = jax.random.PRNGKey(444)

    exact = jnp.zeros_like(X_tot[0])

    for i in range(ndims):
      exact += X_tot[i]

    exact = 1/ndims * exact
    exact = exact**2 + jnp.sin(exact)

    del X_tot
    return jnp.atleast_3d(exact)


def allencahnNd_source_term_(index_batch, fixed_batch, X, fX = 0.0):
    """
    Function to compute the exact solution of high-dimensional Allen-Cahn equation

    Arguments:
      index_batch: active dimensions
      fixed_batch: inactive dimensions
      X: values for active dimensions
      fX: fixed value for inactive dimensions
      random_test: if "True" returns exact solution at randomly sampled points from the high-dimensional space for testing the trained Anant-Net model
                   if "False" returns the exact solution for a strucutred grid in the high-dimensional space

    Return:
      f: source term
    """
    fX_copy = fX

    f = jnp.zeros_like(X[0])
    fX = fX * jnp.ones_like(X[0])

    ndims = len(index_batch) + len(fixed_batch)
    X_tot = []
    for i in range(ndims):
      if i in index_batch:
        id = np.argwhere(np.array(index_batch) == i).ravel()[0]
        X_tot.append(X[id])
      else:
        X_tot.append(fX)

    v = jnp.ones(X_tot[0].shape)
    f = jnp.zeros(X_tot[0].shape)

    for i in range(ndims):
      f += X_tot[i]

    f = 1/ndims * f
    f = (1/ndims)*(jnp.sin(f) - 2.)

    u_exact = allencahnNd_exact_u_(index_batch, fixed_batch, X, fX_copy, random_test=False)

    f = f + u_exact - u_exact**3

    del X_tot
    return f

In [10]:
def generate_dims(seed, ndims, niters, bound = False):
  """
  Function to randomly generate active (and corresponding inactive) dimensions for a high-dimensional PDE

  Arguments:
    seed: seed for random number generation
    ndims: total number of dimensions for the pde
    niters: number of iterations
    bound: if "True" inactive dimensions will be evalauted only at boundary points (eg: {-1,1} in this case)
           if "False" inactive dimensions will be evaluated at any random point in the high-dimensional space

  Yield:
    a list of active dimensions and the corresponding fixed value at which all the inactive dimensions
    are evaluated in the high-dimensional space
  """
  idx = int(ndims/3)
  np.random.seed(seed)
  FLAG = 0

  for i in range(niters):
    if not FLAG:
      dims1 = list(np.random.permutation(jnp.arange(1, idx-1, 1)))
      dims2 = list(np.random.permutation(jnp.arange(idx+1, 2*idx-1, 1)))
      dims3 = list(np.random.permutation(jnp.arange(2*idx+1, 3*idx-1, 1)))

      dims1_ = dims1[0]
      dims2_ = dims2[0]
      dims3_ = dims3[0]

      dims1.pop(0)
      dims2.pop(0)
      dims3.pop(0)
      FLAG = 1
    else:
      dims1_ = dims1[0]
      dims2_ = dims2[0]
      dims3_ = dims3[0]

      dims1.pop(0)
      dims2.pop(0)
      dims3.pop(0)

      if len(dims1) == 0:
        FLAG = 0

    if bound:
      factor = np.random.choice(np.arange(-1.,2.,2.))
    else:
      factor = np.random.choice(np.arange(-1.0,1.0,0.1))

    yield [int(dims1_), int(dims2_), int(dims3_)], float(factor.round(2))

In [11]:
gen = generate_dims(444, 99, 35, 1)
for i in gen:
  print(i)

([2, 56, 96], 1.0)
([5, 38, 89], 1.0)
([6, 54, 88], -1.0)
([20, 60, 97], -1.0)
([14, 51, 72], 1.0)
([25, 34, 93], -1.0)
([7, 48, 83], 1.0)
([19, 58, 94], 1.0)
([26, 40, 70], 1.0)
([27, 39, 95], -1.0)
([1, 44, 74], 1.0)
([18, 35, 85], 1.0)
([12, 46, 69], 1.0)
([15, 45, 75], 1.0)
([30, 61, 92], 1.0)
([3, 50, 86], 1.0)
([22, 52, 84], 1.0)
([11, 57, 73], -1.0)
([29, 55, 78], 1.0)
([10, 43, 71], 1.0)
([28, 49, 90], 1.0)
([16, 36, 81], 1.0)
([23, 47, 80], -1.0)
([13, 53, 76], -1.0)
([8, 63, 68], 1.0)
([21, 42, 91], 1.0)
([31, 37, 79], -1.0)
([9, 59, 82], 1.0)
([24, 64, 87], -1.0)
([17, 41, 67], -1.0)
([4, 62, 77], 1.0)
([11, 54, 84], -1.0)
([30, 34, 86], -1.0)
([6, 42, 82], 1.0)
([4, 57, 67], -1.0)


### Test Data

In [12]:
def _test_generator_allencahnNd_random(key, ninputs):
    '''
    factor = -1,1: Boudnary points
    factor != -1,1: Interior points
    '''

    data = []
    dim_list = []

    # total number of dimensions for the pde
    ndims = 3*ninputs

    # dimension index
    lst = list(np.arange(0,ndims,1))

    # seeds to randomly generate activated dimensions
    seeds = np.arange(0,len(key),1)

    # collocation points
    for j in range(len(key)):
      keys = jax.random.split(key[j], 3)

      # collocation points
      c1 = jax.random.uniform(keys[0], (ninputs,1), minval=-1.0, maxval=1.0)
      c2 = jax.random.uniform(keys[1], (ninputs,1), minval=-1.0, maxval=1.0)
      c3 = jax.random.uniform(keys[2], (ninputs,1), minval=-1.0, maxval=1.0)

      uc = allencahnNd_exact_u_(lst, [], [c1, c2, c3])

      c1, c2, c3 = c1.T, c2.T, c3.T
      data.append([c1, c2, c3, uc])

    return data

In [13]:
_test_generator_allencahnNd_random([jax.random.PRNGKey(445), jax.random.PRNGKey(44), jax.random.PRNGKey(4)], 3)

[[Array([[-0.98189783,  0.8204565 ,  0.4572053 ]], dtype=float32),
  Array([[0.31219506, 0.11097169, 0.5340574 ]], dtype=float32),
  Array([[-0.49980354,  0.6980226 , -0.6882596 ]], dtype=float32),
  Array([[[0.09185674]]], dtype=float32)],
 [Array([[0.39400816, 0.72540736, 0.0255549 ]], dtype=float32),
  Array([[ 0.15588188, -0.5705161 ,  0.7560077 ]], dtype=float32),
  Array([[-0.03737426,  0.9326124 , -0.02509737]], dtype=float32),
  Array([[[0.327406]]], dtype=float32)],
 [Array([[0.09366703, 0.62533736, 0.5107343 ]], dtype=float32),
  Array([[0.27566886, 0.15903115, 0.40404105]], dtype=float32),
  Array([[ 0.6407676 , -0.34005594,  0.77477884]], dtype=float32),
  Array([[[0.46429986]]], dtype=float32)]]

### Collocation Data

In [14]:
def _anant_train_generator_allencahnNd_collocation(nxy, key, ninputs, seed):
    '''
    factor = -1,1: Boudnary points
    factor != -1,1: Interior points
    '''

    data = []
    dim_list = []

    ndims = 3*ninputs
    lst = list(np.arange(0,ndims,1))

    seeds = np.arange(0,len(key),1)

    # dimension sampler
    gen = generate_dims(seed, ndims, len(key))

    # collocation points
    for j,(dims,factor) in enumerate(gen):

      keys = jax.random.split(key[j], 3)
      data_temp = []
      dim_list.append(dims)

      c1 = jnp.ones((ninputs, nxy))*factor
      c2 = jnp.ones((ninputs, nxy))*factor
      c3 = jnp.ones((ninputs, nxy))*factor

      # collocation points
      c1_temp = jax.random.uniform(keys[0], (nxy,), minval=-1.0, maxval=1.0)
      c2_temp = jax.random.uniform(keys[1], (nxy,), minval=-1.0, maxval=1.0)
      c3_temp = jax.random.uniform(keys[2], (nxy,), minval=-1.0, maxval=1.0)

      # source term
      c1m, c2m, c3m = jnp.meshgrid(c1_temp, c2_temp, c3_temp, indexing='ij')

      index_batch = [dims[0], dims[1], dims[2]]
      idx1 = (index_batch[0] != np.array(lst)).ravel()
      idx2 = (index_batch[1] != np.array(lst)).ravel()
      idx3 = (index_batch[2] != np.array(lst)).ravel()
      fixed_batch = np.argwhere(idx1*idx2*idx3 == True).ravel()

      uc = allencahnNd_source_term_(index_batch, fixed_batch, [c1m, c2m, c3m], factor)

      c1 = c1.at[dims[0],:].set(c1_temp)
      c2 = c2.at[dims[1]-(1*ninputs),:].set(c2_temp)
      c3 = c3.at[dims[2]-(2*ninputs),:].set(c3_temp)

      c1, c2, c3 = c1.T, c2.T, c3.T

      c11 = c1[:, :dims[0]]
      c12 = c1[:, dims[0]].reshape(-1,1)
      c13 = c1[:, dims[0]+1:]

      c21 = c2[:, :dims[1]-(1*ninputs)]
      c22 = c2[:, dims[1]-(1*ninputs)].reshape(-1,1)
      c23 = c2[:, dims[1]-(1*ninputs)+1:]

      c31 = c3[:, :dims[2]-(2*ninputs)]
      c32 = c3[:, dims[2]-(2*ninputs)].reshape(-1,1)
      c33 = c3[:, dims[2]-(2*ninputs)+1:]

      data.append([[c11, c12, c13], [c21, c22, c23], [c31, c32, c33], uc])

    return data, dim_list

In [15]:
_anant_train_generator_allencahnNd_collocation(4, [jax.random.PRNGKey(444),jax.random.PRNGKey(44334),jax.random.PRNGKey(4443),jax.random.PRNGKey(333),], 7, 14)

([[[Array([[0.3, 0.3],
           [0.3, 0.3],
           [0.3, 0.3],
           [0.3, 0.3]], dtype=float32),
    Array([[-0.4376738 ],
           [ 0.42755318],
           [ 0.14637709],
           [-0.11809397]], dtype=float32),
    Array([[0.3, 0.3, 0.3, 0.3],
           [0.3, 0.3, 0.3, 0.3],
           [0.3, 0.3, 0.3, 0.3],
           [0.3, 0.3, 0.3, 0.3]], dtype=float32)],
   [Array([[0.3, 0.3, 0.3, 0.3, 0.3],
           [0.3, 0.3, 0.3, 0.3, 0.3],
           [0.3, 0.3, 0.3, 0.3, 0.3],
           [0.3, 0.3, 0.3, 0.3, 0.3]], dtype=float32),
    Array([[-0.6865864 ],
           [-0.76568747],
           [-0.6264143 ],
           [ 0.75231457]], dtype=float32),
    Array([[0.3],
           [0.3],
           [0.3],
           [0.3]], dtype=float32)],
   [Array([[0.3, 0.3, 0.3, 0.3, 0.3],
           [0.3, 0.3, 0.3, 0.3, 0.3],
           [0.3, 0.3, 0.3, 0.3, 0.3],
           [0.3, 0.3, 0.3, 0.3, 0.3]], dtype=float32),
    Array([[ 0.08064222],
           [-0.8096907 ],
           [ 0.1789

In [16]:
def _anant_train_generator_allencahnNd_collocation_(nxy, key, ninputs, seed):
    '''
    factor = -1,1: Boudnary points
    factor != -1,1: Interior points
    '''

    data = []
    dim_list = []

    ndims = 3*ninputs
    lst = list(np.arange(0,ndims,1))

    seeds = np.arange(0,len(key),1)

    # dimension sampler
    gen = generate_dims(seed, ndims, len(key))

    # collocation points
    for j,(dims,factor) in enumerate(gen):

      keys = jax.random.split(key[j], 3)
      data_temp = []
      dim_list.append(dims)

      c1 = jnp.ones((ninputs, nxy))*factor
      c2 = jnp.ones((ninputs, nxy))*factor
      c3 = jnp.ones((ninputs, nxy))*factor

      # collocation points
      c1_temp = jax.random.uniform(keys[0], (nxy,), minval=-1.0, maxval=1.0)
      c2_temp = jax.random.uniform(keys[1], (nxy,), minval=-1.0, maxval=1.0)
      c3_temp = jax.random.uniform(keys[2], (nxy,), minval=-1.0, maxval=1.0)

      # source term
      c1m, c2m, c3m = jnp.meshgrid(c1_temp, c2_temp, c3_temp, indexing='ij')

      index_batch = [dims[0], dims[1], dims[2]]
      idx1 = (index_batch[0] != np.array(lst)).ravel()
      idx2 = (index_batch[1] != np.array(lst)).ravel()
      idx3 = (index_batch[2] != np.array(lst)).ravel()
      fixed_batch = np.argwhere(idx1*idx2*idx3 == True).ravel()

      uc = allencahnNd_source_term_(index_batch, fixed_batch, [c1m, c2m, c3m], factor)
      u_gt = allencahnNd_exact_u_(index_batch, fixed_batch, [c1m, c2m, c3m], factor, random_test=False)

      c1 = c1.at[dims[0],:].set(c1_temp)
      c2 = c2.at[dims[1]-(1*ninputs),:].set(c2_temp)
      c3 = c3.at[dims[2]-(2*ninputs),:].set(c3_temp)

      c1, c2, c3 = c1.T, c2.T, c3.T

      c11 = c1[:, :dims[0]]
      c12 = c1[:, dims[0]].reshape(-1,1)
      c13 = c1[:, dims[0]+1:]

      c21 = c2[:, :dims[1]-(1*ninputs)]
      c22 = c2[:, dims[1]-(1*ninputs)].reshape(-1,1)
      c23 = c2[:, dims[1]-(1*ninputs)+1:]

      c31 = c3[:, :dims[2]-(2*ninputs)]
      c32 = c3[:, dims[2]-(2*ninputs)].reshape(-1,1)
      c33 = c3[:, dims[2]-(2*ninputs)+1:]

      data.append([[c11, c12, c13], [c21, c22, c23], [c31, c32, c33], uc, u_gt])

    return data, dim_list

### Boundary Data

In [17]:
def allencahnNd_exact_u_bound_(index_batch, fixed_batch, X, fX):

    uxx = jnp.zeros_like(X[0])

    ndims = len(index_batch) + len(fixed_batch)
    X_tot = []
    for i in range(ndims):
      if i in index_batch:
        id = np.argwhere(np.array(index_batch) == i).ravel()[0]
        X_tot.append(X[id])
      else:
        id = np.argwhere(np.array(fixed_batch) == i).ravel()[0]
        X_tot.append(fX[id] * jnp.ones_like(X[0]))
    ############ Exact solution ##########
    key = jax.random.PRNGKey(444)
    exact = jnp.zeros_like(X_tot[0])

    for i in range(ndims):
      exact += X_tot[i]

    exact = 1/ndims * exact
    exact = exact**2 + jnp.sin(exact)

    del X_tot
    return jnp.atleast_3d(exact)



def allencahnNd_source_term_bound_(index_batch, fixed_batch, X, fX):

    f = jnp.zeros_like(X[0])

    ndims = len(index_batch) + len(fixed_batch)
    X_tot = []
    for i in range(ndims):
      if i in index_batch:
        id = np.argwhere(np.array(index_batch) == i).ravel()[0]
        X_tot.append(X[id])
      else:
        id = np.argwhere(np.array(fixed_batch) == i).ravel()[0]
        X_tot.append(fX[id] * jnp.ones_like(X[0]))

    v = jnp.ones(X_tot[0].shape)
    f = jnp.zeros(X_tot[0].shape)

    for i in range(ndims):
      f += X_tot[i]

    f = 1/ndims * f
    f = (1/ndims)*(jnp.sin(f) - 2.)

    del X_tot
    return f

def _anant_train_generator_allencahnNd_boundary(nxy, key, ninputs):
    '''
    factor = -1,1: Boudnary points
    factor != -1,1: Interior points
    '''

    data = []
    dim_list = []
    ndims = 3*ninputs

    lst = list(np.arange(0,ndims,1))

    seeds = np.arange(1,len(key)+1,1)

    # dimension sampler for boundary points
    gen = generate_dims(444, ndims, len(key), bound = True)

    # collocation points
    for j,(dims,factor) in enumerate(gen):
      keys = jax.random.split(key[j], 3)
      data_temp = []
      dim_list.append(dims)

      #################################################### Boundary points ###############################################
      c1_temp = jax.random.uniform(keys[0], (nxy,), minval=-1.0, maxval=1.0)
      c2_temp = jax.random.uniform(keys[1], (nxy,), minval=-1.0, maxval=1.0)
      c3_temp = jax.random.uniform(keys[2], (nxy,), minval=-1.0, maxval=1.0)

      c1m, c2m, c3m = jnp.meshgrid(c1_temp, c2_temp, c3_temp, indexing='ij')

      index_batch = [dims[0], dims[1], dims[2]]
      idx1 = (index_batch[0] != np.array(lst)).ravel()
      idx2 = (index_batch[1] != np.array(lst)).ravel()
      idx3 = (index_batch[2] != np.array(lst)).ravel()
      fixed_batch = np.argwhere(idx1*idx2*idx3 == True).ravel()

      np.random.seed(44+np.random.randint(1000)*j)
      f1 = list(np.random.choice(np.arange(-1.,2.,2.),ninputs))
      id = np.random.permutation(np.arange(ninputs))[:ninputs-1]
      for i in id:
        f1[i] = np.random.choice(np.arange(-0.8,0.95,0.1)).round(2)
      bs = [jnp.ones((nxy,))*f1[i] for i in range(ninputs)]
      b1 = jnp.stack(bs)


      np.random.seed(444+np.random.randint(1000)*j)
      f2 = list(np.random.choice(np.arange(-1.,2.,2.),ninputs))
      id = np.random.permutation(np.arange(ninputs))[:ninputs-1]
      for i in id:
        f2[i] = np.random.choice(np.arange(-0.8,0.95,0.1)).round(2)
      bs = [jnp.ones((nxy,))*f2[i] for i in range(ninputs)]
      b2 = jnp.stack(bs)

      np.random.seed(4444+np.random.randint(1000)*j)
      f3 = list(np.random.choice(np.arange(-1.,2.,2.),ninputs))
      id = np.random.permutation(np.arange(ninputs))[:ninputs-1]
      for i in id:
        f3[i] = np.random.choice(np.arange(-0.8,0.95,0.1)).round(2)
      bs = [jnp.ones((nxy,))*f3[i] for i in range(ninputs)]
      b3 = jnp.stack(bs)

      f1.pop(dims[0])
      f2.pop(dims[1]-(1*ninputs))
      f3.pop(dims[2]-(2*ninputs))

      f = [f1, f2, f3]
      factor = [fi for i in range(3) for fi in f[i]]
      del f

      ub = allencahnNd_exact_u_bound_(index_batch, fixed_batch, [c1m, c2m, c3m], factor)
      u_gt = allencahnNd_source_term_bound_(index_batch, fixed_batch, [c1m, c2m, c3m], factor)

      b1 = b1.at[dims[0],:].set(c1_temp)
      b2 = b2.at[dims[1]-(1*ninputs),:].set(c2_temp)
      b3 = b3.at[dims[2]-(2*ninputs),:].set(c3_temp)

      b1, b2, b3 = b1.T, b2.T, b3.T

      b11 = b1[:, :dims[0]]
      b12 = b1[:, dims[0]].reshape(-1,1)
      b13 = b1[:, dims[0]+1:]

      b21 = b2[:, :dims[1]-(1*ninputs)]
      b22 = b2[:, dims[1]-(1*ninputs)].reshape(-1,1)
      b23 = b2[:, dims[1]-(1*ninputs)+1:]

      b31 = b3[:, :dims[2]-(2*ninputs)]
      b32 = b3[:, dims[2]-(2*ninputs)].reshape(-1,1)
      b33 = b3[:, dims[2]-(2*ninputs)+1:]

      data.append([[b11, b12, b13], [b21, b22, b23], [b31, b32, b33], ub, u_gt])

    return data, dim_list

In [18]:
_anant_train_generator_allencahnNd_boundary(4, [jax.random.PRNGKey(444),jax.random.PRNGKey(445),jax.random.PRNGKey(333),], 100)

([[[Array([[ 0.9,  0.4, -0.8,  0.7,  0.5, -0.6,  0.2,  0.1, -0. , -0.8,  0.2,
            -0.8, -0. ,  0.9,  0.9, -0.8, -0.8, -0.6, -0.7, -0.5,  0.5, -0.4,
            -0.2,  0.4,  0.8, -0.3, -0. , -0.6,  0.1, -0.7, -0. ,  0.2, -0.7,
            -0.3, -0.6, -0.4,  0.8,  0.8, -0.8,  0.9,  0.3,  0.7, -0. , -0.8,
            -0.3,  0.1, -0. ,  0.8, -0.3, -0.4,  0.6,  0.6, -0. ,  0.7,  0.1,
             0.7, -0. ,  0.9, -0.5,  0.7,  0.3, -0.1,  0.2, -0.3,  0.2,  0.4,
            -0.7,  0.3, -0.2,  0.9,  0.8, -0.8,  0.4,  0.7],
           [ 0.9,  0.4, -0.8,  0.7,  0.5, -0.6,  0.2,  0.1, -0. , -0.8,  0.2,
            -0.8, -0. ,  0.9,  0.9, -0.8, -0.8, -0.6, -0.7, -0.5,  0.5, -0.4,
            -0.2,  0.4,  0.8, -0.3, -0. , -0.6,  0.1, -0.7, -0. ,  0.2, -0.7,
            -0.3, -0.6, -0.4,  0.8,  0.8, -0.8,  0.9,  0.3,  0.7, -0. , -0.8,
            -0.3,  0.1, -0. ,  0.8, -0.3, -0.4,  0.6,  0.6, -0. ,  0.7,  0.1,
             0.7, -0. ,  0.9, -0.5,  0.7,  0.3, -0.1,  0.2, -0.3,  0.2,  0.4,
   

## 3. Utils

In [19]:
def relative_l2(u, u_gt):
    """
    Function to compute relative l2 error

    Arguments:
      u: predicted solution
      u_gt: exact solution
    Return:
      relative l2 error
    """
    return jnp.linalg.norm(u-u_gt) / jnp.linalg.norm(u_gt)

def mse(apply_fn, params, data):
    """
    Function to compute the mean squared error

    Arguments:
      apply_fn: calling Anant-Net
      params: trained parameters of Anant-Net
      data: data for which the error is to be computed

    Return:
      mean squared error
    """
    error = 0.0
    sum = 0.0
    for i in range(len(data)):
      x, y, z, u_gt = data[i]
      u = apply_fn(params, x, y, z)
      error += jnp.linalg.norm(u-u_gt)
      sum += jnp.linalg.norm(u_gt)
    return error / sum

def plot_allencahn_slice(x, y, z, u):
    """
    Function to plot the solution along the active dimensions

    Arguments:
      Assuming a problem with 3 body networks (or 3 batch size)
      x: x-coordinate (active-dimension-1)
      y: y-coordinate (active-dimension-2)
      z: z-coordinate (active-dimension-3)
      u: solution

    Return:

    """
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection='3d')
    cs = ax.scatter(x, y, z, c=u, s=0.3, cmap='seismic', edgecolor='none')
    ax.set_title('u(x, y, z)', fontsize=20)
    ax.set_xlabel('x', fontsize=18, labelpad=10)
    ax.set_ylabel('y', fontsize=18, labelpad=10)
    ax.set_zlabel('z', fontsize=18, labelpad=10)
    cbar = fig.colorbar(cs, orientation='horizontal', fraction=.1)
    plt.show()

def plot_allencahn_slices(ax, x, y, z, u, title, labels):
    """
    Function to plot the solution along the active dimensions

    Arguments:
      Assuming a problem with 3 body networks (or 3 batch size)
      ax: axis
      x: x-coordinate (active-dimension-1)
      y: y-coordinate (active-dimension-2)
      z: z-coordinate (active-dimension-3)
      u: solution
      title: title of the plot
      labels: labels of the axes

    Return:

    """
    cs = ax.scatter(x, y, z, c=u, s=0.3, cmap='seismic', edgecolor='none')
    ax.set_title(title, fontsize=20)
    ax.set_xlabel(labels[0], fontsize=18, labelpad=10)
    ax.set_ylabel(labels[1], fontsize=18, labelpad=10)
    ax.set_zlabel(labels[2], fontsize=18, labelpad=10)

### Inference/Test at Random Points

In [20]:
def test_model_random(seed, params, ninputs, feat_sizes, num_samples):
  """
  Function to generate random data points for testing Anant-Net

  Arguments:
    seed: random seed
    params: trainable parameters of Anant-Net
    ninputs: number of inputs per body network (number of dimensions divided by number of body networks)
    feat_sizes: number of features in each layer (architecture of a single body network)
    num_samples: number of test points need to be sampled

  Return:
    mean squared error (test)
  """
  feat_sizes = tuple(feat_sizes)

  # make & init model
  model = Anant_test(feat_sizes)
  apply_fn = jax.jit(model.apply)

  key = jax.random.PRNGKey(seed)
  key, subkey = jax.random.split(key, 2)

  key_test = jax.random.split(key, num_samples)
  data_test = _test_generator_allencahnNd_random(key_test, ninputs)

  return mse(apply_fn, params, data_test)

In [62]:
def test_model_ntimes(seed, dir, params, ninputs, feat_sizes, ext='adam', num_samples = 10000):
  """
  The function performs testing of the trained model on random test points

  Arguments:
    seed: random seed
    dir: directory to save the results
    params: trainable parameters
    ninputs: number of inputs per body network
    feat_sizes: architecture of the body network in the form of a list
    ext: 'adam' or 'lbfgs'
    num_samples: number of test samples

  Returns:
    test_rl2: list of test results
  """
  test_rl2 = []
  for i in range(len(params)):
    test_rl2.append(test_model_random(seed, params[i], ninputs, feat_sizes, num_samples))
    var = round(test_rl2[i], 5)
    print(f"% Relative L2 Error: {var*100} %")

  # write results
  with open(os.path.join(dir, f'test_rl2_{ext}.pkl'), 'wb') as f:
    pkl.dump(test_rl2, f)
  return test_rl2

### Load/Save Results

In [22]:
def save_results(dir, loss, time_list, params_list, best_params):
  """
  Function to save the results of Anant-Net training

  Arguments:
    dir: directory for writing the results
    loss: training loss
    time_list: time profiling
    params_list: saved parameters at differnt iterations during model training
    best_params: best parameters

  Return:

  """
  with open(os.path.join(dir, 'loss.pkl'), 'wb') as f:
    pkl.dump(loss, f)

  with open(os.path.join(dir, 'params_hist.pkl'), 'wb') as f:
    pkl.dump(params_list, f)

  with open(os.path.join(dir, 'time_hist.pkl'), 'wb') as f:
    pkl.dump(time_list, f)

  with open(os.path.join(dir, 'best_params.pkl'), 'wb') as f:
    pkl.dump(best_params, f)

In [23]:
def load_results(dir):
  """
  Function to load the results of Anant-Net training

  Arguments:
    dir: directory from which the results are loaded

  Return:
    loss: training loss
    time_list: time profiling
    params_list: saved parameters at differnt iterations during model training
    best_params: best parameters

  """
  with open(os.path.join(dir, 'loss.pkl'), 'rb') as f:
    loss = pkl.load(f)

  with open(os.path.join(dir, 'params_hist.pkl'), 'rb') as f:
    params_list = pkl.load(f)

  with open(os.path.join(dir, 'time_hist.pkl'), 'rb') as f:
    time_list = pkl.load(f)

  with open(os.path.join(dir, 'best_params.pkl'), 'rb') as f:
    best_params = pkl.load(f)

  return loss, time_list, params_list, best_params

In [24]:
def load_test_results(dir):
  """
  Function to load the results of Anant-Net training

  Arguments:
    dir: directory where the relative l2 error (test) is saved

  Return:
    relative l2 error (test)
  """
  with open(os.path.join(dir, 'test_rl2_lbfgs.pkl'), 'rb') as f:
    tst = pkl.load(f)
  return tst

## 4. Main function

In [25]:
def main(NC, NB, SAMP_FREQ, SEED, EPOCHS, N_LAYERS, NNEURONS, LOG_ITER, PRINT_ITER, NUM_BATCHES_COLL, NUM_BATCHES_BOUND, NDIMS, BATCH_SIZE, LAMB, WDIR=None, OPT_DICT = {'Optimizer': 'Adam', 'Lr': 1e-3, 'Reg': 0.1, 'Param_Initial': None}):
    """
    The "main" function to train Anant-Net

    Arguments:
      NC: Number of partitions for an axis in the collocation grid. For eg. a grid with 3 body networks will have NCxNCxNC collocation points
      NB: Number of partitions for an axis in the boundary grid. For eg. a grid with 3 body networks will have NBxNBxNB boundary points points
          Also, note that if NB = 1,  Anant-Net handles a single boundary point since NBxNBxNB = 1
      SAMP_FREQ: The frequency at which new collocation grids are sampled during model training.
      SEED: Random seed
      EPOCHS: Number of iterations
      N_LAYERS: Number of layers in the body network
      NNEURONS [List]: [Number of neurons in hidden layers, Number of neuron in final embedding layer]
      LOG_ITER: Frequency of writing the results
      PRINT_ITER: Frequency of printing the results
      NUM_BATCHES_COLL: Number of collocation grids
      NUM_BATCHES_BOUND: Number of boundary grids
      NDIMS: Number of dimensions for the high-dimensional PDE
      BATCH_SIZE: Batch size for collocation grids. Can sample multiple collocation grids in a single iteration and sustains it untill SAMP_FREQ
      LAMB: Weight for the boundary loss
      WDIR: Work directory to save the results
      OPT_DICT:
        A dictionary of optimizer settings
        Optimizer: 'AdamW' or 'LBFGS'
        Lr: Learning rate
        Reg: Weight decay parameter such as used in AdamW for regularization
        Param_Initial: Initial parameters for warm starting the model training

    Returns:
      loss_lst: list of loss values
      time_list: list of time profiling
      p_list: list of trained parameters at different iterations
      best_params: best parameters
    """

    # setting numpy random seed
    np.random.seed(SEED)

    # force jax to use one device
    os.environ["CUDA_VISIBLE_DEVICES"]="0"
    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
    jax.config.update("jax_enable_x64", False)

    # random key
    key = jax.random.PRNGKey(SEED)
    key, subkey = jax.random.split(key, 2)

    # setting up body networks
    feat_sizes = [NNEURONS[0] for _ in range(N_LAYERS)]
    feat_sizes[-1] = NNEURONS[1]
    feat_sizes = tuple(feat_sizes)
    print(f'Body Network: {feat_sizes}')

    # number of inputs per body network. Here, we have 3 body networks.
    N_INPUTS = int(NDIMS/3)

    # make & init model
    model = Anant(feat_sizes)

    # initialization
    d1 = 1
    d2 = 1
    d3 = N_INPUTS - 2
    params = model.init(subkey, jnp.ones((NC, d1)), jnp.ones((NC, d2)), jnp.ones((NC, d3)), jnp.ones((NC, d1)), jnp.ones((NC, d2)), jnp.ones((NC, d3))\
                        , jnp.ones((NC, d1)), jnp.ones((NC, d2)), jnp.ones((NC, d3)))

    # optimizer
    optim = optax.adamw(OPT_DICT['Lr'], weight_decay=OPT_DICT['Reg'])

    state = optim.init(params)

    # dataset
    key_coll = jax.random.split(key, NUM_BATCHES_COLL)
    keys_coll = [key_i for key_i in key_coll]

    key_bound = jax.random.split(key, NUM_BATCHES_BOUND)
    keys_bound = [key_i for key_i in key_bound]

    # forward & loss function
    apply_fn = jax.jit(model.apply)

    # test points
    key_test = jax.random.split(key, 2000)
    data_test = _test_generator_allencahnNd_random(key_test, N_INPUTS)

    # collocation points
    train_data_coll, dims_coll = _anant_train_generator_allencahnNd_collocation_(NC, keys_coll, N_INPUTS, seed=SEED)
    train_data_bound, dims_bound = _anant_train_generator_allencahnNd_boundary(NB, keys_bound, N_INPUTS)

    # training data summary
    tot_data_samples = 0
    for t_data in train_data_coll:
      _,_,_,u,_ = t_data
      tot_data_samples += u.shape[0]*u.shape[1]*u.shape[2]
    print(f'\nTotal number of collocation samples: {tot_data_samples}')

    tot_data_samples = 0
    for t_data in train_data_bound:
      _,_,_,u,_ = t_data
      tot_data_samples += u.shape[0]*u.shape[1]*u.shape[2]
    print(f'Total number of boundary samples: {tot_data_samples}')


    if OPT_DICT['Optimizer'] == 'LBFGS-Rank':
      batch_idx = copy(OPT_DICT['Rank'])
      op_fl = 0
    else:
      batch_idx = list(np.arange(0,len(train_data_coll),1))
      opt_fl = 1

    np.random.shuffle(batch_idx)
    reset_idx = 0
    update_epoch = 0
    print_idx = 1

    time_coll = 0.0
    time_avg = 0.0
    time_bound = 0.0
    best_error = jnp.inf

    time_list = []
    p_list = []
    loss_lst = []

    lr = OPT_DICT['Lr']

    start = time.time()
    for e in trange(1, (EPOCHS*BATCH_SIZE)+1):

      if len(batch_idx) < BATCH_SIZE and opt_fl:
        batch_idx = list(np.arange(0,len(train_data_coll),1))
        np.random.shuffle(batch_idx)
        reset_idx = 0
      elif len(batch_idx) < BATCH_SIZE and not opt_fl:
        batch_idx = OPT_DICT['Rank']
        np.random.shuffle(batch_idx)
        reset_idx = 0

      ################# process time for collocation points ###########################
      time_coll = time.process_time()
      loss_coll, gradient_coll = apply_anant_collocation(apply_fn, params, train_data_coll[batch_idx[0]])
      time_coll = time.process_time() - time_coll
      #################################################################################

      ################# process time for averaging gradients ###########################
      time_avg = time.process_time()
      if reset_idx == 0:
        gradient_avg = jax.tree.map(lambda x: x, gradient_coll)
        loss_avg = loss_coll
      else:
        gradient_avg = jax.tree.map(lambda x, y: x + y, gradient_coll, gradient_avg)
        loss_avg = loss_avg + loss_coll
      time_avg = time.process_time() - time_avg
      #################################################################################

      reset_idx+=1

      if reset_idx == BATCH_SIZE:

        ################# process time for collocation points #########################
        time_bound = time.process_time()
        loss_bound, gradient_bound = apply_anant_boundary(apply_fn, params, train_data_bound, lamb=LAMB)
        time_bound = time.process_time() - time_bound
        ###############################################################################

        ################# process time for averaging gradients ###########################
        time_avg1 = time.process_time()
        loss_avg = loss_avg/BATCH_SIZE + loss_bound

        gradient_avg = jax.tree.map(lambda x: x/BATCH_SIZE, gradient_avg)
        gradient = jax.tree.map(lambda x, y: x + y, gradient_avg, gradient_bound)
        time_avg1 = time.process_time() - time_avg1
        time_avg = time_avg + time_avg1
        #################################################################################
        time_list.append([time_coll, time_bound, time_avg])

        params, state = update_model(optim, gradient, params, state)

        gradient_avg = None
        reset_idx = 0
        update_epoch+=1
        print_idx = 1


      if update_epoch % SAMP_FREQ == 0:
        # save the parameters for the best model
        rl2 = test_model_random(333, params, N_INPUTS, feat_sizes, 1000)
        if rl2 < best_error:
          best_params = params
          best_error = rl2

      if update_epoch % PRINT_ITER == 0 and print_idx == 1:

        x, y, z, uc, u_gt = train_data_coll[batch_idx[0]]
        u = apply_fn(params, x[0], x[1], x[2], y[0], y[1], y[2], z[0], z[1], z[2])
        if math.isnan(loss_avg):
          print('LBFGS diverging...')
          break

        loss_lst += [[update_epoch,round(float(loss_avg),6)]]
        print(f'Epoch: {update_epoch}/{EPOCHS}, collocation grid index: {batch_idx[0]}, lr: {lr} --> loss: {loss_avg:.8f}, best error: {best_error:.8f}')

        loss_avg = 0.0
        print_idx = 0

      if update_epoch % LOG_ITER == 0:
        p_list.append(params)

      if update_epoch % SAMP_FREQ == 0:
        batch_idx.pop(0)

    end = time.time()
    print(f'Runtime: {((end-start)/(EPOCHS)*1000):.2f} ms/iter.')

    p_list.append(params)

    # save results
    try:
      os.makedirs(WDIR, exist_ok=True)
      save_results(WDIR, loss_lst, time_list, p_list, best_params)
    except:
      pass

    return loss_lst, time_list, p_list, best_params

In [26]:
def lbfgs(NC, NB, SAMP_FREQ, SEED, LR, EPOCHS, N_LAYERS, NNEURONS, LOG_ITER, PRINT_ITER, NUM_BATCHES_COLL, NUM_BATCHES_BOUND, NDIMS, BATCH_SIZE, LAMB, PARAM_INIT, WDIR=None):

    """
    The "main" function to train Anant-Net with LBFGS

    Arguments:
      NC: Number of partitions for an axis in the collocation grid. For eg. a grid with 3 body networks will have NCxNCxNC collocation points
      NB: Number of partitions for an axis in the boundary grid. For eg. a grid with 3 body networks will have NBxNBxNB boundary points points
          Also, note that if NB = 1,  Anant-Net handles a single boundary point since NBxNBxNB = 1
      SAMP_FREQ: The frequency at which new collocation grids are sampled during model training.
      SEED: Random seed
      EPOCHS: Number of iterations
      N_LAYERS: Number of layers in the body network
      NNEURONS [List]: [Number of neurons in hidden layers, Number of neuron in final embedding layer]
      LOG_ITER: Frequency of writing the results
      PRINT_ITER: Frequency of printing the results
      NUM_BATCHES_COLL: Number of collocation grids
      NUM_BATCHES_BOUND: Number of boundary grids
      NDIMS: Number of dimensions for the high-dimensional PDE
      BATCH_SIZE: Batch size for collocation grids. Can sample multiple collocation grids in a single iteration and sustains it untill SAMP_FREQ
      LAMB: Weight for the boundary loss
      PARAM_INIT: Initial parameters for warm starting the model training
      WDIR: Work directory to save the results

    Returns:
      loss_lst: list of loss values
      time_list: list of time profiling
      p_list: list of trained parameters at different iterations
    """

    # setting numpy random seed
    np.random.seed(SEED)

    # force jax to use one device
    os.environ["CUDA_VISIBLE_DEVICES"]="0"
    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

    # random key
    key = jax.random.PRNGKey(SEED)
    key, subkey = jax.random.split(key, 2)

    # setting up body networks
    feat_sizes = [NNEURONS[0] for _ in range(N_LAYERS)]
    feat_sizes[-1] = NNEURONS[1]
    feat_sizes = tuple(feat_sizes)
    print(f'Body Network: {feat_sizes}')

    # number of inputs per body network. Here, we have 3 body networks.
    N_INPUTS = int(NDIMS/3)

    # make & init model
    model = Anant(feat_sizes)

    params = PARAM_INIT

    # optimizer
    optim = optax.chain(
        optax.scale_by_lbfgs(),
        optax.scale(-1*LR)
    )
    state = optim.init(params)

    # dataset
    key_coll = jax.random.split(key, NUM_BATCHES_COLL)
    keys_coll = [key_i for key_i in key_coll]

    key_bound = jax.random.split(key, NUM_BATCHES_BOUND)
    keys_bound = [key_i for key_i in key_bound]

    # forward & loss function
    apply_fn = jax.jit(model.apply)

    # test points
    key_test = jax.random.split(key, 2000)
    data_test = _test_generator_allencahnNd_random(key_test, N_INPUTS)

    # collocation points
    train_data_coll, dims_coll = _anant_train_generator_allencahnNd_collocation_(NC, keys_coll, N_INPUTS, SEED)
    train_data_bound, dims_bound = _anant_train_generator_allencahnNd_boundary(NB, keys_bound, N_INPUTS)

    # training data summary
    tot_data_samples = 0
    for t_data in train_data_coll:
      _,_,_,u,_ = t_data
      tot_data_samples += u.shape[0]*u.shape[1]*u.shape[2]
    print(f'\nTotal number of collocation samples: {tot_data_samples}')

    tot_data_samples = 0
    for t_data in train_data_bound:
      _,_,_,u,_ = t_data
      tot_data_samples += u.shape[0]*u.shape[1]*u.shape[2]
    print(f'Total number of boundary samples: {tot_data_samples}')

    batch_idx = list(np.arange(0,len(train_data_coll),1))

    np.random.shuffle(batch_idx)
    reset_idx = 0
    update_epoch = 0
    print_idx = 1
    best_error = jnp.inf

    time_coll = 0.0
    time_avg = 0.0
    time_bound = 0.0

    time_list = []
    p_list = []
    loss_lst = []

    start = time.time()
    for e in trange(1, (EPOCHS*BATCH_SIZE)+1):

      if len(batch_idx) == 0:
        batch_idx = list(np.arange(0,len(train_data_coll),1))
        np.random.shuffle(batch_idx)
        reset_idx = 0

      ################# process time for collocation points ###########################
      time_coll = time.process_time()
      loss_coll, gradient_coll = apply_anant_collocation(apply_fn, params, train_data_coll[batch_idx[0]])
      time_coll = time.process_time() - time_coll
      #################################################################################

      ################# process time for averaging gradients ###########################
      time_avg = time.process_time()
      if reset_idx == 0:
        gradient_avg = jax.tree.map(lambda x: x, gradient_coll)
        loss_avg = loss_coll
      else:
        gradient_avg = jax.tree.map(lambda x, y: x + y, gradient_coll, gradient_avg)
        loss_avg = loss_avg + loss_coll
      time_avg = time.process_time() - time_avg
      #################################################################################

      reset_idx+=1

      if reset_idx == BATCH_SIZE:

        ################# process time for collocation points #########################
        time_bound = time.process_time()
        loss_bound, gradient_bound = apply_anant_boundary(apply_fn, params, train_data_bound, lamb=LAMB)
        time_bound = time.process_time() - time_bound
        ###############################################################################

        ################# process time for averaging gradients ###########################
        time_avg1 = time.process_time()
        loss_avg = loss_avg/BATCH_SIZE + loss_bound

        gradient_avg = jax.tree.map(lambda x: x/BATCH_SIZE, gradient_avg)
        gradient = jax.tree.map(lambda x, y: x + y, gradient_avg, gradient_bound)
        time_avg1 = time.process_time() - time_avg1
        time_avg = time_avg + time_avg1
        #################################################################################
        time_list.append([time_coll, time_bound, time_avg])

        params, state = update_model(optim, gradient, params, state)

        gradient_avg = None

        reset_idx = 0
        update_epoch+=1
        print_idx = 1

      if update_epoch % SAMP_FREQ == 0:
        # save the parameters for the best model
        # change the random seed below. Current random seed = 333 for computing validation error while training the model. 
        rl2 = test_model_random(333, params, N_INPUTS, feat_sizes, 1000)
        if rl2 < best_error:
          best_params = params
          best_error = rl2

      if update_epoch % PRINT_ITER == 0 and print_idx == 1:

        x, y, z, uc, u_gt = train_data_coll[batch_idx[0]]
        u = apply_fn(params, x[0], x[1], x[2], y[0], y[1], y[2], z[0], z[1], z[2])

        if math.isnan(loss_avg):
          print('Optimizer diverging...')
          break
        loss_lst += [[update_epoch,round(float(loss_avg),6)]]
        print(f'Epoch: {update_epoch}/{EPOCHS}, collocation grid index: {batch_idx[0]}, lr: {LR} --> loss: {loss_avg:.8f}, best error: {best_error:.8f}')


        loss_avg = 0.0
        print_idx = 0

      if update_epoch % LOG_ITER == 0:
        p_list.append(params)

      if update_epoch % SAMP_FREQ == 0:
        batch_idx.pop(0)


    end = time.time()
    print(f'Runtime: {((end-start)/(EPOCHS)*1000):.2f} ms/iter.')

    p_list.append(params)

    # save results
    try:
      os.makedirs(WDIR, exist_ok=True)
      save_results(WDIR, loss_lst, time_list, p_list, best_params)
    except:
      pass

    return loss_lst, time_list, p_list, best_params

# Run

In [30]:
filename = 'ac_results'
os.getcwd()
os.makedirs(os.path.join(os.getcwd(), filename), exist_ok=True)
root_dir = os.path.join(os.getcwd(), filename)

In [39]:
result_dir = os.path.join(root_dir, 'adam')
os.makedirs(result_dir, exist_ok=True)

lss_list, time_list, params_list, best_params = main(NC=14, NB=6, SAMP_FREQ=5000, SEED=444, EPOCHS=70000, \
                                                    N_LAYERS=3, NNEURONS=[64,10], LOG_ITER=100, PRINT_ITER=1000, \
                                                    NUM_BATCHES_COLL=14, NUM_BATCHES_BOUND=40, \
                                                    NDIMS = 21, BATCH_SIZE = 1, LAMB = 15, WDIR=result_dir, \
                                                    OPT_DICT = {'Optimizer': 'Adam', 'Lr': 1e-3, 'Reg': 0.1, 'Param_Initial': None})

Body Network: (64, 64, 10)

Total number of collocation samples: 38416
Total number of boundary samples: 8640


  1%|▏         | 1014/70000 [00:21<10:06, 113.66it/s]

Epoch: 1000/70000, collocation grid index: 13, lr: 0.001 --> loss: 0.00056544, best error: inf


  3%|▎         | 2029/70000 [00:29<07:19, 154.72it/s]

Epoch: 2000/70000, collocation grid index: 13, lr: 0.001 --> loss: 0.00016472, best error: inf


  4%|▍         | 3016/70000 [00:36<10:20, 107.93it/s]

Epoch: 3000/70000, collocation grid index: 13, lr: 0.001 --> loss: 0.00010481, best error: inf


  6%|▌         | 4023/70000 [00:43<06:47, 161.92it/s]

Epoch: 4000/70000, collocation grid index: 13, lr: 0.001 --> loss: 0.00006158, best error: inf


  7%|▋         | 4993/70000 [00:49<09:25, 114.95it/s]

Epoch: 5000/70000, collocation grid index: 13, lr: 0.001 --> loss: 0.00004895, best error: 0.57148540


  9%|▊         | 6019/70000 [01:02<10:15, 103.96it/s] 

Epoch: 6000/70000, collocation grid index: 2, lr: 0.001 --> loss: 0.02008322, best error: 0.57148540


 10%|█         | 7021/70000 [01:11<07:51, 133.55it/s]

Epoch: 7000/70000, collocation grid index: 2, lr: 0.001 --> loss: 0.01003081, best error: 0.57148540


 11%|█▏        | 8020/70000 [01:19<07:46, 132.83it/s]

Epoch: 8000/70000, collocation grid index: 2, lr: 0.001 --> loss: 0.00719622, best error: 0.57148540


 13%|█▎        | 9025/70000 [01:28<08:06, 125.24it/s]

Epoch: 9000/70000, collocation grid index: 2, lr: 0.001 --> loss: 0.00547241, best error: 0.57148540


 14%|█▍        | 9995/70000 [01:36<08:51, 112.93it/s]

Epoch: 10000/70000, collocation grid index: 2, lr: 0.001 --> loss: 0.00339095, best error: 0.36539310


 16%|█▌        | 11014/70000 [01:49<09:54, 99.28it/s]  

Epoch: 11000/70000, collocation grid index: 12, lr: 0.001 --> loss: 0.00078604, best error: 0.36539310


 17%|█▋        | 12016/70000 [02:00<08:24, 114.95it/s]

Epoch: 12000/70000, collocation grid index: 12, lr: 0.001 --> loss: 0.00040650, best error: 0.36539310


 19%|█▊        | 13012/70000 [02:09<09:59, 95.08it/s] 

Epoch: 13000/70000, collocation grid index: 12, lr: 0.001 --> loss: 0.00020524, best error: 0.36539310


 20%|██        | 14021/70000 [02:20<08:17, 112.55it/s]

Epoch: 14000/70000, collocation grid index: 12, lr: 0.001 --> loss: 0.00010101, best error: 0.36539310


 21%|██▏       | 14998/70000 [02:29<09:15, 99.04it/s] 

Epoch: 15000/70000, collocation grid index: 12, lr: 0.001 --> loss: 0.00005168, best error: 0.05275854


 23%|██▎       | 16019/70000 [02:42<10:18, 87.24it/s]  

Epoch: 16000/70000, collocation grid index: 6, lr: 0.001 --> loss: 0.00025945, best error: 0.05275854


 24%|██▍       | 17013/70000 [02:51<07:21, 120.08it/s]

Epoch: 17000/70000, collocation grid index: 6, lr: 0.001 --> loss: 0.00018876, best error: 0.05275854


 26%|██▌       | 18022/70000 [03:00<07:02, 122.94it/s]

Epoch: 18000/70000, collocation grid index: 6, lr: 0.001 --> loss: 0.00015692, best error: 0.05275854


 27%|██▋       | 19020/70000 [03:09<07:42, 110.25it/s]

Epoch: 19000/70000, collocation grid index: 6, lr: 0.001 --> loss: 0.00013675, best error: 0.05275854


 29%|██▊       | 19996/70000 [03:18<07:14, 114.97it/s]

Epoch: 20000/70000, collocation grid index: 6, lr: 0.001 --> loss: 0.00012290, best error: 0.05275854


 30%|███       | 21013/70000 [03:31<09:53, 82.49it/s]  

Epoch: 21000/70000, collocation grid index: 11, lr: 0.001 --> loss: 0.00023213, best error: 0.05275854


 31%|███▏      | 22021/70000 [03:40<06:55, 115.40it/s]

Epoch: 22000/70000, collocation grid index: 11, lr: 0.001 --> loss: 0.00016526, best error: 0.05275854


 33%|███▎      | 23015/70000 [03:50<08:05, 96.69it/s] 

Epoch: 23000/70000, collocation grid index: 11, lr: 0.001 --> loss: 0.00014399, best error: 0.05275854


 34%|███▍      | 24024/70000 [03:59<06:59, 109.67it/s]

Epoch: 24000/70000, collocation grid index: 11, lr: 0.001 --> loss: 0.00012938, best error: 0.05275854


 36%|███▌      | 24997/70000 [04:09<06:21, 117.93it/s]

Epoch: 25000/70000, collocation grid index: 11, lr: 0.001 --> loss: 0.00012197, best error: 0.05275854


 37%|███▋      | 26023/70000 [04:23<06:56, 105.51it/s] 

Epoch: 26000/70000, collocation grid index: 3, lr: 0.001 --> loss: 0.00210496, best error: 0.05275854


 39%|███▊      | 27015/70000 [04:30<05:19, 134.70it/s]

Epoch: 27000/70000, collocation grid index: 3, lr: 0.001 --> loss: 0.00122925, best error: 0.05275854


 40%|████      | 28019/70000 [04:39<05:07, 136.35it/s]

Epoch: 28000/70000, collocation grid index: 3, lr: 0.001 --> loss: 0.00081909, best error: 0.05275854


 41%|████▏     | 29019/70000 [04:47<04:43, 144.35it/s]

Epoch: 29000/70000, collocation grid index: 3, lr: 0.001 --> loss: 0.00070144, best error: 0.05275854


 43%|████▎     | 29992/70000 [04:54<04:34, 145.63it/s]

Epoch: 30000/70000, collocation grid index: 3, lr: 0.001 --> loss: 0.00064205, best error: 0.05275854


 44%|████▍     | 31021/70000 [05:06<06:14, 104.13it/s]

Epoch: 31000/70000, collocation grid index: 4, lr: 0.001 --> loss: 0.00103707, best error: 0.05275854


 46%|████▌     | 32022/70000 [05:14<04:32, 139.28it/s]

Epoch: 32000/70000, collocation grid index: 4, lr: 0.001 --> loss: 0.00058142, best error: 0.05275854


 47%|████▋     | 33013/70000 [05:22<06:06, 100.81it/s]

Epoch: 33000/70000, collocation grid index: 4, lr: 0.001 --> loss: 0.00040373, best error: 0.05275854


 49%|████▊     | 34016/70000 [05:30<04:13, 141.69it/s]

Epoch: 34000/70000, collocation grid index: 4, lr: 0.001 --> loss: 0.00024683, best error: 0.05275854


 50%|████▉     | 34998/70000 [05:37<06:21, 91.73it/s] 

Epoch: 35000/70000, collocation grid index: 4, lr: 0.001 --> loss: 0.00017296, best error: 0.05275854


 51%|█████▏    | 36022/70000 [05:49<05:08, 110.27it/s]

Epoch: 36000/70000, collocation grid index: 5, lr: 0.001 --> loss: 0.00005556, best error: 0.05275854


 53%|█████▎    | 37016/70000 [05:56<04:28, 122.78it/s]

Epoch: 37000/70000, collocation grid index: 5, lr: 0.001 --> loss: 0.00004314, best error: 0.05275854


 54%|█████▍    | 38030/70000 [06:03<03:37, 146.93it/s]

Epoch: 38000/70000, collocation grid index: 5, lr: 0.001 --> loss: 0.00003601, best error: 0.05275854


 56%|█████▌    | 39019/70000 [06:11<04:54, 105.29it/s]

Epoch: 39000/70000, collocation grid index: 5, lr: 0.001 --> loss: 0.00002902, best error: 0.05275854


 57%|█████▋    | 39998/70000 [06:18<03:29, 143.24it/s]

Epoch: 40000/70000, collocation grid index: 5, lr: 0.001 --> loss: 0.00002241, best error: 0.05275854


 59%|█████▊    | 41020/70000 [06:31<03:59, 120.94it/s]

Epoch: 41000/70000, collocation grid index: 1, lr: 0.001 --> loss: 0.00002394, best error: 0.05275854


 60%|██████    | 42014/70000 [06:38<03:27, 134.70it/s]

Epoch: 42000/70000, collocation grid index: 1, lr: 0.001 --> loss: 0.00002415, best error: 0.05275854


 61%|██████▏   | 43019/70000 [06:47<03:04, 145.91it/s]

Epoch: 43000/70000, collocation grid index: 1, lr: 0.001 --> loss: 0.00001429, best error: 0.05275854


 63%|██████▎   | 44028/70000 [06:53<02:56, 147.33it/s]

Epoch: 44000/70000, collocation grid index: 1, lr: 0.001 --> loss: 0.00001080, best error: 0.05275854


 64%|██████▍   | 44988/70000 [07:01<03:23, 123.09it/s]

Epoch: 45000/70000, collocation grid index: 1, lr: 0.001 --> loss: 0.00000857, best error: 0.03569463


 66%|██████▌   | 46019/70000 [07:12<03:24, 117.21it/s]

Epoch: 46000/70000, collocation grid index: 8, lr: 0.001 --> loss: 0.02133887, best error: 0.03569463


 67%|██████▋   | 47030/70000 [07:20<02:39, 144.21it/s]

Epoch: 47000/70000, collocation grid index: 8, lr: 0.001 --> loss: 0.00729798, best error: 0.03569463


 69%|██████▊   | 48019/70000 [07:26<03:09, 116.08it/s]

Epoch: 48000/70000, collocation grid index: 8, lr: 0.001 --> loss: 0.00432913, best error: 0.03569463


 70%|███████   | 49021/70000 [07:35<02:21, 147.90it/s]

Epoch: 49000/70000, collocation grid index: 8, lr: 0.001 --> loss: 0.00294325, best error: 0.03569463


 71%|███████▏  | 49996/70000 [07:42<02:17, 145.07it/s]

Epoch: 50000/70000, collocation grid index: 8, lr: 0.001 --> loss: 0.00200189, best error: 0.03569463


 73%|███████▎  | 51026/70000 [07:54<02:35, 121.81it/s]

Epoch: 51000/70000, collocation grid index: 10, lr: 0.001 --> loss: 0.00091140, best error: 0.03569463


 74%|███████▍  | 52011/70000 [08:02<02:37, 114.50it/s]

Epoch: 52000/70000, collocation grid index: 10, lr: 0.001 --> loss: 0.00057769, best error: 0.03569463


 76%|███████▌  | 53025/70000 [08:10<02:17, 123.32it/s]

Epoch: 53000/70000, collocation grid index: 10, lr: 0.001 --> loss: 0.00044008, best error: 0.03569463


 77%|███████▋  | 54011/70000 [08:18<02:29, 107.26it/s]

Epoch: 54000/70000, collocation grid index: 10, lr: 0.001 --> loss: 0.00036626, best error: 0.03569463


 79%|███████▊  | 54991/70000 [08:25<01:43, 145.70it/s]

Epoch: 55000/70000, collocation grid index: 10, lr: 0.001 --> loss: 0.00028555, best error: 0.03569463


 80%|████████  | 56026/70000 [08:38<02:00, 115.74it/s]

Epoch: 56000/70000, collocation grid index: 9, lr: 0.001 --> loss: 0.00041923, best error: 0.03569463


 81%|████████▏ | 57008/70000 [08:46<01:54, 113.39it/s]

Epoch: 57000/70000, collocation grid index: 9, lr: 0.001 --> loss: 0.00031555, best error: 0.03569463


 83%|████████▎ | 58017/70000 [08:55<01:43, 115.92it/s]

Epoch: 58000/70000, collocation grid index: 9, lr: 0.001 --> loss: 0.00025855, best error: 0.03569463


 84%|████████▍ | 59015/70000 [09:03<01:41, 108.19it/s]

Epoch: 59000/70000, collocation grid index: 9, lr: 0.001 --> loss: 0.00018469, best error: 0.03569463


 86%|████████▌ | 59990/70000 [09:11<01:15, 132.87it/s]

Epoch: 60000/70000, collocation grid index: 9, lr: 0.001 --> loss: 0.00016155, best error: 0.03569463


 87%|████████▋ | 61028/70000 [09:24<01:13, 122.57it/s]

Epoch: 61000/70000, collocation grid index: 0, lr: 0.001 --> loss: 0.00006481, best error: 0.03569463


 89%|████████▊ | 62024/70000 [09:32<00:57, 137.81it/s]

Epoch: 62000/70000, collocation grid index: 0, lr: 0.001 --> loss: 0.00005024, best error: 0.03569463


 90%|█████████ | 63013/70000 [09:41<00:52, 134.03it/s]

Epoch: 63000/70000, collocation grid index: 0, lr: 0.001 --> loss: 0.00004180, best error: 0.03569463


 91%|█████████▏| 64025/70000 [09:49<00:44, 134.86it/s]

Epoch: 64000/70000, collocation grid index: 0, lr: 0.001 --> loss: 0.00005303, best error: 0.03569463


 93%|█████████▎| 65022/70000 [10:00<03:32, 23.39it/s] 

Epoch: 65000/70000, collocation grid index: 0, lr: 0.001 --> loss: 0.00003598, best error: 0.03569463


 94%|█████████▍| 66023/70000 [10:08<00:35, 113.12it/s]

Epoch: 66000/70000, collocation grid index: 7, lr: 0.001 --> loss: 0.00011587, best error: 0.03569463


 96%|█████████▌| 67025/70000 [10:16<00:19, 149.87it/s]

Epoch: 67000/70000, collocation grid index: 7, lr: 0.001 --> loss: 0.00008862, best error: 0.03569463


 97%|█████████▋| 68012/70000 [10:23<00:21, 91.07it/s] 

Epoch: 68000/70000, collocation grid index: 7, lr: 0.001 --> loss: 0.00006695, best error: 0.03569463


 99%|█████████▊| 69026/70000 [10:31<00:06, 147.01it/s]

Epoch: 69000/70000, collocation grid index: 7, lr: 0.001 --> loss: 0.00004974, best error: 0.03569463


100%|██████████| 70000/70000 [10:40<00:00, 109.27it/s]

Epoch: 70000/70000, collocation grid index: 7, lr: 0.001 --> loss: 0.00004841, best error: 0.03569463
Runtime: 9.15 ms/iter.





In [41]:
result_dir = os.path.join(root_dir, 'lbfgs')
os.makedirs(result_dir, exist_ok=True)

lss_list_2, time_list_2, params_list_2, best_params_2 = lbfgs(NC=14, NB=6, SAMP_FREQ=5000, SEED=444, LR=1e-2, EPOCHS=70000,\
                                                    N_LAYERS=3, NNEURONS=[64,10], LOG_ITER=100, PRINT_ITER=10000, NUM_BATCHES_COLL=14,\
                                                    NUM_BATCHES_BOUND=40, NDIMS = 21, BATCH_SIZE = 1, LAMB = 15,\
                                                    PARAM_INIT=params_list[-1], WDIR=result_dir)

Body Network: (64, 64, 10)

Total number of collocation samples: 38416
Total number of boundary samples: 8640


 14%|█▍        | 9989/70000 [01:57<08:29, 117.80it/s] 

Epoch: 10000/70000, collocation grid index: 2, lr: 0.01 --> loss: 0.00002102, best error: 0.03765579


 29%|██▊       | 19999/70000 [03:47<06:28, 128.78it/s] 

Epoch: 20000/70000, collocation grid index: 6, lr: 0.01 --> loss: 0.00001760, best error: 0.02957998


 43%|████▎     | 29997/70000 [05:20<05:02, 132.27it/s] 

Epoch: 30000/70000, collocation grid index: 3, lr: 0.01 --> loss: 0.00000868, best error: 0.02957998


 57%|█████▋    | 39995/70000 [06:56<03:53, 128.74it/s]

Epoch: 40000/70000, collocation grid index: 5, lr: 0.01 --> loss: 0.00000375, best error: 0.02957998


 71%|███████▏  | 49996/70000 [08:41<05:04, 65.77it/s] 

Epoch: 50000/70000, collocation grid index: 8, lr: 0.01 --> loss: 0.00000379, best error: 0.02957998


 86%|████████▌ | 59999/70000 [11:27<02:12, 75.66it/s]  

Epoch: 60000/70000, collocation grid index: 9, lr: 0.01 --> loss: 0.00001489, best error: 0.02957998


100%|██████████| 70000/70000 [14:08<00:00, 82.52it/s]

Epoch: 70000/70000, collocation grid index: 7, lr: 0.01 --> loss: 0.00000655, best error: 0.02957998
Runtime: 12.12 ms/iter.





In [63]:
f_filename = 'ac_results'
s_filename = 'lbfgs'
root_dir1 = os.path.join(os.getcwd(), f_filename)
root_dir1 = os.path.join(root_dir1, s_filename)
_, _, _, best_params = load_results(root_dir1)

# testing on randomly sampled test points
test = test_model_ntimes(444, root_dir, [best_params], 7, [64, 64, 10], ext='lbfgs')

% Relative L2 Error: 3.0749998092651367 %
