This notebook provides the functions for running the SGD-to-NTK hybrid strategy model for classification tasks

Import necessary libraries. The code can run on gpu, however, with large datasets it won't be possible to store large kernel matrices. If the gpu is not found the program falls back on cpu.

In [None]:
import numpy as np
import jax
import jax.lib
import jax.numpy as jnp
from jax import grad, jit
from jax.nn import initializers
import matplotlib.pyplot as plt
from functools import partial
import time
import copy
from collections import deque
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler

# --- JAX Environment Check ---
print(f"JAX version: {jax.__version__}")
print(f"jaxlib version: {jax.lib.__version__}")
backend = jax.default_backend()
print(f"Default backend: {backend}")
print(f"Available devices: {jax.devices()}")

if backend == 'cpu':
    print("\n*** WARNING: JAX is running on CPU. ***")
    print("To enable GPU, ensure a CUDA-enabled version of jaxlib is installed.")
    print("Example: pip install --upgrade \"jax[cuda12_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")


The functions for training Neural Network. JAX architecture helps operating with networks parameters and processes more efficiently.

In [None]:
def init_network_params(layer_dims, key):      #Initialize network parameters
    keys = jax.random.split(key, len(layer_dims) - 1)
    params = []
    for i, (in_dim, out_dim) in enumerate(zip(layer_dims[:-1], layer_dims[1:])):
        W = initializers.glorot_normal()(keys[i], (in_dim, out_dim))
        b = initializers.zeros(keys[i], (out_dim,))
        params.append((W, b))
    return params

@jit
def jax_relu(x):                             #ReLU activation function
    return jnp.maximum(0, x)

def jax_softmax(x, axis=-1):                 #Softmax function
    x_max = jnp.max(x, axis=axis, keepdims=True)
    exp_x = jnp.exp(x - x_max)
    return exp_x / jnp.sum(exp_x, axis=axis, keepdims=True)

def jax_forward(params, x):                   #Forward pass of the parameters
    activation = x
    for i, (W, b) in enumerate(params[:-1]):
        outputs = jnp.dot(activation, W) + b
        activation = jax_relu(outputs)
    final_W, final_b = params[-1]
    logits = jnp.dot(activation, final_W) + final_b
    return logits


def jax_update_params(params, grads, lr):     #Update parameters
    return [(W - lr * dW, b - lr * db) for (W, b), (dW, db) in zip(params, grads)]

@jit
def jax_compute_crossentropy_loss(params, x_batch, y_batch_one_hot):      #Compute cross-entropy loss
    """Computes Cross-Entropy loss for classification."""
    logits = jax_forward(params, x_batch)
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    return -jnp.sum(y_batch_one_hot.astype(jnp.float32) * log_probs) / x_batch.shape[0]

jax_loss_grad_fn = jit(grad(jax_compute_crossentropy_loss, argnums=0))

def jax_predict_proba(params, x):             #Activate softmax
    return jax_softmax(jax_forward(params, x))

def compute_accuracy_jax(params, x, y_one_hot):          #Compute accuracy
    """Computes accuracy for classification."""
    preds_proba = jax_predict_proba(params, x)
    return jnp.mean(jnp.argmax(preds_proba, axis=1) == jnp.argmax(y_one_hot, axis=1))


Flatten and unflatten parameters to use JAX functions:

In [None]:
def flatten_params(params_list):
    flat_params_leaves, treedef = jax.tree_util.tree_flatten(params_list)
    flat_params_leaves = [jnp.asarray(leaf) for leaf in flat_params_leaves]
    return jnp.concatenate([p.ravel() for p in flat_params_leaves]), treedef

def unflatten_params(flat_params_vec, treedef, shapes_and_dtypes_meta):
    leaves = []
    current_pos = 0
    for shape, dtype in shapes_and_dtypes_meta:
        num_elements = np.prod(shape, dtype=int)
        leaves.append(jnp.asarray(flat_params_vec[current_pos: current_pos + num_elements], dtype=dtype).reshape(shape))
        current_pos += num_elements
    return jax.tree_util.tree_unflatten(treedef, leaves)

def get_shapes_and_dtypes(params_list):
    flat_params_meta, _ = jax.tree_util.tree_flatten(params_list)
    return [(p.shape, p.dtype) for p in flat_params_meta]

def single_sample_forward_flat_params(flat_params_vec, single_x_input, treedef, shapes_and_dtypes_meta):
    unflattened_params_list = unflatten_params(flat_params_vec, treedef, shapes_and_dtypes_meta)
    return jax_forward(unflattened_params_list, single_x_input.reshape(1, -1))[0]


Computes the Frobenius norm of the difference between two parameter lists.

In [None]:
def compute_params_diff_norm(params1, params2):
    diff_norms_sq = [
        jnp.sum((w1 - w2)**2) + jnp.sum((b1 - b2)**2)
        for (w1, b1), (w2, b2) in zip(params1, params2)
    ]
    return jnp.sqrt(jnp.sum(jnp.array(diff_norms_sq)))


One_hot encode the output:

In [None]:
def one_hot(y, num_classes):
    y_int = np.asarray(y, dtype=int)
    return np.eye(num_classes)[y_int.reshape(-1)]


Computes the empirical NTK for a multi-class output.

In [None]:
@partial(jit, static_argnames=['num_classes_for_k0'])
def compute_empirical_ntk_k0(J_all_at_theta0, num_classes_for_k0):
    K0 = jnp.einsum('acp,bcp->ab', J_all_at_theta0, J_all_at_theta0)
    return K0


Run SGD training for a number of epochs:

In [None]:
def run_sgd_epochs(params_initial, X_train_sgd, Y_train_onehot_sgd, X_val_full, Y_val_onehot_full,
                   start_epoch_idx, num_epochs_to_run, # start_epoch_idx is 0-based
                   batch_size, lr_sgd, key_sgd_loop, phase_label="SGD"):
    print(f"\n--- Starting {phase_label} Training Phase (Epochs {start_epoch_idx + 1} to {start_epoch_idx + num_epochs_to_run}) ---")
    params = params_initial # Start from provided parameters
    N_train_sgd = X_train_sgd.shape[0]

    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    sgd_phase_start_time = time.time()
    current_batch_size = min(batch_size, N_train_sgd)

    for epoch_offset in range(num_epochs_to_run):
        actual_epoch_num_display = start_epoch_idx + epoch_offset + 1 # For printing (1-based)

        key_sgd_loop, subkey_perm = jax.random.split(key_sgd_loop)
        indices = jax.random.permutation(subkey_perm, N_train_sgd)

        # --- Mini-batch update loop ---
        for i in range(0, N_train_sgd, current_batch_size):
            X_batch = X_train_sgd[indices[i:i + current_batch_size]]
            Y_batch = Y_train_onehot_sgd[indices[i:i + current_batch_size]]
            grads = jax_loss_grad_fn(params, X_batch, Y_batch)
            params = jax_update_params(params, grads, lr_sgd)

        # --- Full-dataset metric logging (end of epoch) ---
        train_loss = float(jax_compute_crossentropy_loss(params, X_train_sgd, Y_train_onehot_sgd))
        train_acc = float(compute_accuracy_jax(params, X_train_sgd, Y_train_onehot_sgd))
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)

        val_loss = float(jax_compute_crossentropy_loss(params, X_val_full, Y_val_onehot_full))
        val_acc = float(compute_accuracy_jax(params, X_val_full, Y_val_onehot_full))
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        print(f"{phase_label} Epoch {actual_epoch_num_display} - Train L: {train_loss:.4f}, A: {train_acc*100:.2f}% | Val L: {val_loss:.4f}, A: {val_acc*100:.2f}%")

    sgd_phase_time = time.time() - sgd_phase_start_time
    print(f"{phase_label} phase ({num_epochs_to_run} epochs) took {sgd_phase_time:.2f} seconds.")
    return params, history, sgd_phase_time


Monitor the switching condition while running SGD training. Finish and return the parameters when condition is met.

In [None]:
def run_sgd_monitoring_switch(
    params_initial, X_train_sgd, Y_train_onehot_sgd, X_val_full, Y_val_onehot_full,
    max_sgd_epochs, batch_size, lr_sgd, key_sgd_loop,
    switch_config,
    X_ntk_monitor_subset, num_classes
):
    print(f"\n--- Starting SGD Phase (Monitoring for Switch using '{switch_config['method']}') ---")
    params = params_initial
    N_train_sgd = X_train_sgd.shape[0]

    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    sgd_phase_start_time = time.time()
    current_batch_size = min(batch_size, N_train_sgd)

    # --- Switch condition specific setup ---
    method = switch_config['method']
    if method == 'param_norm':
        k = switch_config.get('param_norm_window', 3)
        param_history = deque(maxlen=k + 1)
        param_history.append(copy.deepcopy(params))

    elif method == 'ntk_norm':
        k = switch_config.get('ntk_norm_window', 3)
        ntk_total_diff_history = deque(maxlen=k + 1)

        _, treedef = flatten_params(params)
        shapes_meta = get_shapes_and_dtypes(params)
        partial_apply_fn = partial(single_sample_forward_flat_params, treedef=treedef, shapes_and_dtypes_meta=shapes_meta)
        jac_fn_single = jax.jacrev(partial_apply_fn, argnums=0)
        J_vmap_fn = jit(jax.vmap(lambda p, x: jac_fn_single(p, x), in_axes=(None, 0), out_axes=0))
        params_flat_initial, _ = flatten_params(params_initial)
        J_initial = J_vmap_fn(params_flat_initial, X_ntk_monitor_subset)
        K_initial = compute_empirical_ntk_k0(J_initial, num_classes)
        ntk_total_diff_history.append(0.0)

    epoch_at_switch = max_sgd_epochs

    for epoch in range(max_sgd_epochs):
        actual_epoch_num_display = epoch + 1
        epoch_start_time = time.time()
        key_sgd_loop, subkey_perm = jax.random.split(key_sgd_loop)
        indices = jax.random.permutation(subkey_perm, N_train_sgd)

        for i in range(0, N_train_sgd, current_batch_size):
            X_batch = X_train_sgd[indices[i:i + current_batch_size]]
            Y_batch = Y_train_onehot_sgd[indices[i:i + current_batch_size]]
            grads = jax_loss_grad_fn(params, X_batch, Y_batch)
            params = jax_update_params(params, grads, lr_sgd)

        # Log metrics
        train_loss = float(jax_compute_crossentropy_loss(params, X_train_sgd, Y_train_onehot_sgd))
        train_acc = float(compute_accuracy_jax(params, X_train_sgd, Y_train_onehot_sgd))
        val_loss = float(jax_compute_crossentropy_loss(params, X_val_full, Y_val_onehot_full))
        val_acc = float(compute_accuracy_jax(params, X_val_full, Y_val_onehot_full))

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        print(f"SGD (Monitoring) Epoch {actual_epoch_num_display} - Train L: {train_loss:.4f}, A: {train_acc*100:.2f}% | Val L: {val_loss:.4f}, A: {val_acc*100:.2f}%", end='')

        # --- Check switch condition ---
        switch_now = False

        if method == 'fixed_epoch':
            if actual_epoch_num_display >= switch_config['fixed_switch_epoch']:
                switch_now = True
        elif method == 'param_norm':
            k = switch_config.get('param_norm_window', 3)
            param_history.append(copy.deepcopy(params))
            if len(param_history) == k + 1:
                diff_norm = compute_params_diff_norm(param_history[-1], param_history[0])
                print(f" | Param diff norm (k={k}) = {diff_norm:.6f}", end='')
                if diff_norm < switch_config['param_norm_threshold']:
                    switch_now = True

        elif method == 'ntk_norm':
            k = switch_config.get('ntk_norm_window', 3)
            params_flat_current, _ = flatten_params(params)
            J_current = J_vmap_fn(params_flat_current, X_ntk_monitor_subset)
            K_current = compute_empirical_ntk_k0(J_current, num_classes)
            ntk_diff_total = float(jnp.linalg.norm(K_current - K_initial, 'fro'))
            ntk_total_diff_history.append(ntk_diff_total)

            if len(ntk_total_diff_history) == k + 1:
                diff_norm = abs(ntk_total_diff_history[-1] - ntk_total_diff_history[0])
                print(f" | NTK stability (k={k}) = {diff_norm:.4f}", end='')
                if diff_norm < switch_config['ntk_norm_threshold']:
                    switch_now = True

        print(f" (took {time.time() - epoch_start_time:.2f}s)")
        if switch_now:
            epoch_at_switch = actual_epoch_num_display
            print(f">>> Switching condition '{method}' met at epoch {epoch_at_switch} over a {k}-epoch window. <<<")
            break

    if epoch == max_sgd_epochs - 1 and not switch_now:
        epoch_at_switch = max_sgd_epochs
        print(f">>> Max SGD epochs ({max_sgd_epochs}) reached without meeting switch condition. This run will be flagged. <<<")

    sgd_phase_time = time.time() - sgd_phase_start_time
    print(f"SGD monitoring phase ({epoch_at_switch} epochs) took {sgd_phase_time:.2f} seconds.")

    final_history = {k: v[:epoch_at_switch] for k, v in history.items()}
    return params, final_history, sgd_phase_time, epoch_at_switch


Run a full SGD training to find the switching point:

In [None]:
def run_sgd_scouting(
    params_initial, X_train_sgd, Y_train_onehot_sgd, X_val_full, Y_val_onehot_full,
    scouting_epochs, batch_size, lr_sgd, key_sgd_loop,
    switch_method, # 'param_norm' or 'ntk_norm'
    X_ntk_monitor_subset, num_classes,
    param_norm_window, ntk_norm_window
):
    print(f"\n--- Starting SGD Scouting Run for {scouting_epochs} Epochs (Method: {switch_method}) ---")
    params = params_initial
    N_train_sgd = X_train_sgd.shape[0]
    history = {'val_loss': [], 'val_acc': [], 'norm_diff': []}
    scouting_start_time = time.time()
    current_batch_size = min(batch_size, N_train_sgd)

    # --- Monitoring setup based on the chosen method ---
    if switch_method == 'param_norm':
        k = param_norm_window
        param_history = deque(maxlen=k + 1)
        param_history.append(copy.deepcopy(params))

    elif switch_method == 'ntk_norm':
        k = ntk_norm_window
        ntk_total_diff_history = deque(maxlen=k + 1)

        _, treedef = flatten_params(params)
        shapes_meta = get_shapes_and_dtypes(params)
        partial_apply_fn = partial(single_sample_forward_flat_params, treedef=treedef, shapes_and_dtypes_meta=shapes_meta)
        jac_fn_single = jax.jacrev(partial_apply_fn, argnums=0)
        J_vmap_fn = jit(jax.vmap(lambda p, x: jac_fn_single(p, x), in_axes=(None, 0), out_axes=0))
        params_flat_initial, _ = flatten_params(params_initial)
        J_initial = J_vmap_fn(params_flat_initial, X_ntk_monitor_subset)
        K_initial = compute_empirical_ntk_k0(J_initial, num_classes)
        ntk_total_diff_history.append(0.0)

    for epoch in range(scouting_epochs):
        key_sgd_loop, subkey_perm = jax.random.split(key_sgd_loop)
        indices = jax.random.permutation(subkey_perm, N_train_sgd)

        for i in range(0, N_train_sgd, current_batch_size):
            X_batch = X_train_sgd[indices[i:i + current_batch_size]]
            Y_batch = Y_train_onehot_sgd[indices[i:i + current_batch_size]]
            grads = jax_loss_grad_fn(params, X_batch, Y_batch)
            params = jax_update_params(params, grads, lr_sgd)

        # Log performance metrics
        history['val_loss'].append(float(jax_compute_crossentropy_loss(params, X_val_full, Y_val_onehot_full)))
        val_acc = float(compute_accuracy_jax(params, X_val_full, Y_val_onehot_full))
        history['val_acc'].append(val_acc)

        norm_diff = np.nan
        # --- Unified calculation logic ---
        if switch_method == 'param_norm':
            k = param_norm_window
            param_history.append(copy.deepcopy(params))
            if len(param_history) == k + 1:
                 norm_diff = compute_params_diff_norm(param_history[-1], param_history[0])

        elif switch_method == 'ntk_norm':
            k = ntk_norm_window
            params_flat_current, _ = flatten_params(params)
            J_current = J_vmap_fn(params_flat_current, X_ntk_monitor_subset)
            K_current = compute_empirical_ntk_k0(J_current, num_classes)
            ntk_diff_total = float(jnp.linalg.norm(K_current - K_initial, 'fro'))

            ntk_total_diff_history.append(ntk_diff_total)
            if len(ntk_total_diff_history) == k + 1:
                norm_diff = abs(ntk_total_diff_history[-1] - ntk_total_diff_history[0])

        history['norm_diff'].append(float(norm_diff))

        print(f"Scouting Epoch {epoch + 1}/{scouting_epochs} - Val Acc: {val_acc*100:.2f}% | Norm Diff (k={k}): {norm_diff:.4f}")

    print(f"Scouting run took {time.time() - scouting_start_time:.2f} seconds.")
    return history


Run NTK 1 phase:

In [None]:
def run_ntk1_phase(params_sgd, X_train_ntk, Y_train_onehot_ntk, X_val_full, Y_val_onehot_full,
                   ntk_epochs, batch_size, lr_ntk, key_ntk_loop):
    print("\n--- Starting NTK 1 Phase ---")
    N_train_ntk = X_train_ntk.shape[0]

    theta_0_params_unflat = params_sgd
    theta_0_flat, treedef_0 = flatten_params(theta_0_params_unflat)
    shapes_meta_0 = get_shapes_and_dtypes(theta_0_params_unflat)
    theta_k_flat = jnp.copy(theta_0_flat)

    partial_apply_fn = partial(single_sample_forward_flat_params, treedef=treedef_0, shapes_and_dtypes_meta=shapes_meta_0)
    jac_fn_single_sample = jax.jacrev(partial_apply_fn, argnums=0)
    J_at_theta0_single_sample = jit(lambda single_x: jac_fn_single_sample(theta_0_flat, single_x))
    J_batch_at_theta0_vmap = jit(jax.vmap(J_at_theta0_single_sample, in_axes=(0), out_axes=0))
    predict_batch_theta_k_vmap = jit(
        jax.vmap(partial(single_sample_forward_flat_params, treedef=treedef_0, shapes_and_dtypes_meta=shapes_meta_0),
                 in_axes=(None, 0), out_axes=0))

    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    ntk1_start_time = time.time()

    for ntk_iter in range(ntk_epochs):
        key_ntk_loop, subkey_perm = jax.random.split(key_ntk_loop)
        indices = jax.random.permutation(subkey_perm, N_train_ntk)
        total_param_update_contrib = jnp.zeros_like(theta_k_flat)
        for i in range(0, N_train_ntk, batch_size):
            X_batch = X_train_ntk[indices[i:i + batch_size]]
            Y_batch_onehot = Y_train_onehot_ntk[indices[i:i + batch_size]]
            J_b_at_theta0 = J_batch_at_theta0_vmap(X_batch)
            logits_b_at_thetak = predict_batch_theta_k_vmap(theta_k_flat, X_batch)
            pred_probas_b_at_thetak = jax_softmax(logits_b_at_thetak)
            Error_batch = pred_probas_b_at_thetak - Y_batch_onehot
            batch_contrib = jnp.einsum('bcp,bc->p', J_b_at_theta0, Error_batch)
            total_param_update_contrib += batch_contrib

        effective_lr_ntk = (2.0 * lr_ntk) / N_train_ntk
        theta_k_flat -= effective_lr_ntk * total_param_update_contrib

        current_params_unflat = unflatten_params(theta_k_flat, treedef_0, shapes_meta_0)
        train_loss = float(jax_compute_crossentropy_loss(current_params_unflat, X_train_ntk, Y_train_onehot_ntk))
        train_acc = float(compute_accuracy_jax(current_params_unflat, X_train_ntk, Y_train_onehot_ntk))
        val_loss = float(jax_compute_crossentropy_loss(current_params_unflat, X_val_full, Y_val_onehot_full))
        val_acc = float(compute_accuracy_jax(current_params_unflat, X_val_full, Y_val_onehot_full))

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        print(f"NTK 1 Iter {ntk_iter + 1}/{ntk_epochs} - Train (full) L: {train_loss:.4f}, A: {train_acc*100:.2f}% | Val (full) L: {val_loss:.4f}, A: {val_acc*100:.2f}%")

    ntk1_time = time.time() - ntk1_start_time
    print(f"NTK 1 phase ({ntk_epochs} iterations) took {ntk1_time:.2f} seconds.")
    final_params_ntk1 = unflatten_params(theta_k_flat, treedef_0, shapes_meta_0)
    return final_params_ntk1, history, ntk1_time


Two functions for computing matrix exponential, either using Taylor epxansion or a direct calculation:

In [None]:
def matrix_exp_taylor(A, order=5):
    """Taylor expansion for matrix exponential: I + A + A^2/2! + ..."""
    N = A.shape[0]
    if A.shape[0] != A.shape[1]:
        raise ValueError("Matrix must be square for exponential.")

    result = jnp.eye(N, dtype=A.dtype)
    A_power_k = jnp.eye(N, dtype=A.dtype)
    factorial_k = 1.0

    for k in range(1, order + 1):
        A_power_k = jnp.dot(A_power_k, A)
        factorial_k *= k
        result += A_power_k / factorial_k
    return result

USE_JAX_EXPM = True # Set to False to use Taylor expansion

def compute_matrix_exp(A, taylor_order=5):
    if USE_JAX_EXPM:
        try:
            return jax.scipy.linalg.expm(A)
        except Exception as e:
            print(f"jax.scipy.linalg.expm failed: {e}. Falling back to Taylor expansion.")
            return matrix_exp_taylor(A, order=taylor_order)
    else:
        return matrix_exp_taylor(A, order=taylor_order)