In [15]:
# model_utils.py

import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from typing import Callable
from dataclasses import field
from typing import List
from scipy.stats.qmc import Sobol
import numpy as np
import matplotlib.pyplot as plt
import time
from jaxkan.models.KAN import KAN

class FourierFeats(nn.Module):
    num_output: int
    
    std = 5
    mean = 0
    
    @nn.compact
    def __call__(self, x):
        B = self.param(
            'B', lambda rng, shape: self.mean + jax.random.normal(rng, shape) * self.std,
           (x.shape[-1], self.num_output // 2)
        )
        bias = self.param(
            'bias', lambda rng, shape: jax.random.normal(rng, shape) * 0,
            (self.num_output // 2,)
        )
        
        x = jnp.matmul(x, B)
        x = jnp.concatenate([jnp.cos((x + bias)), jnp.sin((x + bias))], axis=-1)
            
        return x

class PirateBlock(nn.Module):
    kernel_init: Callable
    num_hidden: int
    
    @nn.compact
    def __call__(self, x, U, V):
        f = nn.Dense(self.num_hidden, kernel_init=self.kernel_init)(x)
        f = nn.tanh(f)
        
        z_1 = f*U + (1-f)*V
        g = nn.Dense(self.num_hidden, kernel_init=self.kernel_init)(z_1)
        g = nn.tanh(g)
        
        z_2 = g*U + (1-g)*V
        h = nn.Dense(self.num_hidden, kernel_init=self.kernel_init)(z_2)
        h = nn.tanh(h)
        
        alpha = self.param(
            'alpha', lambda rng: 0,
        )
        x_next = alpha*h + alpha*x
        
        return x_next
    
        
class PirateNet(nn.Module):
    kernel_init: Callable
    num_input: int
    num_output: int
    layer_sizes: List[int] = field(default_factory=list)

    @nn.compact
    def __call__(self, x):
        # Add hidden layers
        for idx, size in enumerate(self.layer_sizes):
            if idx==0:
                x = FourierFeats(size)(x)
                
                U = nn.Dense(size, kernel_init=self.kernel_init)(x)
                U = nn.tanh(x)

                V = nn.Dense(size, kernel_init=self.kernel_init)(x)
                V = nn.tanh(x)
            else:
                x = PirateBlock(self.kernel_init, size)(x, U, V)
            
        # Final output layer
        x = nn.Dense(self.num_output, kernel_init=self.kernel_init)(x)
        return x

In [22]:
from model_utils import KeyHandler

r = KeyHandler(0)

model = PirateNet(
    kernel_init=nn.initializers.glorot_normal(),
    num_input=2,
    num_output=1,
    layer_sizes=[64, 64] # first is fourier
)

collocs = jnp.ones((64, 2))
variables = model.init(r.key(), collocs)
variables["params"]

{'FourierFeats_0': {'B': Array([[ -1.2065982 ,   1.8077171 ,  -5.211637  ,   3.3368049 ,
           -8.549459  ,   0.72885275,   5.904413  ,   2.890458  ,
           -4.313114  ,   7.554807  ,   7.484484  ,  -1.7095883 ,
           11.507137  ,  -2.5598488 ,  -2.8855038 ,  -7.804573  ,
            3.4581532 ,  -1.1767659 ,  -7.5463705 ,  -3.5254617 ,
           -3.4090886 ,  -0.03795527,  -3.9488685 ,  -6.2443924 ,
            0.8240086 ,  16.686298  ,   1.0624833 ,   0.6462384 ,
            3.198957  ,  -4.807271  ,   0.64688027,  -4.886603  ],
         [ -7.153842  ,  -5.066449  ,  -4.5582786 ,  -4.2030673 ,
           -0.6217977 ,  -0.8811017 ,  -4.7639885 ,   5.049013  ,
            1.7905627 ,  -1.5781153 ,  -8.4924345 ,   0.41131723,
            0.41815665,   1.3594339 ,  -3.0519276 ,  -0.84154475,
           -2.4935863 ,  -2.2750785 ,   3.0883896 ,   0.40678853,
            1.0786219 ,  11.586031  ,  -1.0293334 ,  -8.618601  ,
           -3.3506832 ,  -4.5110946 ,   5.2130175 , 