# An Empirical Investigation of Initialization Strategies for Kolmogorov–Arnold Networks

**Authors**: *Spyros Rigas, Dhruv Verma, Georgios Alexandridis, Yixuan Wang*

The present Jupyter Notebook serves as the companion to the ICML25 MOSS submission titled "*An Empirical Investigation of Initialization Strategies for Kolmogorov–Arnold Networks*". The notebook includes the code that reproduces the results shown in the paper along with the corresponding plots. Note that the results shown in Table 1 of the manuscript (see [this](#Grid-Search-Results---Data-Analysis) notebook section) use the `grid_search.csv` file as a reference, which is obtained after running a grid search over multiple architectures. This file is provided in the same directory as this notebook.

## Preliminaries

In the following, we install necessary packages and define preliminaries that are essential for the main part of the code.

In [None]:
!pip3 install -q jaxkan[gpu]
!pip3 install -q scikit-learn
!pip3 install -q pandas
!pip3 install -q matplotlib

In [None]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

In [None]:
from typing import Union, List

import pandas as pd

import jax
import jax.numpy as jnp
from jax.scipy.special import i1, i1e, fresnel, erfinv, erf

from flax import nnx
import optax

from jaxkan.KAN import KAN
from jaxkan.layers.Spline import SplineLayer

from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.lines as mlines
from matplotlib.cm import get_cmap

In [None]:
# Define the functions used for the task
def f1(x):
    return x[:, [0]] * x[:, [1]]

def f2(x):
    return jnp.exp(jnp.sin(jnp.pi * x[:, [0]]) + x[:, [1]]**2)

def f3(x):
    return i1(x[:, [0]]) + jnp.exp(i1e(x[:, [1]])) + jnp.sin(x[:, [0]] * x[:, [1]])

def f4(x):
    S, C = fresnel(f3(x) + erfinv(x[:, [1]]))
    return S * C

def f5(x):
    return x[:, 1].reshape(-1, 1) * jnp.where(x[:, 0] < 0.5, 1, -1).reshape(-1, 1) + erf(x[:, 0]).reshape(-1, 1) * jnp.where(f1(x) < 1, f1(x), 1/f1(x))

In [None]:
# Classes that define the architecture that allows to get lecun-normalized initialization
class StdSplineLayer(SplineLayer):

    def _initialize_params(self, init_scheme, seed):

        key = jax.random.key(seed)

        # Also get distribution type
        distrib = init_scheme.get("distribution", "uniform")

        if distrib is None:
            distrib = "uniform"

        # Generate a sample of 10^5 points
        if distrib == "uniform":
            sample = jax.random.uniform(key, shape=(100000,), minval=-1.0, maxval=1.0)
        elif distrib == "normal":
            sample = jax.random.normal(key, shape=(100000,))

        # Finally get gain
        gain = init_scheme.get("gain", None)
        if gain is None:
            gain = sample.std().item()

        # ---- Residual Calculations --------
        # Variance equipartitioned across all terms
        scale = self.n_in * (self.grid.G + self.k + 1)
        # Apply the residual function
        y_res = self.residual(sample)
        # Calculate the average of residual^2(x)
        y_res_sq = y_res**2
        y_res_sq_mean = y_res_sq.mean().item()

        std_res = gain/jnp.sqrt(scale*y_res_sq_mean)
        c_res = nnx.initializers.normal(stddev=std_res)(self.rngs.params(), (self.n_out, self.n_in), jnp.float32)

        # ---- Basis Calculations -----------
        std_b = gain/jnp.sqrt(scale)
        c_basis = nnx.initializers.normal(stddev=std_b)(
            self.rngs.params(), (self.n_out, self.n_in, self.grid.G + self.k), jnp.float32
        )
        
        return c_res, c_basis

        
    def basis(self, x):
        basis_splines = super().basis(x)

        mean = jnp.mean(basis_splines, axis=0, keepdims=True)
        denom = jnp.sqrt(jnp.var(basis_splines, axis=0, keepdims=True) + 1e-5)
        basis_splines = (basis_splines - mean) / denom

        return basis_splines


class StdKAN(nnx.Module):
    
    def __init__(self, layer_dims: List[int], required_parameters: Union[None, dict] = None, seed: int = 42):
            
        if required_parameters is None:
            raise ValueError("required_parameters must be provided as a dictionary for the selected layer_type.")
        
        self.layers = [
                StdSplineLayer(
                    n_in=layer_dims[i],
                    n_out=layer_dims[i + 1],
                    **required_parameters,
                    seed=seed
                )
                for i in range(len(layer_dims) - 1)
            ]
    
    def __call__(self, x):

        # Pass through each layer of the KAN
        for layer in self.layers:
            x = layer(x)

        return x
        

In [None]:
# Utilities
def generate_func_data(function, dim, N, seed):
    key = jax.random.key(seed)
    x = jax.random.uniform(key, shape=(N,dim), minval=-1.0, maxval=1.0)

    y = function(x)

    return x, y


@nnx.jit
def func_fit_step(model, optimizer, X_train, y_train):

    def loss_fn(model):
        residual = model(X_train) - y_train
        loss = jnp.mean((residual)**2)

        return loss

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)

    return loss

## Grid Search Results - Data Analysis

This part of the notebook simply performs data analysis on the results of the `grid_search.csv` file to produce the results of the manuscript's Table 1. Note that the file `grid_search.csv` (which is supplied in the Supplementary Material) must be on the same directory of this notebook for the following to run successfully. If you only wish to see the results for the trained models, you may skip directly to the section titled [KAN Runs](#KAN-Runs).

In [None]:
# Load the results file to perform the analysis
gs = pd.read_csv('grid_search.csv')

In [None]:
# Isolate the run with the median performance for confidence
gs_sorted = gs.sort_values("loss")

# Grouping columns, including pow_res and pow_basis
group_cols = ['method', 'function', 'G', 'width', 'depth', 'pow_res', 'pow_basis']

# Define a function to get the row with the median loss
def get_median_row(group):
    median_loss = group['loss'].median()
    # Use idxmin on absolute difference to median to break ties predictably
    idx = (group['loss'] - median_loss).abs().idxmin()
    return group.loc[[idx]]

# Apply the function group-wise and reset the index
mgs = gs_sorted.groupby(group_cols, dropna=False, group_keys=False).apply(get_median_row).reset_index(drop=True)

In [None]:
# Filter to only 'power' method
power_df = mgs[mgs['method'] == 'power'].copy()

# Group by function and architecture (G, width, depth), and find row with minimal loss
best_power_configs = (
    power_df
    .groupby(['function', 'G', 'width', 'depth'], dropna=False, group_keys=False)
    .apply(lambda g: g.loc[g['loss'].idxmin()])
    .reset_index(drop=True)
)

# Drop pow_res and pow_basis from the whole filtered set
mgs_nopow = mgs.drop(columns=['pow_res', 'pow_basis', 'run'])

# Drop pow_res and pow_basis from best_power_configs too
best_power_configs_nopow = best_power_configs.drop(columns=['pow_res', 'pow_basis', 'run'])

# Filter out original 'power' rows from mgs_nopow
non_power_rows = mgs_nopow[mgs_nopow['method'] != 'power']

# Combine best 'power' rows with all other methods
fgs = pd.concat([non_power_rows, best_power_configs_nopow], ignore_index=True)

At this point we have a dataframe called `fgs` with a single run per architecture, corresponding to the median results. For each function and each method, we proceed to calculate how many instances outperform the baseline in terms of:

a. the final loss:

In [None]:
# Step 1: Extract baseline rows
baseline_df = fgs[fgs['method'] == 'baseline'][['function', 'G', 'depth', 'width', 'loss']]
baseline_df = baseline_df.rename(columns={'loss': 'baseline_loss'})

# Step 2: Filter the methods of interest
methods_of_interest = ['lecun_norm', 'lecun_numer', 'power']
fgs_comp = fgs[fgs['method'].isin(methods_of_interest)].copy()

# Step 3: Merge with baseline on matching config
merged = pd.merge(
    fgs_comp,
    baseline_df,
    on=['function', 'G', 'depth', 'width'],
    how='inner'
)

# Step 4: Compare losses
merged['beats_baseline'] = merged['loss'] < merged['baseline_loss']

# Step 5: Group and count
result = (
    merged.groupby(['function', 'method'])['beats_baseline']
    .sum()
    .reset_index(name='num_architectures')
)

num_base = baseline_df[baseline_df['function']=='f1'].shape[0]
result['percentage'] = 100*result['num_architectures']/num_base

In [None]:
print(result)

b. the final $L^2$ error relative to the reference solution:

In [None]:
# Step 1: Get baseline l2 values
baseline_l2 = fgs[fgs['method'] == 'baseline'][['function', 'G', 'depth', 'width', 'l2']]
baseline_l2 = baseline_l2.rename(columns={'l2': 'baseline_l2'})

# Step 2: Filter the methods of interest again if needed
fgs_comp_l2 = fgs[fgs['method'].isin(methods_of_interest)].copy()

# Step 3: Merge on config
merged_l2 = pd.merge(
    fgs_comp_l2,
    baseline_l2,
    on=['function', 'G', 'depth', 'width'],
    how='inner'
)

# Step 4: Compare l2 values
merged_l2['beats_baseline_l2'] = merged_l2['l2'] < merged_l2['baseline_l2']

# Step 5: Group and count
result_l2 = (
    merged_l2.groupby(['function', 'method'])['beats_baseline_l2']
    .sum()
    .reset_index(name='num_architectures')
)

result_l2['percentage'] = 100*result_l2['num_architectures']/num_base

In [None]:
print(result_l2)

Finally, let's find the number of architectures that minimize the loss and the relative $L^2$ error at the same time:

In [None]:
# Reuse the merged DataFrame that contains both loss and l2 comparisons
# First, make sure both baseline_loss and baseline_l2 are available

# Step 1: Merge baseline loss and l2 together
baseline_all = fgs[fgs['method'] == 'baseline'][['function', 'G', 'depth', 'width', 'loss', 'l2']]
baseline_all = baseline_all.rename(columns={'loss': 'baseline_loss', 'l2': 'baseline_l2'})

# Step 2: Merge with the methods of interest
fgs_comp_all = fgs[fgs['method'].isin(methods_of_interest)].copy()
merged_all = pd.merge(
    fgs_comp_all,
    baseline_all,
    on=['function', 'G', 'depth', 'width'],
    how='inner'
)

# Step 3: Compare both loss and l2
merged_all['beats_both'] = (
    (merged_all['loss'] < merged_all['baseline_loss']) &
    (merged_all['l2'] < merged_all['baseline_l2'])
)

# Step 4: Group and count
result_both = (
    merged_all.groupby(['function', 'method'])['beats_both']
    .sum()
    .reset_index(name='num_architectures')
)

result_both['percentage'] = 100*result_both['num_architectures']/num_base

In [None]:
print(result_both)

## KAN Runs

Following the data analysis stage, we proceed with the training of the two networks mentioned in the manuscript to show the evolution of the training loss for each function, under all proposed initialization techniques.

In [None]:
# Setup
func_dict = {"f1": f1, "f2": f2, "f3": f3, "f4": f4, "f5": f5}

N = 5000
seed = 42

num_epochs = 2000

opt_type = optax.adam(learning_rate=0.001)

pow_basis = 1.75
pow_res = 0.25

# --------------------------
# Small architecture details
# --------------------------
G_small = 5
hidden_small = [8, 8]

params_small_baseline = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                         'init_scheme': {'type': 'default'}}

params_small_lecun_numer = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                            'init_scheme': {'type': 'lecun', 'gain': None, 'distribution': 'uniform'}}

params_small_lecun_norm = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                           'init_scheme': {'gain': None, 'distribution': 'uniform'}}

params_small_power = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                      'init_scheme': {'type': 'power', "const_b": 1.0, "const_r": 1.0, "pow_b1": pow_basis, "pow_b2": pow_basis, "pow_r1": pow_res, "pow_r2": pow_res}}

# ------------------------
# Big architecture details
# ------------------------
G_big = 20
hidden_big = [32, 32, 32]

params_big_baseline = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                         'init_scheme': {'type': 'default'}}

params_big_lecun_numer = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                            'init_scheme': {'type': 'lecun', 'gain': None, 'distribution': 'uniform'}}

params_big_lecun_norm = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                           'init_scheme': {'gain': None, 'distribution': 'uniform'}}

params_big_power = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,
                      'init_scheme': {'type': 'power', "const_b": 1.0, "const_r": 1.0, "pow_b1": pow_basis, "pow_b2": pow_basis, "pow_r1": pow_res, "pow_r2": pow_res}}

In [None]:
# Experiment
# Initialize results dict
results = dict()

for func_name in func_dict.keys():
    print(f"Running Experiments for {func_name}.")
    function = func_dict[func_name]
    results[func_name] = dict()

    results[func_name]['small'] = dict()
    results[func_name]['big'] = dict()

    # Generate data
    x, y = generate_func_data(function, 2, N, seed)

    # Split data
    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)

    # Model input/output
    n_in, n_out = X_train.shape[1], y_train.shape[1]

    # Small architecture
    layer_dims = [n_in, *hidden_small, n_out]

    print(f"\tTraining model with dimensions {layer_dims}.")

    # For confidence
    for run in [1, 2, 3, 4, 5]:

        results[func_name]['small'][run] = dict()

        print(f"\t\tRun No. {run}.")

        # Baseline
        base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_baseline, seed = seed+run)
        base_opt = nnx.Optimizer(base_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss = func_fit_step(base_model, base_opt, X_train, y_train)
            train_losses = train_losses.at[epoch].set(loss)

        results[func_name]['small'][run]['baseline'] = train_losses.copy()

        print(f"\t\t\tBaseline model: Final Loss = {loss:.2e}")

        # LeCun-Numerical
        numer_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_lecun_numer, seed = seed+run)
        numer_opt = nnx.Optimizer(numer_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss = func_fit_step(numer_model, numer_opt, X_train, y_train)
            train_losses = train_losses.at[epoch].set(loss)

        results[func_name]['small'][run]['numer'] = train_losses.copy()

        print(f"\t\t\tLeCun-Numerical model: Final Loss = {loss:.2e}")

        # LeCun-Normalized
        norm_model = StdKAN(layer_dims = layer_dims, required_parameters = params_small_lecun_norm, seed = seed+run)
        norm_opt = nnx.Optimizer(norm_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss = func_fit_step(norm_model, norm_opt, X_train, y_train)
            train_losses = train_losses.at[epoch].set(loss)

        results[func_name]['small'][run]['norm'] = train_losses.copy()

        print(f"\t\t\tLeCun-Normalized model: Final Loss = {loss:.2e}")

        # Power Law
        power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_power, seed = seed+run)
        power_opt = nnx.Optimizer(power_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss = func_fit_step(power_model, power_opt, X_train, y_train)
            train_losses = train_losses.at[epoch].set(loss)

        results[func_name]['small'][run]['power'] = train_losses.copy()

        print(f"\t\t\tPower-law model: Final Loss = {loss:.2e}")

    # Big architecture
    layer_dims = [n_in, *hidden_big, n_out]

    print(f"\tTraining model with dimensions {layer_dims}.")

    # For confidence
    for run in [1, 2, 3, 4, 5]:

        results[func_name]['big'][run] = dict()

        print(f"\t\tRun No. {run}.")

        # Baseline
        base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_baseline, seed = seed+run)
        base_opt = nnx.Optimizer(base_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss = func_fit_step(base_model, base_opt, X_train, y_train)
            train_losses = train_losses.at[epoch].set(loss)

        results[func_name]['big'][run]['baseline'] = train_losses.copy()

        print(f"\t\t\tBaseline model: Final Loss = {loss:.2e}")

        # LeCun-Numerical
        numer_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_lecun_numer, seed = seed+run)
        numer_opt = nnx.Optimizer(numer_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss = func_fit_step(numer_model, numer_opt, X_train, y_train)
            train_losses = train_losses.at[epoch].set(loss)

        results[func_name]['big'][run]['numer'] = train_losses.copy()

        print(f"\t\t\tLeCun-Numerical model: Final Loss = {loss:.2e}")

        # LeCun-Normalized
        norm_model = StdKAN(layer_dims = layer_dims, required_parameters = params_big_lecun_norm, seed = seed+run)
        norm_opt = nnx.Optimizer(norm_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss = func_fit_step(norm_model, norm_opt, X_train, y_train)
            train_losses = train_losses.at[epoch].set(loss)

        results[func_name]['big'][run]['norm'] = train_losses.copy()

        print(f"\t\t\tLeCun-Normalized model: Final Loss = {loss:.2e}")

        # Power Law
        power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_power, seed = seed+run)
        power_opt = nnx.Optimizer(power_model, opt_type)

        train_losses = jnp.zeros((num_epochs,))
        for epoch in range(num_epochs):
            loss = func_fit_step(power_model, power_opt, X_train, y_train)
            train_losses = train_losses.at[epoch].set(loss)

        results[func_name]['big'][run]['power'] = train_losses.copy()

        print(f"\t\t\tPower-law model: Final Loss = {loss:.2e}")

In [None]:
# Plotting
cmap = get_cmap("Spectral")
spectral_points = np.linspace(0, 1, 12)
color_indices = [-3, 3, 1, -1]

init_types = ['baseline', 'numer', 'norm', 'power']
architectures = ['small', 'big']
func_names = list(results.keys())
func_plot_names = [r'$f_1(x,y)$', r'$f_2(x,y)$', r'$f_3(x,y)$', r'$f_4(x,y)$', r'$f_5(x,y)$']

colors = [cmap(spectral_points[i]) for i in color_indices]
custom_colors = dict(zip(init_types, colors))

fig, axes = plt.subplots(2, 5, figsize=(25, 10))

for col, func_name in enumerate(func_names):
    for row, arch in enumerate(architectures):
        ax = axes[row, col]
        
        for init in init_types:
            # Collect all runs for this configuration
            runs = []
            for run in results[func_name][arch]:
                arr = np.array(results[func_name][arch][run][init])
                runs.append(arr)
            runs = np.stack(runs)  # shape: (5, num_epochs)

            # Compute mean and standard error
            mean = runs.mean(axis=0)
            stderr = runs.std(axis=0) / np.sqrt(runs.shape[0])

            # Plot mean with stderr shaded area
            ax.plot(mean, label=init, color=custom_colors[init])
            ax.fill_between(np.arange(num_epochs), mean - stderr, mean + stderr, alpha=0.3, color=custom_colors[init])
            
            #ax.set_xticks([0, 100, 200, 300, 400, 500])
            ax.tick_params(axis='both', labelsize=14)

        # Labeling
        if row == 0:
            ax.set_title(func_plot_names[col], fontsize=18)
        if col == 0:
            ax.set_ylabel("Training Loss", fontsize=16, labelpad=10)
        if row == 1:
            ax.set_xlabel("Training Iteration", fontsize=16, labelpad=10)
        if col == len(func_names) - 1:
            ax.text(1.05, 0.5, r'$G = 5$, depth = 2, width = 8' if row == 0 else r'$G = 20$, depth = 3, width = 32', transform=ax.transAxes,
                    fontsize=16, rotation=270, va='center', ha='left')

        ax.set_yscale('log')
        ax.grid(True, which='both', linestyle='--', linewidth=0.25, alpha=0.35)

# Construct legend manually
handles = [
    mlines.Line2D([], [], color=custom_colors['baseline'], label='Baseline', linewidth=3),
    mlines.Line2D([], [], color=custom_colors['numer'], label='LeCun–Numerical', linewidth=3),
    mlines.Line2D([], [], color=custom_colors['norm'], label='LeCun–Normalized', linewidth=3),
    mlines.Line2D([], [], color=custom_colors['power'], label='Power-Law', linewidth=3),
]

# Add global legend
fig.legend(handles=handles, loc="lower center", ncol=4, fontsize=18, frameon=False, bbox_to_anchor=(0.5, -0.05))

plt.subplots_adjust(hspace=0.2, wspace=0.2, bottom=0.1)

#fig.savefig("losses.pdf", bbox_inches='tight')

plt.show()
