In [1]:
import yaml
from model_utils import GeneralizedMLP, FourierKAN
from model_utils import get_mse_loss, get_train_step
from model_utils import KeyHandler, sobol_sample
import jax.numpy as jnp
from flax import linen as nn
import optax
import numpy as np
from functools import partial
from tqdm import tqdm
import pandas as pd

In [2]:
from interpolated_funcs import circular_wave_interference

def get_model(config):
    if config["MODEL"] == "MLP":
        return GeneralizedMLP(
            kernel_init=nn.initializers.glorot_normal(),
            num_input=config['N_INPUT'],
            num_output=1,
            use_fourier_feats=config['FourierFeatures'],
            layer_sizes=config['layers']
        )
    if config["MODEL"] == "KAN":
        return FourierKAN(
            kernel_init=nn.initializers.glorot_normal(),
            num_input=config['N_INPUT'],
            num_output=1,
            use_fourier_feats=config['FourierFeatures'],
            layer_sizes=config['layers']
        )

def get_target_func(config):
    if experiment["learnable_func"] == "circular_wave_interference":
        learnable_func = circular_wave_interference

    learnable_func = partial(learnable_func, FREQ=experiment["FREQ"])
    return learnable_func

def sample_collocs(config):
    collocs = jnp.array(sobol_sample(np.array([config["X_MIN"],config["Y_MIN"]]), 
                                     np.array([config["X_MAX"],config["Y_MAX"]]), config["BS"]))
    return collocs

def train_model(config):    
    collocs = sample_collocs(config)

    model = get_model(config)
    variables = model.init(keygen.key(), collocs)
    loss_fn = get_mse_loss(model, MODEL=config["MODEL"])
    
    # Define a cosine decay learning rate schedule
    schedule_fn = optax.cosine_decay_schedule(
        init_value=1e-2,       # Initial learning rate
        decay_steps=config["EPOCHS"],  # Total number of decay steps
        alpha=1e-3             # Final learning rate multiplier
    )
    optimizer = optax.adamw(learning_rate=schedule_fn, weight_decay=1e-4)
    opt_state = optimizer.init(variables['params'])
    train_step = get_train_step(model, optimizer, loss_fn)

    learnable_func = get_target_func(config)

    if config["MODEL"] == "MLP":
        variables["state"] = []

    # train always on same colloc points
    collocs = sample_collocs(experiment)
    losses = []
    
    loc_w = jnp.array([])
    for i in (pbar:= tqdm(range(experiment["EPOCHS"]))):
        params, state = variables['params'], variables['state']
        y = learnable_func(collocs).reshape(-1,1)
        params, opt_state, loss, loc_w = train_step(params, collocs, y,
                                                    opt_state, state, loc_w)
        variables = {'params': params, 'state':state}

        losses.append(loss)
        
        if i % 50 == 0: # dont waste a lot of time printing
            pbar.set_description(f"Loss {loss: .8f}")

    return variables, losses

import pickle
def save_dict_to_file(dictionary, filename):
    """Saves a dictionary to a file using pickle."""
    with open(filename, 'wb') as file:
        pickle.dump(dictionary, file)

def load_dict_from_file(filename):
    """Loads a dictionary from a file using pickle."""
    with open(filename, 'rb') as file:
        return pickle.load(file)

In [3]:
filename = "increase_params_fourier"
with open(f"{filename}.yaml", 'r') as file:
    config = yaml.safe_load(file)

keygen = KeyHandler(0)
config["experiments"].keys()

dict_keys(['increase_params_1', 'increase_params_2', 'increase_params_3', 'increase_params_4', 'increase_params_5'])

In [11]:
import jax 

def get_mse_loss(model, MODEL='MLP'):
#    @jax.jit
    def mse_loss_mlp(params, x, y, state):
        def u(vec_x, variables):
            y = model.apply(variables, vec_x)
            return y
        variables = {'params' : params}
        
        y_hat = u(x, variables)
        loss = (y_hat - y)**2
#        print(jnp.squeeze(loss))
        
        return jnp.squeeze(loss)

    if MODEL == 'MLP':
        return mse_loss_mlp
    
    @jax.jit
    def mse_loss_kan(params, x, y, state, loc_w):
        def u(vec_x, variables):
            y = model.apply(variables, vec_x)
            return y
        variables = {'params' : params, 'state': state}
        
        y_hat = u(x, variables)
        loss = jnp.mean((y_hat - y)**2)

        new_loc_w = loc_w
        return loss, new_loc_w
    
    if MODEL == 'KAN':
        return mse_loss_kan
        
experiment = config["experiments"]['increase_params_1']

collocs = sample_collocs(experiment)
collocs = collocs[:1]

model = get_model(experiment)
variables = model.init(keygen.key(), collocs)
loss_fn = get_mse_loss(model, MODEL=experiment["MODEL"])

preds = model.apply(variables, collocs)
true = jnp.ones_like(preds)

loc_w = jnp.array([])

params = variables["params"]
state = []


res = jax.vmap(loss_fn, (None, 0, 0, None))(params, collocs, true, state)
loss_grad = jax.value_and_grad(loss_fn)

res, grads = jax.vmap(loss_grad, (None, 0, 0, None))(params, collocs, true, state)
# grads

{'Dense_0': {'bias': Array([[-1.0785582]], dtype=float32),
  'kernel': Array([[[ 1.0503103 ],
          [-0.9971523 ],
          [-0.9746476 ],
          [-0.2376494 ],
          [-1.076255  ],
          [-0.93965465],
          [ 0.5540534 ],
          [ 0.5681294 ],
          [ 1.0190935 ],
          [ 0.5734641 ],
          [-1.0649058 ],
          [-1.0368441 ],
          [ 1.0537382 ],
          [ 1.0652359 ],
          [-0.79729855],
          [ 0.50856835],
          [ 0.24522667],
          [ 0.4110658 ],
          [ 0.46189806],
          [ 1.0520507 ],
          [ 0.07044991],
          [-0.5294685 ],
          [-0.9253716 ],
          [-0.916797  ],
          [-0.35318023],
          [-0.9134696 ],
          [-0.17106532],
          [-0.29705578],
          [-0.23005116],
          [-0.16899806],
          [ 0.7263627 ],
          [ 0.9511288 ]]], dtype=float32)},
 'FourierFeats_0': {'B': Array([[[ 0.04291802, -0.09637446, -0.15213463,  0.01053628,
            0.09130981, -0

In [None]:
def compute_mean_and_std_nested(batchgrad):
    mean_gradients = {}
    std_gradients = {}

    def traverse_and_compute(d, parent_key=''):
        for key, value in d.items():
            full_key = f"{parent_key}.{key}" if parent_key else key
            if isinstance(value, dict):
                # Recurse into the nested dictionary
                traverse_and_compute(value, full_key)
            else:
                # Compute mean and std for the leaf-level array
                mean_gradients[full_key] = jnp.mean(value, axis=0)
                std_gradients[full_key] = jnp.std(value, axis=0)

    traverse_and_compute(batchgrad)
    return mean_gradients, std_gradients

# Compute mean and std gradients
mean_gradients, std_gradients = compute_mean_and_std_nested(batchgrad)

# Compute the L2 norms of the mean and std gradients
mean_l2 = jnp.sqrt(sum(jnp.sum(jnp.square(mu)) for mu in mean_gradients.values()))
std_l2 = jnp.sqrt(sum(jnp.sum(jnp.square(std)) for std in std_gradients.values()))

# Compute the SNR
snr = mean_l2 / std_l2

print(f"SNR: {snr}")

In [None]:
for exp_key in config["experiments"].keys():
    experiment = config["experiments"][exp_key]
    variables, losses = train_model(experiment)
    save_dict_to_file(variables, f"trained_models/{exp_key}")

In [None]:
def l2_error(results, true):
    err = jnp.sum((results - true)**2) / jnp.sum(true**2)
    err = jnp.sqrt(err)
    return err

def get_l2_error(config, variables):
    model = get_model(config)
    learnable_func = get_target_func(experiment)
    
    N = 300
    X_1 = jnp.linspace(experiment["X_MIN"], experiment["X_MAX"], N)
    X_2 = jnp.linspace(experiment["Y_MIN"], experiment["Y_MAX"], N)
    X_1, X_2 = jnp.meshgrid(X_1, X_2, indexing='ij')
    coords = jnp.stack([X_1.flatten(), X_2.flatten()], axis=1)

    y = learnable_func(coords).reshape(-1,1)
    y_hat = model.apply(variables, coords)
    
    err = l2_error(y_hat, y)
    
    return err

def sum_params(data, verbose=False):
    total = 0
    if isinstance(data, type(jnp.array([]))):  # If the current node is a leaf array
        return len(data.reshape(-1))
    elif isinstance(data, dict):  # If the current node is a dictionary
        for key, value in data.items():
            if verbose:
                print(f"Processing key: {key}")  # Print the current key
            branch_total = sum_params(value)  # Compute the total for this subbranch
            if verbose:
                print(f"Total parameters in subbranch '{key}': {branch_total}")
            total += branch_total
    
    return total

In [None]:
import pandas as pd

df = pd.DataFrame(config["experiments"]).T
df["params"] = -1
df["L2%"] = -1

for idx, exp_key in enumerate(config["experiments"].keys()):
    experiment = config["experiments"][exp_key]
    
    variables = load_dict_from_file(f"trained_models/{exp_key}")
    l2_err = get_l2_error(experiment, variables)

    df.loc[exp_key, "params"] = sum_params(variables["params"], verbose=False)
    df.loc[exp_key, "L2%"] = l2_err * 100

    print(f"Results from {exp_key}:")
    print(f"L2 {l2_err*100:.4f}%")
    print(f"#params {df.iloc[idx]['params']}")
    
df.to_csv(f'{filename}.csv')

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from matplotlib import cm
import os

def plot_true_approx(config, variables, exp_key):
    model = get_model(config)
    learnable_func = get_target_func(config)  # Fixed to use 'config' instead of 'experiment'
    
    N = 300
    X_1 = jnp.linspace(config["X_MIN"], config["X_MAX"], N)  # Fixed to use 'config'
    X_2 = jnp.linspace(config["Y_MIN"], config["Y_MAX"], N)  # Fixed to use 'config'
    X_1, X_2 = jnp.meshgrid(X_1, X_2, indexing='ij')
    coords = jnp.stack([X_1.flatten(), X_2.flatten()], axis=1)

    y = learnable_func(coords).reshape(-1, 1)
    y_hat = model.apply(variables, coords).reshape(-1, 1)

    # Compute the absolute error
    abs_error = jnp.abs(y - y_hat).reshape(-1, 1)

    # Create a figure and axis grid for the 3 subplots
    fig = plt.figure(figsize=(18, 6))

    # Plotting the approximated function (y_hat) on the left using the plasma colormap
    ax1 = fig.add_subplot(1, 3, 1, projection='3d')
    ax1.plot_trisurf(coords[:, 0], coords[:, 1], y_hat.flatten(), cmap=cm.plasma)
    ax1.set_title('Approximated Function (y_hat)')
    ax1.set_xlabel('X1')
    ax1.set_ylabel('X2')
    ax1.set_zlabel('y_hat')

    # Plotting the true function (y) in the middle
    ax2 = fig.add_subplot(1, 3, 2, projection='3d')
    ax2.plot_trisurf(coords[:, 0], coords[:, 1], y.flatten(), cmap='viridis')
    ax2.set_title('True Function (y)')
    ax2.set_xlabel('X1')
    ax2.set_ylabel('X2')
    ax2.set_zlabel('y')

    # Plotting the absolute error on the right
    ax3 = fig.add_subplot(1, 3, 3, projection='3d')
    ax3.plot_trisurf(coords[:, 0], coords[:, 1], abs_error.flatten(), cmap='inferno')
    ax3.set_title('Absolute Error')
    ax3.set_xlabel('X1')
    ax3.set_ylabel('X2')
    ax3.set_zlabel('Error')

    # Adjust layout to ensure all plots fit well within the figure
    plt.tight_layout(pad=3.0)

    # Saving the figure
    save_folder = f'visuals/{filename}'
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    save_path = os.path.join(save_folder, f'{exp_key}_plot.png')

    # Show the plot
    plt.show()

    # Save the plot to the specified path
    fig.savefig(save_path)

In [None]:
for exp_key in config["experiments"].keys():
    experiment = config["experiments"][exp_key]

    print(f"Results from {exp_key}:")
    variables = load_dict_from_file(f'trained_models/{filename}/{exp_key}')
    plot_true_approx(experiment,variables, f'{exp_key}')

In [None]:
pd.read_csv('increase_params_fourier.csv')

In [None]:
pd.read_csv('KAN_increase_params_fourier.csv')

In [None]:
model = get_model(experiment)
print(model)

In [None]:
sum_params(variables["params"], verbose=True)

In [None]:
variables["params"]

In [None]:
experiment