In [1]:
import jax
import optax
import time

import jax.numpy as jnp
import jax.random as jrd
import numpy as np

from jax.random import PRNGKey
from flax import linen as nn

import sys
sys.path.insert(1, '..')
from kronojax.neural.KAN import KAN, MLP
from kronojax.neural.embedding import time_embedding

In [2]:
# Set random Key
key = jrd.PRNGKey(0)

#### Parametrize flow and drift functions with neural networks

In [3]:
class NN_drift(nn.Module):
    """parametrize drift function"""

    time_embedding_dim: int
    time_freq_min: float
    time_freq_max: float
    dim_list: list

    @nn.compact
    def __call__(self, x, t):

        t_embedded = time_embedding(t, self.time_freq_min, self.time_freq_max, self.time_embedding_dim)
        input_forward = jnp.concatenate([x, t_embedded], axis=1)
        NN_forward = KAN(dim_list=self.dim_list, degree=2)

        return NN_forward(input_forward)

class NN_flow(nn.Module):
    """parametrize flow function"""

    time_embedding_dim: int
    time_freq_min: float
    time_freq_max: float
    dim_list: list

    @nn.compact
    def __call__(self, x, t):

        t_embedded = time_embedding(t, self.time_freq_min, self.time_freq_max, self.time_embedding_dim)
        input_flow = jnp.concatenate([x, t_embedded], axis=1)
        NN_flow = MLP(dim_list=self.dim_list)

        return jnp.exp(NN_flow(input_flow))
    
class NN_FlowAndDirft(nn.Module):
    """concatenate two neural networks for drift and flow"""

    time_embedding_dim: int
    time_freq_min: float
    time_freq_max: float
    dim_list_drift: list
    dim_list_flow: list

    @nn.compact
    def __call__(self, x, t):

        drift = NN_drift(time_embedding_dim=self.time_embedding_dim, 
                         time_freq_min=self.time_freq_min, time_freq_max=self.time_freq_max, dim_list=self.dim_list_drift)
        
        flow = NN_flow(time_embedding_dim=self.time_embedding_dim, 
                       time_freq_min=self.time_freq_min, time_freq_max=self.time_freq_max, dim_list=self.dim_list_flow)

        return drift(x, t), flow(x, t)

In [4]:
key, subkey = jrd.split(key)
time_embedding_dim = 16
time_freq_min = 1.
time_freq_max = 10
output_dim = 1
dim_list_drift = [32, 32, output_dim]
dim_list_flow = [32, 32, output_dim]
GFnet = NN_FlowAndDirft(time_embedding_dim, time_freq_min, time_freq_max, dim_list_drift, dim_list_flow)
# net_drift = NN_drift(time_embedding_dim, time_freq_min, time_freq_max, dim_list_drift)
# net_flow = NN_flow(time_embedding_dim, time_freq_min, time_freq_max, dim_list_flow)

# initialize the network
batch_sz = 32
key, subkey = jrd.split(key)
xs = jrd.normal(subkey, (batch_sz, 1))
ts = jrd.uniform(subkey, (batch_sz, 1))
params = GFnet.init(subkey, xs, ts)
# params_drift = net_drift.init(subkey, xs, ts)
# params_flow = net_flow.init(subkey, xs, ts)
# params = {'drift': params_drift, 'flow': params_flow}

In [5]:
# print(params_drift['params']['KAN_0']['ChebyKANLayer_0']['cheby_coeffs'].shape)
# print(params_drift['params']['KAN_0']['LayerNorm_0']['scale'].shape)
# print(params_drift['params']['KAN_0']['LayerNorm_0']['bias'].shape)

# print(params_drift['params']['KAN_0']['ChebyKANLayer_1']['cheby_coeffs'].shape)
# print(params_drift['params']['KAN_0']['LayerNorm_1']['scale'].shape)
# print(params_drift['params']['KAN_0']['LayerNorm_1']['bias'].shape)

# print(params_drift['params']['KAN_0']['ChebyKANLayer_2']['cheby_coeffs'].shape)

# print(params_flow['params']['MLP_0']['Dense_0']['kernel'].shape)
# print(params_flow['params']['MLP_0']['Dense_0']['bias'].shape)

# print(params_flow['params']['MLP_0']['Dense_1']['kernel'].shape)
# print(params_flow['params']['MLP_0']['Dense_1']['bias'].shape)

# print(params_flow['params']['MLP_0']['Dense_2']['kernel'].shape)
# print(params_flow['params']['MLP_0']['Dense_2']['bias'].shape)

#### Define the loss function

In [6]:
def normal_density(x, mu, sigma2):
    """
    x: n-dim array
    mu: n-dim array
    sigma2: scalar
    """
    return jnp.exp(-0.5 * (x - mu)**2 / sigma2) / jnp.sqrt(2 * jnp.pi * sigma2)

In [7]:
def Traj(
        params: any,
        batch_sz: int,
        N_step: int,
        key: PRNGKey
        ):
    """keep track on flow function, forward probability, and backward probability"""

    T = 1.
    dt = T / N_step
    sqt = jnp.sqrt(dt)

    def _step(carry, _):
        xo, t = carry
        to = jnp.ones_like(xo) * t
        # u = net_drift.apply(params_drift, xo, to)
        # f = net_flow.apply(params_flow, xo, to)
        u, f = GFnet.apply(params, xo, to)
        dw = sqt * jrd.normal(subkey, xo.shape)
        xn = xo + u*dt + dw
        t += dt
        PF = normal_density(xn, xo + u*dt, dt) # Forward probability
        PB = normal_density(xo, xn, dt) # Backward probability
        output_dict = {
            "t": t,
            "x": xo,
            "P_forward": PF,
            "P_backward": PB,
            "state flow": f
            }
        return (xn, t), output_dict
    
    key, subkey = jrd.split(key)
    t_init = 0.
    x_init = jnp.zeros((batch_sz, 1))
    carry_init = (x_init, t_init)
    _, trajectory = jax.lax.scan(_step, carry_init, xs = None, length=N_step)
    xT = trajectory["x"][-1]
    tT = jnp.ones_like(xT)
    # FN = net_flow.apply(params_flow, xT, tT)
    _, FN = GFnet.apply(params, xT, tT)
    return trajectory, FN

In [8]:
trajectory, FN= Traj(params, batch_sz, 10, key)
print(trajectory["t"].shape)
print(trajectory["x"].shape)
print(trajectory["P_forward"].shape)
print(trajectory["P_backward"].shape)
print(trajectory["state flow"].shape)
print(FN.shape)

(10,)
(10, 32, 1)
(10, 32, 1)
(10, 32, 1)
(10, 32, 1)
(32, 1)


In [9]:
def loss(
        params, 
        batch_sz: int, 
        N_step: int, 
        key: PRNGKey
        ):
    """loss function -- total trajectory balance"""
    # params_drift = params['drift']
    # params_flow = params['flow']
    trajectory, FN = Traj(params, batch_sz, N_step, key)
    F0 = trajectory["state flow"][0]
    ratio = jnp.log(F0 / FN) + jnp.sum(jnp.log(trajectory["P_forward"] / trajectory["P_backward"]), axis=0)
    loss = jnp.mean(ratio**2, axis=0)
    return loss.reshape(())

loss = jax.jit(loss, static_argnums=(1, 2))
loss_value_grad = jax.value_and_grad(loss, argnums=0)

In [10]:
loss(params, 32, 10, key)

Array(5.8536606, dtype=float32)

#### Update parameters via loss function

In [11]:
lr = 1e-3
optimizer = optax.adam(learning_rate=lr)

opt_state = optimizer.init(params)

def update(params, opt_state, batch_sz, N_step, key):
    """update the parameters"""

    loss_value, grads = loss_value_grad(params, N_step, batch_sz, key)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss_value

update = jax.jit(update, static_argnums=(2, 3))

In [12]:
Niter = 10
Batch_SZ = 256
LR = 10**-2
optimizer = optax.adam(learning_rate=LR)

loss_values = []

time_start = time.time()
for i in range(Niter):
    key, subkey = jrd.split(key)
    params, opt_state, loss_value = update(params, opt_state, Batch_SZ, 10, subkey)
    loss_values.append(loss_value)
    print(f"Iteration {i}, Loss {loss_value}")
    if i % 10 == 0:
        time_current = time.time()
        print(f"Iteration {i:4d}/{Niter}  |  Loss: {loss_value:.2f}  |  Time: {time_current - time_start:.2f} s")

Iteration 0, Loss 150.5684814453125
Iteration    0/10  |  Loss: 150.57  |  Time: 1.97 s
Iteration 1, Loss nan
Iteration 2, Loss nan
Iteration 3, Loss nan
Iteration 4, Loss nan
Iteration 5, Loss nan
Iteration 6, Loss nan
Iteration 7, Loss nan
Iteration 8, Loss nan
Iteration 9, Loss nan
