# ReLU-KAN Experiments

In [2]:
import jax
import jax.numpy as jnp

import optax
from flax import linen as nn

import sys
import os

import time
import scipy

from jaxkan.models.ReLUKAN import ReLUKAN
from jaxkan.utils.PIKAN import *

import numpy as np

import matplotlib.pyplot as plt

### Collocation Points

In [3]:
# Generate Collocation points for PDE
N = 2**12
collocs = jnp.array(sobol_sample(np.array([-1,-1]), np.array([1,1]), N)) # (4096, 2)

# Generate Collocation points for BCs
N = 2**6

BC1_colloc = jnp.array(sobol_sample(np.array([-1,-1]), np.array([-1,1]), N)) # (64, 2)
BC1_data = jnp.zeros(BC1_colloc.shape[0]).reshape(-1,1) # (64, 1)

BC2_colloc = jnp.array(sobol_sample(np.array([-1,-1]), np.array([1,-1]), N)) # (64, 2)
BC2_data = jnp.zeros(BC2_colloc.shape[0]).reshape(-1,1) # (64, 1)

BC3_colloc = jnp.array(sobol_sample(np.array([1,-1]), np.array([1,1]), N)) # (64, 2)
BC3_data = jnp.zeros(BC3_colloc.shape[0]).reshape(-1,1) # (64, 1)

BC4_colloc = jnp.array(sobol_sample(np.array([-1,1]), np.array([1,1]), N)) # (64, 2)
BC4_data = jnp.zeros(BC4_colloc.shape[0]).reshape(-1,1) # (64, 1)

# Create lists for BCs
bc_collocs = [BC1_colloc, BC2_colloc, BC3_colloc, BC4_colloc]
bc_data = [BC1_data, BC2_data, BC3_data, BC4_data]

2024-09-21 15:51:45.247628: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.6.68). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


## Loss Function

In [4]:
def pde_loss(params, collocs, state):
    # Eq. parameters
    k = jnp.array(1.0, dtype=float)
    a1 = jnp.array(1.0, dtype=float)
    a2 = jnp.array(4.0, dtype=float)
    
    # Define the model function
    variables = {'params' : params, 'state' : state}
    
    def u(vec_x):
        y, spl = model.apply(variables, vec_x)
        return y
        
    # Physics Loss Terms
    u_xx = gradf(u, 0, 2)  # 2nd order derivative of x
    u_yy = gradf(u, 1, 2) # 2nd order derivative of y

    sines = jnp.sin(a1*jnp.pi*collocs[:,[0]])*jnp.sin(a2*jnp.pi*collocs[:,[1]])
    source = -((a1*jnp.pi)**2)*sines - ((a2*jnp.pi)**2)*sines + k*sines
    
    # Residual
    pde_res = u_xx(collocs) + u_yy(collocs) + (k**2)*u(collocs) - source
    
    return pde_res

## Training

### Training Static Case

In [5]:
# Initialize model
layer_dims = [2, 8, 8, 1]
model = ReLUKAN(layer_dims=layer_dims, p=2, k=2, const_R=1.0, const_res=0.0, add_bias=True, grid_e=1.0)
variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 2]))

# Define learning rates for scheduler
lr_vals = dict()
lr_vals['init_lr'] = 0.001
lr_vals['scales'] = {0 : 1.0}

# Define epochs for grid adaptation
grid_adapt = []

# Define epochs for grid extension, along with grid sizes
grid_extend = {0 : 3}

# Define global loss weights
glob_w = [jnp.array(0.01, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float),
          jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]

# Initialize RBA weights
loc_w = [jnp.ones((collocs.shape[0],1)), jnp.ones((BC1_colloc.shape[0],1)),
         jnp.ones((BC2_colloc.shape[0],1)), jnp.ones((BC3_colloc.shape[0],1)),
         jnp.ones((BC4_colloc.shape[0],1))]

In [6]:
num_epochs = 100000

model, variables, train_losses = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, 
                                             lr_vals=lr_vals, adapt_state=True, loc_w=loc_w, nesterov=True, 
                                             num_epochs=num_epochs, grid_extend=grid_extend, grid_adapt=grid_adapt, 
                                             colloc_adapt={'epochs' : []})

Epoch 0: Performing grid update


In [None]:
# Also calculate analytical solution
def helm_exact(x,y):
    a1 = 1.0
    a2 = 4.0
    return np.sin(a1*np.pi*x)*np.sin(a2*np.pi*y)

N_x, N_y = 100, 256

x = np.linspace(-1.0, 1.0, N_x)
y = np.linspace(-1.0, 1.0, N_y)
X, Y = np.meshgrid(x, y, indexing='ij')
coords = np.stack([X.flatten(), Y.flatten()], axis=1)

ref = helm_exact(X, Y)

output, _ = model.apply(variables, jnp.array(coords))
static = np.array(output).reshape(N_x, N_y)

l2err = jnp.linalg.norm(static-ref)/jnp.linalg.norm(ref)
print(f"L^2 Error = {l2err*100:.4f}%")

In [None]:
plt.figure(figsize=(10, 6))

plt.plot(np.array(train_losses), label='Train Loss', marker='o', color='blue', markersize=1)

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.yscale('log')  # Set y-axis to logarithmic scale

plt.legend()
plt.grid(True, which='both', linestyle='--', linewidth=0.5) 

plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.pcolormesh(X, Y, np.abs(static-ref), shading='auto', cmap='Spectral_r') #
plt.colorbar()

plt.title('Absolute Error for Helmholtz Equation')
plt.xlabel('x')

plt.ylabel('y')

plt.tight_layout()
plt.show()

In [None]:
epochs = np.arange(num_epochs)
np.savez('../Plots/data/relu1.npz', x=x, y=y, res=static, ref=ref)

### Training Non-fully Adaptive Case

In [None]:
# Initialize model
layer_dims = [2, 8, 8, 1]
model = ReLUKAN(layer_dims=layer_dims, p=2, k=2, const_R=1.0, const_res=0.0, add_bias=True, grid_e=1.0)
variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 2]))

# Define learning rates for scheduler
lr_vals = dict()
lr_vals['init_lr'] = 0.001
lr_vals['scales'] = {0 : 1.0, 8_000 : 0.5, 15_000 : 0.5, 30_000 : 0.4, 50_000 : 0.7, 70_000 : 0.7}

# Define epochs for grid adaptation
grid_adapt = []

# Define epochs for grid extension, along with grid sizes
grid_extend = {0 : 3, 20_000 : 6, 35_000 : 12}

# Define global loss weights
glob_w = [jnp.array(0.01, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float),
          jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]

# Initialize RBA weights
loc_w = [jnp.ones((collocs.shape[0],1)), jnp.ones((BC1_colloc.shape[0],1)),
         jnp.ones((BC2_colloc.shape[0],1)), jnp.ones((BC3_colloc.shape[0],1)),
         jnp.ones((BC4_colloc.shape[0],1))]


In [None]:
num_epochs = 100_000

model, variables, train_losses2 = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, 
                                             lr_vals=lr_vals, adapt_state=True, loc_w=loc_w, nesterov=True, 
                                             num_epochs=num_epochs, grid_extend=grid_extend, grid_adapt=grid_adapt, 
                                             colloc_adapt={'epochs' : []})

In [None]:
# Also calculate analytical solution
def helm_exact(x,y):
    a1 = 1.0
    a2 = 4.0
    return np.sin(a1*np.pi*x)*np.sin(a2*np.pi*y)

N_x, N_y = 100, 256

x = np.linspace(-1.0, 1.0, N_x)
y = np.linspace(-1.0, 1.0, N_y)
X, Y = np.meshgrid(x, y, indexing='ij')
coords = np.stack([X.flatten(), Y.flatten()], axis=1)

ref = helm_exact(X, Y)

output, _ = model.apply(variables, jnp.array(coords))
nonfullya = np.array(output).reshape(N_x, N_y)

l2err = jnp.linalg.norm(nonfullya-ref)/jnp.linalg.norm(ref)
print(f"L^2 Error = {l2err*100:.4f}%")

In [None]:
plt.figure(figsize=(10, 6))

plt.plot(np.array(train_losses2), label='Train Loss', marker='o', color='blue', markersize=1)

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.yscale('log')  # Set y-axis to logarithmic scale

plt.legend()
plt.grid(True, which='both', linestyle='--', linewidth=0.5) 

plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.pcolormesh(X, Y, np.abs(nonfullya-ref), shading='auto', cmap='Spectral_r') #
plt.colorbar()

plt.title('Absolute Error for Helmholtz Equation')
plt.xlabel('x')

plt.ylabel('y')

plt.tight_layout()
plt.show()

In [None]:
epochs = np.arange(num_epochs)
grids = np.array(list(grid_extend.keys()))
np.savez('../Plots/data/relu2.npz', x=x, y=y, res=nonfullya, ref=ref)

### Training Fully Adaptive Case

In [None]:
# Initialize model
layer_dims = [2, 8, 8, 1]
model = ReLUKAN(layer_dims=layer_dims, p=2, k=2, const_R=False, const_res=False, add_bias=True, grid_e=0.05)
variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 2]))

# Define learning rates for scheduler
lr_vals = dict()
lr_vals['init_lr'] = 0.001
lr_vals['scales'] = {0 : 1.0, 20_000 : 0.6, 35_000 : 0.8, 50_000 : 0.7, 70_000 : 0.7}

# Define epochs for grid adaptation
adapt_every = 500
adapt_stop = 40000
grid_adapt = [i * adapt_every for i in range(1, (adapt_stop // adapt_every) + 1)]
grid_adapt = []

# Define epochs for grid extension, along with grid sizes
grid_extend = {0 : 3, 20_000 : 6, 35_000 : 12}

# Define global loss weights
glob_w = [jnp.array(0.01, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float),
          jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]

# Initialize RBA weights
loc_w = [jnp.ones((collocs.shape[0],1)), jnp.ones((BC1_colloc.shape[0],1)),
         jnp.ones((BC2_colloc.shape[0],1)), jnp.ones((BC3_colloc.shape[0],1)),
         jnp.ones((BC4_colloc.shape[0],1))]

In [None]:
num_epochs = 100_000

model, variables, train_losses3 = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, 
                                             lr_vals=lr_vals, adapt_state=True, loc_w=loc_w, nesterov=True, 
                                             num_epochs=num_epochs, grid_extend=grid_extend, grid_adapt=grid_adapt, 
                                             colloc_adapt={'epochs' : []})

In [None]:
# Also calculate analytical solution
def helm_exact(x,y):
    a1 = 1.0
    a2 = 4.0
    return np.sin(a1*np.pi*x)*np.sin(a2*np.pi*y)

N_x, N_y = 100, 256

x = np.linspace(-1.0, 1.0, N_x)
y = np.linspace(-1.0, 1.0, N_y)
X, Y = np.meshgrid(x, y, indexing='ij')
coords = np.stack([X.flatten(), Y.flatten()], axis=1)

ref = helm_exact(X, Y)

output, _ = model.apply(variables, jnp.array(coords))
fullya = np.array(output).reshape(N_x, N_y)

l2err = jnp.linalg.norm(fullya-ref)/jnp.linalg.norm(ref)
print(f"L^2 Error = {l2err*100:.4f}%")

In [None]:
plt.figure(figsize=(10, 6))

plt.plot(np.array(train_losses3), label='Train Loss', marker='o', color='blue', markersize=1)

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.yscale('log')  # Set y-axis to logarithmic scale

plt.legend()
plt.grid(True, which='both', linestyle='--', linewidth=0.5) 

plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.pcolormesh(X, Y, np.abs(fullya-ref), shading='auto', cmap='Spectral_r') #
plt.colorbar()

plt.title('Absolute Error for Helmholtz Equation')
plt.xlabel('x')

plt.ylabel('y')

plt.tight_layout()
plt.show()

In [None]:
epochs = np.arange(num_epochs)
grids = np.array(list(grid_extend.keys()))
np.savez('../Plots/data/relu3.npz', x=x, y=y, res=fullya, ref=ref)