In [None]:
!pip install git+https://github.com/google-research/swirl-dynamics.git@main

In [None]:
#@title Imports
from typing import Any, Literal, Optional
import flax.linen as nn
import functools
import jax
import numpy as np
import matplotlib.pyplot as plt
jax.config.update("jax_enable_x64", True)
from swirl_dynamics.data import hdf5_utils
from swirl_dynamics.lib.networks import rational_networks
from swirl_dynamics.projects.weno_nn import weno
from swirl_dynamics.projects.weno_nn import weno_nn
from swirl_dynamics.projects.weno_nn import utils

In [None]:
flax_model_main_folder = '../model_weights/'
xid=94741459
model_num=113
filename = flax_model_main_folder+f'/xid_{xid}_model_{model_num}.hdf5'
network_vars = hdf5_utils.read_all_arrays_as_dict(filename)
mlp_model = weno_nn.OmegaNN(
    features=tuple(network_vars['config']['features'].astype(int)),
    features_fun=utils.get_feature_func(network_vars['config']['features_fun'].decode()),
    act_fun=utils.get_act_func(network_vars['config']['act_fun'].decode()),
)

In [None]:
x=np.linspace(0.0, 1.0, 101)
u=np.sin(np.pi*x)
# Stack neighbor information for [u_{i-1}, u_{i}, u_{i+1}].
u_nb=np.stack([u[:-2], u[1:-1], u[2:]], axis=1)

In [None]:
# Individual functions are written over scalar inputs of
# [u_{i-1}, u_{i}, u_{i+1}]. Hence we vectorize over the axis for u_nb above.
# Function to perform interpolation:
weno_interp_func_vmap = jax.vmap(weno.weno_interpolation, in_axes=(0, None))
model_apply_func = functools.partial(mlp_model.apply, test=True)
# Function to calculate WENO-weights:
weno_nn_wt_func_vmap = jax.vmap(model_apply_func, in_axes=(None,0))

In [None]:
# Estimate WENO-weights on the negative side:
wt_neg = weno_nn_wt_func_vmap({"params": network_vars["params"]}, u_nb)
# Perform WENO interpolation on both positive and negative sides.
u_interp = weno_interp_func_vmap(
    u_nb,
    lambda x, params: model_apply_func({"params": network_vars["params"]}, x),
)
# Unstack the positive and negative side interpolations.
u_interp_pos = u_interp[:, 0]
u_interp_neg = u_interp[:, 1]

In [None]:
plt.figure()
plt.plot(x[1:-1], wt_neg[:,0]); plt.ylim([0.0,1.0]);
plt.plot(x[1:-1], wt_neg[:,1]); plt.ylim([0.0,1.0]);
plt.xlabel('X'); plt.ylabel('WENO Weights');

In [None]:
x_half = x+(x[1]-x[0])*0.5
plt.figure()
plt.plot(x_half[1:-1], u_interp_pos, '-.b', label='Pos');
plt.plot(x_half[1:-1], u_interp_neg, '-.g', label='Neg');
plt.plot(x, u, '--r', label='Cell');
plt.xlabel('X'); plt.ylabel('WENO Weight'); plt.legend();