In [None]:
from hybrid_ntk import data_utils, utils, training

Fetch data:

In [None]:
X_train_full, X_val_full, X_test_full, Y_train_onehot_full, Y_val_onehot_full, Y_test_onehot_full = fetch_split_data('Fashion-MNIST')

Create random keys:

In [None]:
scout_key_loop, run_keys = randomize(num_seeds = 10)

Initialize network:

In [None]:
layer_dims = [X_train_full.shape[1], 100, 100, num_classes]  # Network size
print(f"Network layer_dims: {layer_dims}")

# --- Hyperparameters ---

# Scouting Run Config
scouting_epochs = 15
scouting_lr = 0.007
scouting_method = 'param_norm' # 'param_norm' or 'ntk_norm'

# Main Experiment Config
epochs_sgd_total_config = 15
batch_size_config = 128

# Switching condition configuration
switch_config = {
    'method': scouting_method,
    'fixed_switch_epoch': 7,
    'param_norm_window': 2,    # Window size k
    'param_norm_threshold': 0.3253,
    'ntk_norm_window': 4,      # Window size k
    'ntk_norm_threshold': 253892,
}

lr_sgd_config = 0.007
lr_ntk_iterative_config = 0.007
lambda_ntk3_config = 0.01
taylor_order_config = 3

dataset_name = "Fashion-MNIST"

# Max samples for NTK computation phases.
max_ntk_samples_val = 2000
max_ntk_scouting_val = 200

print(
    f"SGD will run on {X_train_full.shape[0]} samples. NTK phases will use a subset of {max_ntk_samples_val} samples.")

Scouting run:

In [None]:
# --- Process AVERAGED Data for Threshold Calculation ---
avg_val_loss = np.array(avg_scout_history['val_loss']['mean'])
avg_norm_diff = np.array(avg_scout_history['norm_diff']['mean'])

# --- Process SINGLE-RUN Data for Plotting ---
epochs = np.arange(1, scouting_epochs + 1)
val_loss = np.array(first_run_history['val_loss'])
val_acc = np.array(first_run_history['val_acc'])
norm_diff = np.array(first_run_history['norm_diff'])

# --- ROBUST THRESHOLD SUGGESTION FROM AVERAGED DATA ---
print("\n--- Suggested Thresholds Based on Averaged Scouting Run ---")
val_loss_diff = np.diff(avg_val_loss)
try:
    initial_drop = avg_val_loss[0] - np.min(avg_val_loss)
    elbow_epoch_idx = np.where(np.abs(val_loss_diff) < 0.01 * initial_drop)[0][0] + 1
except (IndexError, TypeError):
    elbow_epoch_idx = len(avg_val_loss) // 2
print(f"Average validation loss stabilized around epoch {elbow_epoch_idx + 1}.")

stabilization_norms = avg_norm_diff[elbow_epoch_idx:]
stabilization_norms = stabilization_norms[~np.isnan(stabilization_norms)]

recommended_threshold = 0.0
suggested_switch_epoch = scouting_epochs

if stabilization_norms.size > 1:
    if scouting_method == 'param_norm':
        print("\nUsing 'Early Switch' (75th percentile) heuristic for 'param_norm'.")
        recommended_threshold = np.percentile(stabilization_norms, 75)
        switch_config['param_norm_threshold'] = recommended_threshold
    elif scouting_method == 'ntk_norm':
        print("\nUsing 'Early Switch' (75th percentile) heuristic for 'ntk_norm'.")
        recommended_threshold = np.percentile(stabilization_norms, 75)
        switch_config['ntk_norm_threshold'] = recommended_threshold

    try:
        suggested_switch_epoch = np.where(avg_norm_diff < recommended_threshold)[0][0] + 1
    except IndexError:
        print("Threshold was not met on averaged data. Defaulting to last epoch.")
        suggested_switch_epoch = scouting_epochs
else:
    print("Could not automatically determine a threshold.")

print(f"\nRecommended Threshold for '{scouting_method}': {recommended_threshold:.4f}")
print(f"This threshold suggests a switch at Epoch: {suggested_switch_epoch}")

Run scouting:

In [None]:
# --- DYNAMIC FILENAME GENERATION ---
dataset_prefix = f"{dataset_name.lower().replace('-', '')}_sc_avg"
network_width = layer_dims[1]
metric_char = 'n' if scouting_method == 'ntk_norm' else 'p'
filename_template = f"{dataset_prefix}_{{metric_type}}_{metric_char}_{network_width}_{scouting_lr}_{recommended_threshold:.4f}.png"

# --- PLOTTING SETUP ---
# Define font sizes for better readability in a report
TITLE_FONT = 18
LABEL_FONT = 24
LEGEND_FONT = 22
TICK_FONT = 12

# --- PLOTTING (Using data from the FIRST run for visualization) ---
switch_idx = min(suggested_switch_epoch - 1, len(val_loss) - 1)
loss_at_switch = val_loss[switch_idx]
acc_at_switch = val_acc[switch_idx] * 100

if scouting_method == 'param_norm':
    metric_label, linestyle, marker, color = r'Parameter Norm Difference ($\epsilon$)', '--', 'x', 'tab:blue'
    window_val, title_suffix = switch_config['param_norm_window'], 'Parameter Stability'
    legend_label = fr'$||\theta_t - \theta_{{t-{window_val}}}||_F$'
else: # ntk_norm
    metric_label, linestyle, marker, color = r'NTK Stability ($\delta$)', ':', 's', 'tab:green'
    window_val, title_suffix = switch_config['ntk_norm_window'], 'NTK Stability'
    legend_label = fr'| $||\Theta_t - \Theta_0||_F - ||\Theta_{{t-{window_val}}}-\Theta_0||_F$ |'

# PLOT 1: Validation Loss vs. Stability
fig1, ax1 = plt.subplots(1, 1, figsize=(12, 7))
ax1.plot(epochs, val_loss, color='tab:red', marker='o', label='Validation Loss (Run 1)', zorder=5)
ax1.plot(suggested_switch_epoch, loss_at_switch, '*', color='magenta', markersize=20, label=f'Suggested Switch (Ep. {suggested_switch_epoch})', zorder=10, markeredgecolor='black')
ax1.set_ylabel('Validation Loss', color='tab:red', fontsize=LABEL_FONT)
ax1.tick_params(axis='y', labelcolor='tab:red', labelsize=TICK_FONT)
ax1_twin = ax1.twinx()
ax1_twin.plot(epochs, norm_diff, color=color, marker=marker, linestyle=linestyle, label=legend_label)
ax1_twin.set_ylabel(metric_label, color=color, fontsize=LABEL_FONT)
ax1_twin.tick_params(axis='y', labelcolor=color, labelsize=TICK_FONT)
ax1.set_xlabel('Epoch', fontsize=LABEL_FONT)
ax1.tick_params(axis='x', labelsize=TICK_FONT)
ax1.grid(True)
fig1.legend(loc="upper right", bbox_to_anchor=(1, 1), bbox_transform=ax1.transAxes, fontsize=LEGEND_FONT)
fig1.tight_layout()
val_filename = filename_template.format(metric_type='val')
plt.savefig(val_filename)
print(f"\nPlot saved as {val_filename}")
plt.show()

# PLOT 2: Accuracy vs. Stability
fig2, ax2 = plt.subplots(1, 1, figsize=(12, 7))
ax2.plot(epochs, val_acc * 100, color='tab:purple', marker='o', label='Validation Accuracy (Run 1)', zorder=5)
ax2.plot(suggested_switch_epoch, acc_at_switch, '*', color='magenta', markersize=20, label=f'Suggested Switch (Ep. {suggested_switch_epoch})', zorder=10, markeredgecolor='black')
ax2.set_ylabel('Validation Accuracy (%)', color='tab:purple', fontsize=LABEL_FONT)
ax2.tick_params(axis='y', labelcolor='tab:purple', labelsize=TICK_FONT)
ax2_twin = ax2.twinx()
ax2_twin.plot(epochs, norm_diff, color=color, marker=marker, linestyle=linestyle, label=legend_label)
ax2_twin.set_ylabel(metric_label, color=color, fontsize=LABEL_FONT)
ax2_twin.tick_params(axis='y', labelcolor=color, labelsize=TICK_FONT)
ax2.set_xlabel('Epoch', fontsize=LABEL_FONT)
ax2.tick_params(axis='x', labelsize=TICK_FONT)
ax2.grid(True)
# --- CORRECTED LINE ---
fig2.legend(loc="lower left", bbox_to_anchor=(0.1, 0.1), bbox_transform=ax2.transAxes, fontsize=LEGEND_FONT)
fig2.tight_layout()
acc_filename = filename_template.format(metric_type='acc')
plt.savefig(acc_filename)
print(f"Plot saved as {acc_filename}")
plt.show()

Run training:

In [None]:
all_runs_histories = {'SGD_Part1': [], 'SGD_Part2': [], 'NTK1_from_Switch': [], 'NTK2_from_Switch': [], 'NTK3_from_Switch': []}
all_runs_times = {'SGD (Part 1)': [], 'SGD (Part 2)': [], 'NTK1': [], 'NTK2': [], 'NTK3': []}
all_runs_test_metrics = {'Params_At_Switch': [], 'SGD_Full_Run': [], 'NTK1_from_Switch': [], 'NTK2_from_Switch': [], 'NTK3_from_Switch': []}
all_runs_switch_epochs = []

for seed in range(num_seeds):
    print(f"\n\n===== RUNNING EXPERIMENT FOR SEED {seed + 1}/{num_seeds} =====")
    key_init_params, key_sgd_part1, key_sgd_part2, key_ntk_subset_shuffle, key_ntk1, key_ntk2, key_ntk3 = run_keys[seed]
    initial_model_params = init_network_params(layer_dims, key_init_params)

    # --- CORRECTED: Create distinct subsets for monitoring and NTK training ---
    key_monitor_subset, key_train_subset = jax.random.split(key_ntk_subset_shuffle)

    # Subset for NTK stability monitoring during the first SGD phase
    X_ntk_monitor_subset = X_train_full
    if max_ntk_scouting_val is not None and max_ntk_scouting_val < X_train_full.shape[0]:
        monitor_indices = jax.random.choice(key_monitor_subset, X_train_full.shape[0],
                                            shape=(max_ntk_scouting_val,), replace=False)
        X_ntk_monitor_subset = X_train_full[monitor_indices]

    # Subset for the actual NTK training phases (NTK1, NTK2, NTK3)
    X_train_ntk_subset, Y_train_onehot_ntk_subset = X_train_full, Y_train_onehot_full
    if max_ntk_samples_val is not None and max_ntk_samples_val < X_train_full.shape[0]:
        train_indices = jax.random.choice(key_train_subset, X_train_full.shape[0],
                                          shape=(max_ntk_samples_val,), replace=False)
        X_train_ntk_subset = X_train_full[train_indices]
        Y_train_onehot_ntk_subset = Y_train_onehot_full[train_indices]

    # --- Initial SGD Phase (up to switch point) ---
    params_at_switch, history_sgd_part1, time_sgd_part1, epoch_at_switch = run_sgd_monitoring_switch(
        initial_model_params, X_train_full, Y_train_onehot_full, X_val_full, Y_val_onehot_full,
        max_sgd_epochs=epochs_sgd_total_config,
        batch_size=min(batch_size_config, X_train_full.shape[0]),
        lr_sgd=lr_sgd_config, key_sgd_loop=key_sgd_part1,
        switch_config=switch_config,
        X_ntk_monitor_subset=X_ntk_monitor_subset, num_classes=num_classes
    )
    all_runs_switch_epochs.append(epoch_at_switch)
    all_runs_histories['SGD_Part1'].append(history_sgd_part1)
    all_runs_times['SGD (Part 1)'].append(time_sgd_part1)
    test_l_s, test_a_s = evaluate_on_test_jax(params_at_switch, X_test_full, Y_test_onehot_full)
    all_runs_test_metrics['Params_At_Switch'].append({'loss': test_l_s, 'acc': test_a_s})

    if epoch_at_switch >= epochs_sgd_total_config:
        print(f"WARNING: Switch condition not met for seed {seed+1}. Adjust thresholds or total epochs.")
        all_runs_histories['SGD_Part2'].append({'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []})
        all_runs_times['SGD (Part 2)'].append(0.0)
        all_runs_test_metrics['SGD_Full_Run'].append({'loss': test_l_s, 'acc': test_a_s})
        all_runs_test_metrics['NTK1_from_Switch'].append({'loss': np.nan, 'acc': np.nan})
        all_runs_test_metrics['NTK2_from_Switch'].append({'loss': np.nan, 'acc': np.nan})
        all_runs_test_metrics['NTK3_from_Switch'].append({'loss': np.nan, 'acc': np.nan})
        continue

    num_ntk_iterations_to_run = epochs_sgd_total_config - epoch_at_switch

    # --- Continued SGD Phase ---
    params_after_full_sgd, history_sgd_part2, time_sgd_part2 = run_sgd_epochs(
        copy.deepcopy(params_at_switch), X_train_full, Y_train_onehot_full, X_val_full, Y_val_onehot_full,
        start_epoch_idx=epoch_at_switch, num_epochs_to_run=num_ntk_iterations_to_run,
        batch_size=min(batch_size_config, X_train_full.shape[0]),
        lr_sgd=lr_sgd_config, key_sgd_loop=key_sgd_part2, phase_label="SGD (Part 2)"
    )
    all_runs_histories['SGD_Part2'].append(history_sgd_part2)
    all_runs_times['SGD (Part 2)'].append(time_sgd_part2)
    test_l_full, test_a_full = evaluate_on_test_jax(params_after_full_sgd, X_test_full, Y_test_onehot_full)
    all_runs_test_metrics['SGD_Full_Run'].append({'loss': test_l_full, 'acc': test_a_full})

    # --- NTK Phases ---
    current_ntk_batch_size = min(batch_size_config, X_train_ntk_subset.shape[0])
    if num_ntk_iterations_to_run > 0:
        params_after_ntk1, history_ntk1, time_ntk1 = run_ntk1_phase(
            copy.deepcopy(params_at_switch), X_train_ntk_subset, Y_train_onehot_ntk_subset,
            X_val_full, Y_val_onehot_full, num_ntk_iterations_to_run,
            current_ntk_batch_size, lr_ntk_iterative_config, key_ntk1)
        all_runs_histories['NTK1_from_Switch'].append(history_ntk1)
        all_runs_times['NTK1'].append(time_ntk1)
        test_l_ntk1, test_a_ntk1 = evaluate_on_test_jax(params_after_ntk1, X_test_full, Y_test_onehot_full)
        all_runs_test_metrics['NTK1_from_Switch'].append({'loss': test_l_ntk1, 'acc': test_a_ntk1})

        params_after_ntk2, history_ntk2, time_ntk2 = run_ntk2_phase(
            copy.deepcopy(params_at_switch), X_train_ntk_subset, Y_train_onehot_ntk_subset,
            X_val_full, Y_val_onehot_full, epoch_at_switch, num_ntk_iterations_to_run,
            lr_ntk_iterative_config, taylor_order_ntk2=taylor_order_config)
        all_runs_histories['NTK2_from_Switch'].append(history_ntk2)
        all_runs_times['NTK2'].append(time_ntk2)
        test_l_ntk2, test_a_ntk2 = evaluate_on_test_jax(params_after_ntk2, X_test_full, Y_test_onehot_full)
        all_runs_test_metrics['NTK2_from_Switch'].append({'loss': test_l_ntk2, 'acc': test_a_ntk2})

    params_after_ntk3, history_ntk3, time_ntk3 = run_ntk3_phase(
        copy.deepcopy(params_at_switch), X_train_ntk_subset, Y_train_onehot_ntk_subset,
        X_val_full, Y_val_onehot_full,
        lr_ntk=lr_ntk_iterative_config,
        lambda_ntk3_reg=lambda_ntk3_config,
        T_factor_ntk3=epochs_sgd_total_config,
        taylor_order_ntk3=taylor_order_config)
    all_runs_histories['NTK3_from_Switch'].append(history_ntk3)
    all_runs_times['NTK3'].append(time_ntk3)
    test_l_ntk3, test_a_ntk3 = evaluate_on_test_jax(params_after_ntk3, X_test_full, Y_test_onehot_full)
    all_runs_test_metrics['NTK3_from_Switch'].append({'loss': test_l_ntk3, 'acc': test_a_ntk3})

Average results:

In [None]:
avg_histories = {}
avg_histories['SGD_Part1'] = aggregate_histories(all_runs_histories['SGD_Part1'], metric_keys=['val_loss', 'val_acc'])
avg_histories['SGD_Part2'] = aggregate_histories(all_runs_histories['SGD_Part2'], metric_keys=['train_loss', 'train_acc', 'val_loss', 'val_acc'])
avg_histories['NTK1_from_Switch'] = aggregate_histories(all_runs_histories['NTK1_from_Switch'], metric_keys=['train_loss', 'train_acc', 'val_loss', 'val_acc'])
avg_histories['NTK2_from_Switch'] = aggregate_histories(all_runs_histories['NTK2_from_Switch'], metric_keys=['train_loss', 'train_acc', 'val_loss', 'val_acc'])

ntk3_final_train_losses = [h['train_loss'][0] for h in all_runs_histories['NTK3_from_Switch'] if h['train_loss']]
ntk3_final_val_losses   = [h['val_loss'][0] for h in all_runs_histories['NTK3_from_Switch'] if h['val_loss']]
ntk3_final_train_accs = [h['train_acc'][0] for h in all_runs_histories['NTK3_from_Switch'] if h['train_acc']]
ntk3_final_val_accs   = [h['val_acc'][0] for h in all_runs_histories['NTK3_from_Switch'] if h['val_acc']]

avg_histories['NTK3_from_Switch_Aggregated'] = {
    'train_loss': aggregate_scalar_metrics(ntk3_final_train_losses),
    'val_loss':   aggregate_scalar_metrics(ntk3_final_val_losses),
    'train_acc': aggregate_scalar_metrics(ntk3_final_train_accs),
    'val_acc':  aggregate_scalar_metrics(ntk3_final_val_accs)
}

avg_switch_epoch = np.mean(all_runs_switch_epochs)

avg_times = {phase: np.mean(times) for phase, times in all_runs_times.items() if times}
avg_test_metrics = {}
for phase, metrics_list in all_runs_test_metrics.items():
    if metrics_list:
        avg_test_metrics[phase] = {
            'loss_mean': np.nanmean([m.get('loss', np.nan) for m in metrics_list]),
            'loss_std':  np.nanstd([m.get('loss', np.nan) for m in metrics_list]),
            'acc_mean': np.nanmean([m.get('acc', np.nan) for m in metrics_list]),
            'acc_std':  np.nanstd([m.get('acc', np.nan) for m in metrics_list])
        }

Summarise results:

In [None]:
avg_histories = {}
avg_histories['SGD_Part1'] = aggregate_histories(all_runs_histories['SGD_Part1'], metric_keys=['val_loss', 'val_acc'])
avg_histories['SGD_Part2'] = aggregate_histories(all_runs_histories['SGD_Part2'], metric_keys=['train_loss', 'train_acc', 'val_loss', 'val_acc'])
avg_histories['NTK1_from_Switch'] = aggregate_histories(all_runs_histories['NTK1_from_Switch'], metric_keys=['train_loss', 'train_acc', 'val_loss', 'val_acc'])
avg_histories['NTK2_from_Switch'] = aggregate_histories(all_runs_histories['NTK2_from_Switch'], metric_keys=['train_loss', 'train_acc', 'val_loss', 'val_acc'])

ntk3_final_train_losses = [h['train_loss'][0] for h in all_runs_histories['NTK3_from_Switch'] if h['train_loss']]
ntk3_final_val_losses   = [h['val_loss'][0] for h in all_runs_histories['NTK3_from_Switch'] if h['val_loss']]
ntk3_final_train_accs = [h['train_acc'][0] for h in all_runs_histories['NTK3_from_Switch'] if h['train_acc']]
ntk3_final_val_accs   = [h['val_acc'][0] for h in all_runs_histories['NTK3_from_Switch'] if h['val_acc']]

avg_histories['NTK3_from_Switch_Aggregated'] = {
    'train_loss': aggregate_scalar_metrics(ntk3_final_train_losses),
    'val_loss':   aggregate_scalar_metrics(ntk3_final_val_losses),
    'train_acc': aggregate_scalar_metrics(ntk3_final_train_accs),
    'val_acc':  aggregate_scalar_metrics(ntk3_final_val_accs)
}

avg_switch_epoch = np.mean(all_runs_switch_epochs)

avg_times = {phase: np.mean(times) for phase, times in all_runs_times.items() if times}
avg_test_metrics = {}
for phase, metrics_list in all_runs_test_metrics.items():
    if metrics_list:
        avg_test_metrics[phase] = {
            'loss_mean': np.nanmean([m.get('loss', np.nan) for m in metrics_list]),
            'loss_std':  np.nanstd([m.get('loss', np.nan) for m in metrics_list]),
            'acc_mean': np.nanmean([m.get('acc', np.nan) for m in metrics_list]),
            'acc_std':  np.nanstd([m.get('acc', np.nan) for m in metrics_list])
        }