In [1]:
import jax
import optax

import jax.numpy as jnp
import jax.random as jrd
import numpy as np

from jax.random import PRNGKey
from flax import linen as nn

import sys
sys.path.insert(1, '..')
from kronojax.neural.KAN import KAN, MLP
from kronojax.neural.embedding import time_embedding

In [2]:
# Set random Key
key = jrd.PRNGKey(0)

#### parametrize flow and drift functions with neural networks

In [3]:
class NN_drift(nn.Module):
    """parametrize drift function"""

    time_embedding_dim: int
    time_freq_min: float
    time_freq_max: float
    dim_list: list

    @nn.compact
    def __call__(self, x, t):

        t_embedded = time_embedding(t, self.time_freq_min, self.time_freq_max, self.time_embedding_dim)
        input_forward = jnp.concatenate([x, t_embedded], axis=1)
        NN_forward = KAN(dim_list=self.dim_list, degree=2)

        return NN_forward(input_forward)

class NN_flow(nn.Module):
    """parametrize flow function"""

    time_embedding_dim: int
    time_freq_min: float
    time_freq_max: float
    dim_list: list

    @nn.compact
    def __call__(self, x, t):

        t_embedded = time_embedding(t, self.time_freq_min, self.time_freq_max, self.time_embedding_dim)
        input_flow = jnp.concatenate([x, t_embedded], axis=1)
        NN_flow = MLP(dim_list=self.dim_list)

        return jnp.exp(NN_flow(input_flow))
    
class NN_FlowAndDirft(nn.Module):
    """concatenate two neural networks for drift and flow"""

    time_embedding_dim: int
    time_freq_min: float
    time_freq_max: float
    dim_list_drift: list
    dim_list_flow: list

    @nn.compact
    def __call__(self, x, t):

        drift = NN_drift(time_embedding_dim=self.time_embedding_dim, 
                         time_freq_min=self.time_freq_min, time_freq_max=self.time_freq_max, dim_list=self.dim_list_drift)
        
        flow = NN_flow(time_embedding_dim=self.time_embedding_dim, 
                       time_freq_min=self.time_freq_min, time_freq_max=self.time_freq_max, dim_list=self.dim_list_flow)

        return drift(x, t), flow(x, t)

In [4]:
key, subkey = jrd.split(key)
time_embedding_dim = 16
time_freq_min = 1.
time_freq_max = 10
output_dim = 1
dim_list_drift = [32, 32, output_dim]
dim_list_flow = [32, 32, output_dim]
# GFnet = NN_FlowAndDirft(time_embedding_dim, time_freq_min, time_freq_max, dim_list_drift, dim_list_flow)
net_drift = NN_drift(time_embedding_dim, time_freq_min, time_freq_max, dim_list_drift)
net_flow = NN_flow(time_embedding_dim, time_freq_min, time_freq_max, dim_list_flow)

# initialize the network
batch_sz = 32
key, subkey = jrd.split(key)
xs = jrd.normal(subkey, (batch_sz, 1))
ts = jrd.uniform(subkey, (batch_sz, 1))
params_drift = net_drift.init(subkey, xs, ts)
params_flow = net_flow.init(subkey, xs, ts)

In [5]:
# print(params)

In [6]:
print(params['params']['NN_drift_0']['KAN_0']['ChebyKANLayer_0']['cheby_coeffs'].shape)
print(params['params']['NN_drift_0']['KAN_0']['LayerNorm_0']['scale'].shape)
print(params['params']['NN_drift_0']['KAN_0']['LayerNorm_0']['bias'].shape)

print(params['params']['NN_drift_0']['KAN_0']['ChebyKANLayer_1']['cheby_coeffs'].shape)
print(params['params']['NN_drift_0']['KAN_0']['LayerNorm_1']['scale'].shape)
print(params['params']['NN_drift_0']['KAN_0']['LayerNorm_1']['bias'].shape)

print(params['params']['NN_drift_0']['KAN_0']['ChebyKANLayer_2']['cheby_coeffs'].shape)

print(params['params']['NN_flow_0']['MLP_0']['Dense_0']['kernel'].shape)
print(params['params']['NN_flow_0']['MLP_0']['Dense_0']['bias'].shape)

print(params['params']['NN_flow_0']['MLP_0']['Dense_1']['kernel'].shape)
print(params['params']['NN_flow_0']['MLP_0']['Dense_1']['bias'].shape)

print(params['params']['NN_flow_0']['MLP_0']['Dense_2']['kernel'].shape)
print(params['params']['NN_flow_0']['MLP_0']['Dense_2']['bias'].shape)

(17, 32, 3)
(32,)
(32,)
(32, 32, 3)
(32,)
(32,)
(32, 1, 3)
(17, 32)
(32,)
(32, 32)
(32,)
(32, 1)
(1,)


#### define loss function

In [5]:
def normal_density(x, mu, sigma2):
    """
    x: n-dim array
    mu: n-dim array
    sigma2: scalar
    """
    return jnp.exp(-0.5 * (x - mu)**2 / sigma2) / jnp.sqrt(2 * jnp.pi * sigma2)

In [21]:
def loss(
        params_drift: any,
        params_flow: any,
        batch_sz: int,
        N_step: int,
        key: PRNGKey
        ):
    """define the loss function -- trajectory balance"""

    T = 1.
    dt = T / N_step
    sqt = jnp.sqrt(dt)

    def _step(carry, _):
        xo, t = carry
        to = jnp.ones_like(xo) * t
        u = net_drift.apply(params_drift, xo, to)
        f = net_flow.apply(params_flow, xo, to)
        dw = sqt * jrd.normal(subkey, xo.shape)
        xn = xo + u*dt + dw
        t += dt
        PF = normal_density(xn, xo + u*dt, dt) # Forward probability
        PB = normal_density(xo, xn, dt) # Backward probability
        output_dict = {
            "t": t,
            "x": xo,
            "P_forward": PF,
            "P_backward": PB,
            "state flow": f
            }
        return (xn, t), output_dict
    key, subkey = jrd.split(key)
    t_init = 0.
    x_init = jnp.zeros((batch_sz, 1))
    carry_init = (x_init, t_init)
    _, trajectory = jax.lax.scan(_step, carry_init, xs = None, length=N_step)
    xT = trajectory["x"][-1]
    tT = jnp.ones_like(xT)
    FN = net_flow.apply(params_flow, xT, tT)
    return trajectory, FN

In [24]:
trajectory, FN = loss(params_drift, params_flow, 32, 10, key)

(32, 1)