In [1]:
import jax
import jax.numpy as jnp
from flax import linen as nn
import numpy as np
from pikan.model_utils import GeneralizedMLP, sobol_sample
from jax import grad, vmap, jit
from jax.scipy.special import gamma
from functools import partial
import optax

import os
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt

jax.devices()

[CpuDevice(id=0)]

In [2]:
model = GeneralizedMLP(
    kernel_init=nn.initializers.glorot_normal(),
    num_input=2,
    num_output=1,
    use_fourier_feats=True,
    layer_sizes=[128, 128],
)

key = jax.random.PRNGKey(0)
collocs = jnp.ones((2))
params = model.init(key, collocs)['params']
model.apply({"params": params}, collocs)

Array([0.8603496], dtype=float32)

In [3]:
# Define the inference function
def inference(params, model, x, t):
    x = jnp.stack([x,t])
    return model.apply({'params': params}, x)

inference(params, model, 0, 1)

Array([-0.10571431], dtype=float32)

In [4]:
def inference(params, x, t):
    output = model.apply({'params': params}, jnp.array([x, t]))
    return output[0]

def get_caputo_derivative(inference):
    @jax.jit
    def caputo_derivative(params, x, t, alpha, dt=1e-3, num_steps = 10000):
        """
        Compute the Caputo derivative of order alpha for a function f(x, t) with respect to t.
    
        Parameters:
        - f: A function f(x, t) that takes two arguments, x and t.
        - x: The spatial variable.
        - t: The time variable.
        - alpha: The order of the Caputo derivative (0 < alpha < 1).
        - dt: The time step for discretization.
    
        Returns:
        - The Caputo derivative of f(x, t) at time t.
        """
        # Define the integrand
        def integrand(tau):
            return grad(inference, 2)(params, x, tau) / (t - tau)**alpha
    
        # Fixed number of steps for static shape
        tau_values = jnp.linspace(0, t - dt, num_steps)  # Exclude t
        integrand_values = vmap(integrand)(tau_values)
    
        # Compute the integral using the trapezoidal rule
        integral = jnp.trapezoid(integrand_values, tau_values)
    
        # Normalize by the gamma function
        return integral / gamma(1 - alpha)
    
    return caputo_derivative
    
x = 1.0
t = 1.0
alpha = 0.5

caputo_derivative_fn = get_caputo_derivative(inference)
caputo_deriv = caputo_derivative_fn(params, x, t, alpha)

print("Caputo Derivative:", caputo_deriv)

Caputo Derivative: 1.4690284


In [5]:
BS = 64
collocs = sobol_sample(np.array([-1, 0]), np.array([1, 1]), BS)

jax.vmap(caputo_derivative_fn, (None, 0,0,None))(params, collocs[:, 0], collocs[:, 1], alpha), collocs.shape

(Array([ 0.352297  ,  0.07153108,  0.9043222 , -1.1922547 ,  1.1729538 ,
         0.7631223 ,  2.1819003 ,  0.29982382,  1.008047  ,  2.9932196 ,
        -1.1980772 , -1.2293242 , -0.48812997,  0.40089384,  1.9867368 ,
        -0.26317292, -0.30621415,  1.7191843 ,  0.6605068 , -1.1688274 ,
        -1.2866187 , -0.56121   ,  2.1097345 , -2.1567328 ,  1.998726  ,
         2.600072  ,  1.7475344 ,  0.35622153,  0.20362371, -1.2614639 ,
        -0.8023467 ,  1.7286482 ,  0.80583006,  1.7793441 ,  0.0047167 ,
        -1.3672754 , -0.34065434,  0.40130576,  2.1432748 , -2.4950643 ,
         1.7482466 ,  1.1350166 ,  1.0107727 ,  1.9275477 , -1.0320338 ,
        -1.9398096 ,  0.8786229 ,  1.0118887 ,  1.9089093 ,  3.3738277 ,
        -2.0842304 , -0.6886029 ,  1.3699871 , -0.94748884, -0.77802664,
         1.131459  , -0.40089476,  2.266505  , -0.867456  , -1.2834089 ,
         0.36519396,  0.6913541 ,  1.5833902 ,  0.01851468], dtype=float32),
 (64, 2))

In [6]:
# caputo diffusion on 1d
class fractional_diffusion():
    def __init__(self, model, bc_l, bc_r, ic_func, alpha, dom=[-1,1]):
        self.bc_l = bc_l # boundary vals
        self.bc_r = bc_r

        self.ic_collocs = jnp.linspace(dom[0], dom[1], 1000)
        self.ic_vals = ic_func(self.ic_collocs) # t_0 values
        
        self.alpha = alpha
        self.dom = dom

        self.caputo_derivative = get_caputo_derivative(self.neural_net)

        # paralellize for faster computation
        self.neural_net_fn = jax.vmap(self.neural_net, (None, 0, 0))
        self.residual_loss_fn = jax.vmap(self.residual_loss, (None, 0, 0))

    def neural_net(self, params, x, t):
        output = model.apply({'params': params}, jnp.array([x, t]))
        return output[0]

    def residual_loss(self, params, x, t):
        f_derivative = self.caputo_derivative(params, x, t, self.alpha)
        laplacian = grad(grad(self.neural_net, argnums=1), argnums=1)(params, x, t)
        
        return f_derivative - laplacian

    def mse(self, arr):
        return jnp.sum(arr**2)
    
    @partial(jit, static_argnums=(0,))
    def loss(self, params, collocs):
        ic_loss = self.neural_net_fn(params, self.ic_collocs, jnp.zeros_like(self.ic_collocs))
        bc_l_loss = self.neural_net_fn(params, jnp.full_like(collocs[:,1], self.dom[0]), collocs[:,1])
        bc_r_loss = self.neural_net_fn(params, jnp.full_like(collocs[:,1], self.dom[1]), collocs[:,1])

        eq_loss = self.residual_loss_fn(params, collocs[:,0], collocs[:,1])

        # losses could be added with custom weights, TODO gradnorm
        loss = self.mse(ic_loss) + self.mse(bc_l_loss) + self.mse(bc_r_loss) + self.mse(eq_loss)

        new_loc_w = 0 # maybe implement local weights
        return loss, new_loc_w 

fdiff = fractional_diffusion(model, 0, 0, ic_func=lambda x: jnp.cos(x*jnp.pi/2), alpha=.5)
fdiff.neural_net(params, 0., 1.), fdiff.residual_loss(params, 0., 1.)

BS = 64
collocs = sobol_sample(np.array([-1, 0]), np.array([1, 1]), BS)
fdiff.loss(params, collocs)

(Array(70966.03, dtype=float32), Array(0, dtype=int32, weak_type=True))

In [7]:
# Define your gradient function
grad_fn = jax.value_and_grad(fdiff.loss, has_aux=True)

# Define the training loop
def train_step(params, collocs, opt_state):
    # Compute loss and gradients
    (loss, new_loc_w), grads = grad_fn(params, collocs)

    # Apply gradients to update the parameters
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    
    return params, opt_state, loss

In [8]:
EPOCHS = 1000

# Define a cosine decay learning rate schedule
# Learning rate schedule (cosine decay)
schedule_fn = optax.cosine_decay_schedule(
    init_value=1e-3,      # Initial learning rate
    decay_steps=2000,     # Total number of decay steps
    alpha=0.1             # Final learning rate multiplier
)

# Optimizer setup with Adam
optimizer = optax.adamw(
    learning_rate=schedule_fn,
    b1=0.9,               # Beta1
    b2=0.999,             # Beta2
    eps=1e-8              # Epsilon
)

opt_state = optimizer.init(params)

In [9]:
# Function to save parameters and state
def save_checkpoint(params, opt_state, epoch, filename):
    with open(filename, "wb") as f:
        pickle.dump({'params': params, 'opt_state': opt_state, 'epoch': epoch}, f)
    print(f"Checkpoint saved at epoch {epoch}")

def load_checkpoint(filename, params, state):
    if os.path.exists(filename):
        with open(filename, "rb") as f:
            checkpoint = pickle.load(f)
        print(f"Checkpoint loaded from epoch {checkpoint['epoch']}")
        return checkpoint['params'], checkpoint['opt_state'], checkpoint['epoch']
    return params, state, 0 

In [None]:
# Define constants
BS = 64
EPOCHS = 1000
TMAX = 1
CHECKPOINT_FILE = "diff_v1.pkl"

# Initialize or load checkpoint
params, opt_state, start_epoch = load_checkpoint(CHECKPOINT_FILE, params, opt_state)

# Main training loop
for i in (pbar := tqdm(range(start_epoch, EPOCHS))):
    collocs = sobol_sample(np.array([-1, 0]), np.array([1, TMAX]), BS)
    params, opt_state, loss = train_step(params, collocs, opt_state)
    
    if i % 10 == 0:
        pbar.set_description(f"Loss {loss: .8f}")
    
    if i % 50 == 0:  # Save every x epochs
        save_checkpoint(params, opt_state, i, CHECKPOINT_FILE)

save_checkpoint(params, opt_state, i, CHECKPOINT_FILE)

Checkpoint loaded from epoch 0


Loss  30014.42382812:   0%|                  | 1/1000 [00:10<2:59:52, 10.80s/it]

Checkpoint saved at epoch 0


Loss  30014.42382812:   1%|▏                 | 9/1000 [00:58<1:40:07,  6.06s/it]