# Differential evolution with JAX gradients

In [None]:
import numpy as np
from math import isinf
from pathlib import Path
import sys
sys.path.insert(0, str(Path('..') / 'src'))
from kl_decomposition import rectangle_rule
import jax
import jax.numpy as jnp

In [None]:
def fmin_DENewton(
    F, dF, ddF,
    lower_bounds, upper_bounds,
    beta_min, beta_max, pCR,
    n_pop, tol_cost, tol_grad, max_it,
    newton_LS_size=8,
    rng=None,
):
    rng = rng or np.random.default_rng()
    n_var = len(lower_bounds)
    lower_bounds = np.asarray(lower_bounds, dtype=float)
    upper_bounds = np.asarray(upper_bounds, dtype=float)
    pop_x = np.empty((n_pop, n_var))
    for j in range(n_var):
        lo, hi = lower_bounds[j], upper_bounds[j]
        if not isinf(lo) and not isinf(hi):
            pop_x[:, j] = rng.uniform(lo, hi, size=n_pop)
        elif isinf(lo) and isinf(hi):
            pop_x[:, j] = rng.normal(size=n_pop)
        elif isinf(hi):
            pop_x[:, j] = lo + rng.exponential(size=n_pop)
        else:
            pop_x[:, j] = hi - rng.exponential(size=n_pop)
    pop_cost = np.apply_along_axis(F, 1, pop_x)
    best_idx = int(np.argmin(pop_cost))
    best_x = pop_x[best_idx].copy()
    best_cost = pop_cost[best_idx]
    best_grad = np.zeros(n_var)
    prev_mean = np.inf
    for it in range(1, max_it + 1):
        mean_cost = pop_cost.mean()
        if abs(prev_mean - mean_cost) <= tol_cost and np.linalg.norm(best_grad) <= tol_grad:
            break
        prev_mean = mean_cost
        for i in range(n_pop):
            others = [j for j in range(n_pop) if j != i]
            a, b, c = rng.choice(others, 3, replace=False)
            beta = rng.uniform(beta_min, beta_max, size=n_var)
            y = pop_x[a] + beta * (pop_x[b] - pop_x[c])
            y = np.minimum(np.maximum(y, lower_bounds), upper_bounds)
            j0 = rng.integers(n_var)
            mask = rng.random(n_var) < pCR
            mask[j0] = True
            z = np.where(mask, y, pop_x[i])
            z_cost = F(z)
            if z_cost < pop_cost[i]:
                pop_x[i], pop_cost[i] = z, z_cost
                if z_cost < best_cost:
                    best_x, best_cost = z.copy(), z_cost
        for i in range(n_pop):
            g = dF(pop_x[i])
            H = ddF(pop_x[i])
            try:
                direction = -np.linalg.solve(H, g)
            except np.linalg.LinAlgError:
                continue
            alphas = np.linspace(-1, 1, newton_LS_size)
            candidates = pop_x[i] + np.outer(alphas, direction)
            candidates = np.clip(candidates, lower_bounds, upper_bounds)
            costs_line = np.apply_along_axis(F, 1, candidates)
            idx_min = int(np.argmin(costs_line))
            z_new = candidates[idx_min]
            z_cost = costs_line[idx_min]
            if z_cost < pop_cost[i]:
                pop_x[i], pop_cost[i] = z_new, z_cost
                if z_cost < best_cost:
                    best_x, best_cost = z_new.copy(), z_cost
                    best_grad = g
    return best_x, it


In [None]:
# setup squared exponential approximation
x, w = rectangle_rule(0.0, 2.0, 50)
target = 2.0 * np.exp(-3.0 * x**2) + 0.5 * np.exp(-1.0 * x**2)
x_j = jnp.array(x)
w_j = jnp.array(w)
target_j = jnp.array(target)

def obj_jax(p):
    a = jnp.exp(p[:2])
    b = jnp.exp(p[2:])
    pred = jnp.sum(a[:, None] * jnp.exp(-b[:, None] * x_j[None, :] ** 2), axis=0)
    diff = pred - target_j
    return 0.5 * jnp.sum(w_j * diff * diff)

grad_jax = jax.grad(obj_jax)
hess_jax = jax.hessian(obj_jax)

def obj(p):
    return float(obj_jax(jnp.array(p)))
def grad(p):
    return np.array(grad_jax(jnp.array(p)))
def hess(p):
    return np.array(hess_jax(jnp.array(p)))

lower = np.full(4, -np.inf)
upper = np.full(4, np.inf)
best, it = fmin_DENewton(
    obj, grad, hess,
    lower, upper,
    beta_min=0.5, beta_max=1.0, pCR=0.9,
    n_pop=20, tol_cost=1e-8, tol_grad=1e-8, max_it=100
)
print('best params:', best)
print('iterations:', it)
print('recovered a:', np.exp(best[:2]))
print('recovered b:', np.exp(best[2:]))
