This notebook prepares the data for the subsequent notebook `1-Step-Analyze-Baselines.ipynb`, which generates the figure illustrating LLM one-step prediction performance against repeat-based baselines, as described in Supplementary Material Section 7.

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import random
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
from tqdm import tqdm

from data_processing import (
    SimpleSerializerSettings,
    scale_2d_array,
    unscale_2d_array,
    serialize_2d_integers,
    deserialize_2d_integers,
)
from allen_cahn_equation import (
    compute_exact_solution_random_ic_vary_Nx,
    visualize_spline_ic,
    plot_both_grids,
)

L = 2
k = 0.001
T = 0.5
seed = 42
n_seeds = 50
settings = SimpleSerializerSettings(space_sep=",", time_sep=";")
Nx = 14
all_Nt_values = range(2, 41, 2)

stored_initial_conditions = []
stored_spline_objects = []

for seed_idx in range(n_seeds):
    random.seed(seed_idx)
    np.random.seed(seed_idx)
    init_cond_random = np.random.uniform(-0.5, 0.5, size=Nx)
    stored_initial_conditions.append(init_cond_random.copy())
    fig, cs = visualize_spline_ic(L, Nx, init_cond_random)
    plt.close(fig)
    stored_spline_objects.append(cs)

stored_initial_conditions_array = np.array(stored_initial_conditions)

persistence_final_max_diff = []
persistence_final_rmse = []
persistence_final_max_diff_std = []
persistence_final_rmse_std = []
repeat_last_final_max_diff = []
repeat_last_final_rmse = []
repeat_last_final_max_diff_std = []
repeat_last_final_rmse_std = []
temporal_baseline_max_errors = []
temporal_baseline_rmse_errors = []

for Nt in tqdm(all_Nt_values):
    dt = T / Nt
    seed_max_diffs_persistence = []
    seed_rmses_persistence = []
    seed_max_diffs_repeat_last = []
    seed_rmses_repeat_last = []
    seed_baseline_max_errors = []
    seed_baseline_rmse_errors = []
    
    for seed_idx in range(n_seeds):
        init_cond_random = stored_initial_conditions[seed_idx]
        cs = stored_spline_objects[seed_idx]
        
        u_exact = compute_exact_solution_random_ic_vary_Nx(L, k, T, Nx, Nt, spline_obj=cs)
        u_exact_scaled, vmin_exact, vmax_exact = scale_2d_array(u_exact)
        u_exact_serialized = serialize_2d_integers(u_exact_scaled, settings)
        
        final_exact = u_exact[Nt]
        quantized_gt_2d, _, _ = scale_2d_array(final_exact[np.newaxis, :], vmin_exact, vmax_exact)
        quantized_gt_2d = unscale_2d_array(quantized_gt_2d, vmin_exact, vmax_exact)
        quantized_ground_truth = quantized_gt_2d[0, :]
        
        penultimate_exact = u_exact[Nt-1]
        penultimate_exact_2d = penultimate_exact[np.newaxis, :]
        scaled_penultimate_2d, _, _ = scale_2d_array(penultimate_exact_2d, vmin_exact, vmax_exact)
        penultimate_quantized_2d = unscale_2d_array(scaled_penultimate_2d, vmin_exact, vmax_exact)
        penultimate_quantized = penultimate_quantized_2d[0, :]
        
        # temporal-repeat baseline: uses final in-context time slice as prediction
        persistence_prediction = penultimate_quantized
        max_diff_persistence = np.max(np.abs(persistence_prediction - quantized_ground_truth))
        rmse_persistence = np.sqrt(np.mean((persistence_prediction - quantized_ground_truth)**2))
        seed_max_diffs_persistence.append(max_diff_persistence)
        seed_rmses_persistence.append(rmse_persistence)
        
        # last-token baseline: fills next time slice with final scalar token
        last_spatial_value = penultimate_quantized[-1] 
        repeat_last_prediction = np.full_like(quantized_ground_truth, last_spatial_value)
        max_diff_repeat_last = np.max(np.abs(repeat_last_prediction - quantized_ground_truth))
        rmse_repeat_last = np.sqrt(np.mean((repeat_last_prediction - quantized_ground_truth)**2))
        seed_max_diffs_repeat_last.append(max_diff_repeat_last)
        seed_rmses_repeat_last.append(rmse_repeat_last)
        
        u_exact_parsed = deserialize_2d_integers(u_exact_serialized, settings)
        u_exact_unscaled = unscale_2d_array(u_exact_parsed, vmin_exact, vmax_exact)
        baseline_max_error = np.max(np.abs(u_exact - u_exact_unscaled))
        baseline_rmse_error = np.sqrt(np.mean((u_exact - u_exact_unscaled)**2))
        seed_baseline_max_errors.append(baseline_max_error)
        seed_baseline_rmse_errors.append(baseline_rmse_error)
    
    persistence_final_max_diff.append(np.mean(seed_max_diffs_persistence))
    persistence_final_rmse.append(np.mean(seed_rmses_persistence))
    persistence_final_max_diff_std.append(np.std(seed_max_diffs_persistence, ddof=1))
    persistence_final_rmse_std.append(np.std(seed_rmses_persistence, ddof=1))
    
    repeat_last_final_max_diff.append(np.mean(seed_max_diffs_repeat_last))
    repeat_last_final_rmse.append(np.mean(seed_rmses_repeat_last))
    repeat_last_final_max_diff_std.append(np.std(seed_max_diffs_repeat_last, ddof=1))
    repeat_last_final_rmse_std.append(np.std(seed_rmses_repeat_last, ddof=1))
    
    temporal_baseline_max_errors.append(np.mean(seed_baseline_max_errors))
    temporal_baseline_rmse_errors.append(np.mean(seed_baseline_rmse_errors))

persistence_final_max_diff = np.array(persistence_final_max_diff)
persistence_final_rmse = np.array(persistence_final_rmse)
persistence_final_max_diff_std = np.array(persistence_final_max_diff_std)
persistence_final_rmse_std = np.array(persistence_final_rmse_std)

repeat_last_final_max_diff = np.array(repeat_last_final_max_diff)
repeat_last_final_rmse = np.array(repeat_last_final_rmse)
repeat_last_final_max_diff_std = np.array(repeat_last_final_max_diff_std)
repeat_last_final_rmse_std = np.array(repeat_last_final_rmse_std)

temporal_baseline_max_errors = np.array(temporal_baseline_max_errors)
temporal_baseline_rmse_errors = np.array(temporal_baseline_rmse_errors)

def log_ci(mean, std, n, tcrit):
    """95% CI for log10 axis using delta method"""
    se = std / np.sqrt(n)
    se_log = se / (mean * np.log(10))
    mean_log = np.log10(mean)
    delta_log = tcrit * se_log
    return 10**(mean_log - delta_log), 10**(mean_log + delta_log)

t_critical = stats.t.ppf(0.975, df=n_seeds-1)

persistence_lower_max_diff_log, persistence_upper_max_diff_log = log_ci(
    persistence_final_max_diff, persistence_final_max_diff_std, n_seeds, t_critical)
persistence_lower_rmse_log, persistence_upper_rmse_log = log_ci(
    persistence_final_rmse, persistence_final_rmse_std, n_seeds, t_critical)

repeat_last_lower_max_diff_log, repeat_last_upper_max_diff_log = log_ci(
    repeat_last_final_max_diff, repeat_last_final_max_diff_std, n_seeds, t_critical)
repeat_last_lower_rmse_log, repeat_last_upper_rmse_log = log_ci(
    repeat_last_final_rmse, repeat_last_final_rmse_std, n_seeds, t_critical)

np.savez_compressed(
    "baseline_1_step_time_discretization.npz",
    persistence_final_max_diff=persistence_final_max_diff,
    persistence_final_rmse=persistence_final_rmse,
    persistence_final_max_diff_std=persistence_final_max_diff_std,
    persistence_final_rmse_std=persistence_final_rmse_std,
    persistence_lower_max_diff_log=persistence_lower_max_diff_log,
    persistence_upper_max_diff_log=persistence_upper_max_diff_log,
    persistence_lower_rmse_log=persistence_lower_rmse_log,
    persistence_upper_rmse_log=persistence_upper_rmse_log,
    repeat_last_final_max_diff=repeat_last_final_max_diff,
    repeat_last_final_rmse=repeat_last_final_rmse,
    repeat_last_final_max_diff_std=repeat_last_final_max_diff_std,
    repeat_last_final_rmse_std=repeat_last_final_rmse_std,
    repeat_last_lower_max_diff_log=repeat_last_lower_max_diff_log,
    repeat_last_upper_max_diff_log=repeat_last_upper_max_diff_log,
    repeat_last_lower_rmse_log=repeat_last_lower_rmse_log,
    repeat_last_upper_rmse_log=repeat_last_upper_rmse_log,
    temporal_baseline_max_errors=temporal_baseline_max_errors,
    temporal_baseline_rmse_errors=temporal_baseline_rmse_errors,
    initial_conditions=stored_initial_conditions_array,
    all_Nt_values=list(all_Nt_values),
    n_seeds=n_seeds,
    t_critical=t_critical,
    Nx=Nx
)

# Spatial discretization experiment (varying Nx, fixed Nt=50)

Nt = 50
Nx_base = 14
all_Nx_values = range(2, 41, 2)

persistence_final_max_diff = []
persistence_final_rmse = []
persistence_final_max_diff_std = []
persistence_final_rmse_std = []

repeat_last_final_max_diff = []
repeat_last_final_rmse = []
repeat_last_final_max_diff_std = []
repeat_last_final_rmse_std = []

spatial_baseline_max_errors = []
spatial_baseline_rmse_errors = []

for Nx in tqdm(all_Nx_values):
    dt = T / Nt
    seed_max_diffs_persistence = []
    seed_rmses_persistence = []
    seed_max_diffs_repeat_last = []
    seed_rmses_repeat_last = []
    seed_baseline_max_errors = []
    seed_baseline_rmse_errors = []
    
    for seed_idx in range(n_seeds):
        init_cond_random_base = stored_initial_conditions[seed_idx]
        cs = stored_spline_objects[seed_idx]
        
        if Nx == Nx_base:
            init_cond_current = init_cond_random_base
        else:
            fig2, cs_same, init_cond_current = plot_both_grids(L, Nx_base, Nx, init_cond_random_base)
            plt.close(fig2)
        
        u_exact = compute_exact_solution_random_ic_vary_Nx(L, k, T, Nx, Nt, spline_obj=cs)
        u_exact_scaled, vmin_exact, vmax_exact = scale_2d_array(u_exact)
        u_exact_serialized = serialize_2d_integers(u_exact_scaled, settings)
        
        final_exact = u_exact[Nt]
        quantized_gt_2d, _, _ = scale_2d_array(final_exact[np.newaxis, :], vmin_exact, vmax_exact)
        quantized_gt_2d = unscale_2d_array(quantized_gt_2d, vmin_exact, vmax_exact)
        quantized_ground_truth = quantized_gt_2d[0, :]
        
        penultimate_exact = u_exact[Nt-1]
        penultimate_exact_2d = penultimate_exact[np.newaxis, :]
        scaled_penultimate_2d, _, _ = scale_2d_array(penultimate_exact_2d, vmin_exact, vmax_exact)
        penultimate_quantized_2d = unscale_2d_array(scaled_penultimate_2d, vmin_exact, vmax_exact)
        penultimate_quantized = penultimate_quantized_2d[0, :]
        
        # temporal-repeat baseline: uses final in-context time slice as prediction
        persistence_prediction = penultimate_quantized
        max_diff_persistence = np.max(np.abs(persistence_prediction - quantized_ground_truth))
        rmse_persistence = np.sqrt(np.mean((persistence_prediction - quantized_ground_truth)**2))
        seed_max_diffs_persistence.append(max_diff_persistence)
        seed_rmses_persistence.append(rmse_persistence)
        
        # last-token baseline: fills next time slice with final scalar token
        last_spatial_value = penultimate_quantized[-1]
        repeat_last_prediction = np.full_like(quantized_ground_truth, last_spatial_value)
        max_diff_repeat_last = np.max(np.abs(repeat_last_prediction - quantized_ground_truth))
        rmse_repeat_last = np.sqrt(np.mean((repeat_last_prediction - quantized_ground_truth)**2))
        seed_max_diffs_repeat_last.append(max_diff_repeat_last)
        seed_rmses_repeat_last.append(rmse_repeat_last)
        
        u_exact_parsed = deserialize_2d_integers(u_exact_serialized, settings)
        u_exact_unscaled = unscale_2d_array(u_exact_parsed, vmin_exact, vmax_exact)
        baseline_max_error = np.max(np.abs(u_exact - u_exact_unscaled))
        baseline_rmse_error = np.sqrt(np.mean((u_exact - u_exact_unscaled)**2))
        seed_baseline_max_errors.append(baseline_max_error)
        seed_baseline_rmse_errors.append(baseline_rmse_error)
    
    persistence_final_max_diff.append(np.mean(seed_max_diffs_persistence))
    persistence_final_rmse.append(np.mean(seed_rmses_persistence))
    persistence_final_max_diff_std.append(np.std(seed_max_diffs_persistence, ddof=1))
    persistence_final_rmse_std.append(np.std(seed_rmses_persistence, ddof=1))
    
    repeat_last_final_max_diff.append(np.mean(seed_max_diffs_repeat_last))
    repeat_last_final_rmse.append(np.mean(seed_rmses_repeat_last))
    repeat_last_final_max_diff_std.append(np.std(seed_max_diffs_repeat_last, ddof=1))
    repeat_last_final_rmse_std.append(np.std(seed_rmses_repeat_last, ddof=1))
    
    spatial_baseline_max_errors.append(np.mean(seed_baseline_max_errors))
    spatial_baseline_rmse_errors.append(np.mean(seed_baseline_rmse_errors))

epsilon = 1e-4
persistence_final_max_diff = np.maximum(np.array(persistence_final_max_diff), epsilon)
persistence_final_rmse = np.maximum(np.array(persistence_final_rmse), epsilon)
persistence_final_max_diff_std = np.array(persistence_final_max_diff_std)
persistence_final_rmse_std = np.array(persistence_final_rmse_std)

repeat_last_final_max_diff = np.maximum(np.array(repeat_last_final_max_diff), epsilon)
repeat_last_final_rmse = np.maximum(np.array(repeat_last_final_rmse), epsilon)
repeat_last_final_max_diff_std = np.array(repeat_last_final_max_diff_std)
repeat_last_final_rmse_std = np.array(repeat_last_final_rmse_std)

spatial_baseline_max_errors = np.array(spatial_baseline_max_errors)
spatial_baseline_rmse_errors = np.array(spatial_baseline_rmse_errors)

persistence_lower_max_diff_log, persistence_upper_max_diff_log = log_ci(
    persistence_final_max_diff, persistence_final_max_diff_std, n_seeds, t_critical)
persistence_lower_rmse_log, persistence_upper_rmse_log = log_ci(
    persistence_final_rmse, persistence_final_rmse_std, n_seeds, t_critical)

repeat_last_lower_max_diff_log, repeat_last_upper_max_diff_log = log_ci(
    repeat_last_final_max_diff, repeat_last_final_max_diff_std, n_seeds, t_critical)
repeat_last_lower_rmse_log, repeat_last_upper_rmse_log = log_ci(
    repeat_last_final_rmse, repeat_last_final_rmse_std, n_seeds, t_critical)

np.savez_compressed(
    "baseline_1_step_space_discretization.npz",
    persistence_final_max_diff=persistence_final_max_diff,
    persistence_final_rmse=persistence_final_rmse,
    persistence_final_max_diff_std=persistence_final_max_diff_std,
    persistence_final_rmse_std=persistence_final_rmse_std,
    persistence_lower_max_diff_log=persistence_lower_max_diff_log,
    persistence_upper_max_diff_log=persistence_upper_max_diff_log,
    persistence_lower_rmse_log=persistence_lower_rmse_log,
    persistence_upper_rmse_log=persistence_upper_rmse_log,
    repeat_last_final_max_diff=repeat_last_final_max_diff,
    repeat_last_final_rmse=repeat_last_final_rmse,
    repeat_last_final_max_diff_std=repeat_last_final_max_diff_std,
    repeat_last_final_rmse_std=repeat_last_final_rmse_std,
    repeat_last_lower_max_diff_log=repeat_last_lower_max_diff_log,
    repeat_last_upper_max_diff_log=repeat_last_upper_max_diff_log,
    repeat_last_lower_rmse_log=repeat_last_lower_rmse_log,
    repeat_last_upper_rmse_log=repeat_last_upper_rmse_log,
    spatial_baseline_max_errors=spatial_baseline_max_errors,
    spatial_baseline_rmse_errors=spatial_baseline_rmse_errors,
    n_seeds=n_seeds,
    t_critical=t_critical,
    all_Nx_values=list(all_Nx_values),
    Nt=Nt
)