# Newton line search on exponential fit

In [None]:
import sys
from pathlib import Path
sys.path.insert(0, str(Path('..') / 'src'))
from kl_decomposition import rectangle_rule, fit_exp_sum, newton_with_line_search, gauss_legendre_rule
from kl_decomposition.kernel_fit import _prepare_jax_funcs
import numpy as np
import jax
import jax.numpy as jnp

In [None]:
import numpy as np
import time


def de(obj, dim,
       pop_size=15,
       max_gen=100,
       F=0.8,
       CR=0.9,
       seed=42,
       tol=None,            # e.g. 1e-8   -> activate plateau stop
       patience=10,         # how many stagnant generations to allow
       target=None,         # objective value to reach, e.g. 1e-12
       verbose=False):

    rng = np.random.default_rng(seed)
    pop = np.sort(rng.normal(size=(pop_size, dim)), 1)
    scores = np.array([obj(x) for x in pop])
    evals = pop_size
    best = scores.min()
    stall = 0
    t0 = time.time()

    for g in range(max_gen):
        for i in range(pop_size):
            a, b, c = rng.choice(pop_size, 3, replace=False)
            mutant = np.sort(pop[a] + F * (pop[b] - pop[c]))
            mask = rng.random(dim) < CR
            mask[rng.integers(dim)] = True
            trial = np.where(mask, mutant, pop[i])
            s = obj(trial)
            evals += 1
            if s < scores[i]:
                pop[i], scores[i] = trial, s

        new_best = scores.min()

        if verbose:
            print(f'gen {g:3d} | best {new_best:.3e}')

        # --- stopping checks ----------------------------------------------
        if target is not None and new_best <= target:
            if verbose:
                print('target reached – stopping')
            break

        if tol is not None:
            if abs(best - new_best) <= tol * (abs(best) + 1e-12):
                stall += 1
                if stall >= patience:
                    if verbose:
                        print('no significant improvement – stopping')
                    break
            else:
                stall = 0

        best = new_best

    k = scores.argmin()
    return pop[k], {'best_score': float(scores[k]),
                    'iterations': g + 1,
                    'eval_count': evals,
                    'runtime': time.time() - t0}

In [None]:
import numpy as np
N = 30                                    # number of exponential terms
x, w = rectangle_rule(0.0, 5.0, 10000)      # integration grid
# x, w = gauss_legendre_rule(0.0, 5.0, 4000)  # integration grid
target = np.exp(-x)                       # func(t) = exp(-t)
dim = N                                   # only the log-b parameters

sqrt_w = np.sqrt(w)                       # pre-compute once for speed
y = target * sqrt_w


def objective(c_params):
    """Return weighted least-squares error for current log-b vector."""
    b_sorted = np.sort(np.exp(c_params))                    # enforce order
    F = np.exp(-b_sorted[None, :] * x[:, None] ** 2)
    A = F * sqrt_w[:, None]                                 # weighted design
    a_ls, *_ = np.linalg.lstsq(A, y, rcond=None)            # LS solution
    residual = F @ a_ls - target
    return np.sum(w * residual * residual)


# ---- optimise ------------------------------------------------------------
best_c, info = de(objective, dim,
                  pop_size=30,
                  max_gen=500,
                  tol=1e-16,
                  patience=20,
                  verbose=True,seed=3)

# ---- reconstruct final coefficients --------------------------------------
b_est = np.sort(np.exp(best_c))
F = np.exp(-b_est[None, :] * x[:, None] ** 2)
A = F * sqrt_w[:, None]
a_est, *_ = np.linalg.lstsq(A, y, rcond=None)

print("\nEstimated a:", a_est)
print("Estimated b:", b_est)
print("Best weighted-LS error:", info["best_score"])

In [None]:
N = 4
x, w = rectangle_rule(0.0, 5.0, 100)
func = lambda t: np.exp(-t)
a_ls, b_ls, info = fit_exp_sum(N, x, w, func, method='de_ls', max_gen=500, pop_size=30)
print('initial a:', a_ls)
print('initial b:', b_ls)


In [None]:
params0 = np.concatenate([a_est, np.log(b_est)])
params0

In [None]:
target = lambda t: np.exp(-t)
obj, grad, hess = _prepare_jax_funcs(x, target, w, newton=True)params0 = np.concatenate([a_est, np.log(b_est)])
params_opt, stats = newton_with_line_search(params0, obj, grad, hess, max_iter=10)
print('refined a:', np.exp(params_opt[:N]))
print('refined b:', np.exp(params_opt[N:]))
print('Newton iterations:', stats.iterations)