In [1]:
import jax
import jax.numpy as jnp
from pikan.model_utils import GeneralizedMLP
from flax import linen as nn

devices = jax.devices()

In [2]:
model = GeneralizedMLP(
    kernel_init=nn.initializers.glorot_normal(),
    num_input=2,
    num_output=1,
    use_fourier_feats=True,
    layer_sizes=[128, 128],
)

key = jax.random.PRNGKey(0)
collocs = jnp.ones((2))
params = model.init(key, collocs)['params']
model.apply({"params": params}, collocs)

Array([0.8603496], dtype=float32)

In [3]:
# Define the inference function
def inference(params, model, x, t):
    x = jnp.stack([x,t])
    return model.apply({'params': params}, x)

inference(params, model, 0, 1)

Array([-0.10571431], dtype=float32)

In [34]:
import jax
import jax.numpy as jnp
from jax import grad, vmap
from jax.scipy.special import gamma

def inference(params, x, t):
    output = model.apply({'params': params}, jnp.array([x, t]))
    return output[0]

def get_caputo_derivative(inference):
    @jax.jit
    def caputo_derivative(params, x, t, alpha, dt=1e-3, num_steps = 10000):
        """
        Compute the Caputo derivative of order alpha for a function f(x, t) with respect to t.
    
        Parameters:
        - f: A function f(x, t) that takes two arguments, x and t.
        - x: The spatial variable.
        - t: The time variable.
        - alpha: The order of the Caputo derivative (0 < alpha < 1).
        - dt: The time step for discretization.
    
        Returns:
        - The Caputo derivative of f(x, t) at time t.
        """
        # Define the integrand
        def integrand(tau):
            return grad(inference, 2)(params, x, tau) / (t - tau)**alpha
    
        # Fixed number of steps for static shape
        tau_values = jnp.linspace(0, t - dt, num_steps)  # Exclude t
        integrand_values = vmap(integrand)(tau_values)
    
        # Compute the integral using the trapezoidal rule
        integral = jnp.trapezoid(integrand_values, tau_values)
    
        # Normalize by the gamma function
        return integral / gamma(1 - alpha)
    
    return caputo_derivative
    
x = 1.0
t = 1.0
alpha = 0.5

caputo_derivative_fn = get_caputo_derivative(inference)
caputo_deriv = caputo_derivative_fn(params, x, t, alpha)

print("Caputo Derivative:", caputo_deriv)

Caputo Derivative: 1.4690284


In [44]:
from pikan.model_utils import sobol_sample
import numpy as np

collocs = sobol_sample(np.array([0, 0]), np.array([1, 1]), BS)

jax.vmap(caputo_derivative_fn, (None, 0,0,None))(params, collocs[:, 0], collocs[:, 1], alpha), collocs.shape

(Array([-1.9368526 ,  1.5279231 ,  1.6757048 , -0.86210805,  1.9856383 ,
         1.916959  , -0.23063819, -0.11893816, -3.0236068 ,  1.0715395 ,
        -2.8631775 ,  2.1318824 , -0.08895866,  0.6813856 ,  1.2043933 ,
         1.4326512 , -0.7403909 ,  2.353486  , -1.3957287 ,  0.734305  ,
         1.6372179 , -2.115138  ,  1.8639367 ,  0.07038371, -0.36098865,
         2.0357447 , -0.4261263 , -1.1245093 , -1.5903128 ,  0.70380837,
         0.4670559 , -0.57454264, -1.7833678 ,  0.37029892, -2.7208304 ,
         1.5304441 ,  2.3418148 ,  1.1240007 ,  0.91855377, -0.49901232,
        -1.0983696 , -1.2291723 , -0.60905015, -1.3679131 , -1.5742713 ,
         2.0972848 ,  0.584064  ,  0.84370744, -1.340998  ,  1.5299377 ,
         1.2763233 , -0.38129574, -1.714182  ,  1.1769115 , -1.8024554 ,
        -1.5488288 , -0.98314834,  1.4269894 , -2.7450976 ,  1.4711574 ,
         1.4616169 , -0.38075423,  1.2749709 ,  0.12267178], dtype=float32),
 (64, 2))