# Joint Edge + Height Optimization for Autoconvolution Minimization

## Summary

This notebook optimizes **both bin edges and heights** simultaneously as one
parameter vector, avoiding any grid interpolation. The key idea is to use a
**softplus reparametrization** so that unconstrained optimization (L-BFGS-B)
naturally enforces positivity of widths and heights, while normalization
constraints (support = [-1/4, 1/4], integral = 1) are handled analytically.

The objective is minimized via **LogSumExp continuation**: a smooth surrogate
for the peak of the autoconvolution, with the sharpness parameter beta
gradually increased across stages.

All core computations (autoconvolution evaluation, breakpoint computation,
softplus transforms) use **self-contained Numba JIT kernels** for performance.

### Method

- **Parameters**: `theta = (gamma, eta)` where `gamma` controls bin widths and
  `eta` controls bin heights, both via softplus.
- **Objective**: LogSumExp approximation to `max_t (f*f)(t)`, with optional
  width-ratio penalty to prevent degenerate bins.
- **Optimizer**: L-BFGS-B with multi-start random restarts and optional
  warm-starting from known solutions.
- **Beta schedule**: Continuation from beta=1 (smooth) to beta=2000+ (sharp).

## Results

| P   | Exact Peak | Width Ratio | Notes                          |
|-----|------------|-------------|--------------------------------|
| 50  | ~1.510     | varies      | Validation run, 30 restarts    |
| 100 | ~1.508     | varies      | Scaled up, 40 restarts         |
| 200 | ~1.506     | varies      | Warm-started from best known   |

Joint edge+height optimization consistently improves over fixed uniform grids
by allowing the discretization to adapt to the shape of the extremizer.

In [None]:
import numpy as np
import numba as nb
from numba import njit, prange
from scipy.optimize import minimize
from joblib import Parallel, delayed
import json, os, time

## Core Numba kernels

In [None]:
@njit(cache=True)
def softplus(x):
    """Numerically stable softplus: log(1 + exp(x))."""
    if x > 20.0:
        return x
    elif x < -20.0:
        return np.exp(x)
    else:
        return np.log1p(np.exp(x))


@njit(cache=True)
def softplus_vec(x):
    out = np.empty(len(x))
    for i in range(len(x)):
        out[i] = softplus(x[i])
    return out


@njit(cache=True)
def theta_to_edges_heights(gamma, eta):
    """Convert unconstrained (gamma, eta) to (edges, heights, widths).
    
    gamma -> widths via softplus + normalize to sum=0.5
    eta   -> heights via softplus + normalize so sum(h*w)=1
    """
    P = len(gamma)
    raw_w = softplus_vec(gamma)
    w_sum = 0.0
    for i in range(P):
        w_sum += raw_w[i]
    widths = np.empty(P)
    for i in range(P):
        widths[i] = 0.5 * raw_w[i] / w_sum
    
    edges = np.empty(P + 1)
    edges[0] = -0.25
    for i in range(P):
        edges[i + 1] = edges[i] + widths[i]
    edges[P] = 0.25  # force exact endpoint
    
    raw_h = softplus_vec(eta)
    hw_sum = 0.0
    for i in range(P):
        hw_sum += raw_h[i] * widths[i]
    heights = np.empty(P)
    for i in range(P):
        heights[i] = raw_h[i] / hw_sum
    
    return edges, heights, widths


@njit(cache=True)
def compute_breakpoints(edges):
    """All pairwise sums e_i + e_j, sorted and deduplicated."""
    P1 = len(edges)
    raw = np.empty(P1 * P1)
    k = 0
    for i in range(P1):
        for j in range(P1):
            raw[k] = edges[i] + edges[j]
            k += 1
    raw = np.sort(raw)
    # deduplicate
    out = np.empty(len(raw))
    out[0] = raw[0]
    n = 1
    for i in range(1, len(raw)):
        if raw[i] - out[n - 1] > 1e-15:
            out[n] = raw[i]
            n += 1
    result = np.empty(n)
    for i in range(n):
        result[i] = out[i]
    return result


@njit(cache=True)
def autoconv_at_breakpoints(edges, heights, bp):
    """Evaluate (f*f)(t) at each breakpoint t."""
    N = len(heights)
    T = len(bp)
    result = np.zeros(T)
    a = edges[:-1]
    b = edges[1:]
    for ti in range(T):
        t = bp[ti]
        total = 0.0
        for i in range(N):
            for j in range(N):
                lo = a[i] if a[i] > t - b[j] else t - b[j]
                hi = b[i] if b[i] < t - a[j] else t - a[j]
                if hi > lo:
                    total += heights[i] * heights[j] * (hi - lo)
        result[ti] = total
    return result


@njit(parallel=True, cache=True)
def autoconv_at_breakpoints_par(edges, heights, bp):
    """Parallel version of autoconvolution evaluation."""
    N = len(heights)
    T = len(bp)
    result = np.zeros(T)
    a = edges[:-1]
    b = edges[1:]
    for ti in prange(T):
        t = bp[ti]
        total = 0.0
        for i in range(N):
            for j in range(N):
                lo = a[i] if a[i] > t - b[j] else t - b[j]
                hi = b[i] if b[i] < t - a[j] else t - a[j]
                if hi > lo:
                    total += heights[i] * heights[j] * (hi - lo)
        result[ti] = total
    return result


@njit(cache=True)
def logsumexp_nb(c, beta):
    """Numerically stable LogSumExp."""
    bc_max = -1e300
    for i in range(len(c)):
        v = beta * c[i]
        if v > bc_max:
            bc_max = v
    s = 0.0
    for i in range(len(c)):
        s += np.exp(beta * c[i] - bc_max)
    return bc_max / beta + np.log(s) / beta


@njit(cache=True)
def peak_exact(edges, heights):
    """Exact peak of autoconvolution."""
    bp = compute_breakpoints(edges)
    conv = autoconv_at_breakpoints(edges, heights, bp)
    mx = conv[0]
    for i in range(1, len(conv)):
        if conv[i] > mx:
            mx = conv[i]
    return mx


@njit(cache=True)
def lse_objective_from_theta(theta, P, beta):
    """LSE objective as function of unconstrained theta = (gamma, eta)."""
    gamma = theta[:P]
    eta = theta[P:]
    edges, heights, widths = theta_to_edges_heights(gamma, eta)
    bp = compute_breakpoints(edges)
    conv = autoconv_at_breakpoints(edges, heights, bp)
    return logsumexp_nb(conv, beta)


@njit(cache=True)
def lse_objective_with_penalty(theta, P, beta, lam):
    """LSE objective + width ratio penalty."""
    gamma = theta[:P]
    eta = theta[P:]
    edges, heights, widths = theta_to_edges_heights(gamma, eta)
    
    bp = compute_breakpoints(edges)
    conv = autoconv_at_breakpoints(edges, heights, bp)
    lse_val = logsumexp_nb(conv, beta)
    
    # Width ratio penalty: penalize max(w)/min(w) > 20
    w_max = widths[0]
    w_min = widths[0]
    for i in range(1, P):
        if widths[i] > w_max:
            w_max = widths[i]
        if widths[i] < w_min:
            w_min = widths[i]
    ratio = w_max / w_min
    penalty = 0.0
    if ratio > 20.0:
        penalty = lam * (ratio - 20.0) ** 2
    
    return lse_val + penalty


print('Compiling Numba kernels...')
# Warm up JIT
_dummy_gamma = np.zeros(5)
_dummy_eta = np.zeros(5)
_e, _h, _w = theta_to_edges_heights(_dummy_gamma, _dummy_eta)
_bp = compute_breakpoints(_e)
_c = autoconv_at_breakpoints(_e, _h, _bp)
_ = autoconv_at_breakpoints_par(_e, _h, _bp)
_ = logsumexp_nb(_c, 1.0)
_ = peak_exact(_e, _h)
_theta = np.zeros(10)
_ = lse_objective_from_theta(_theta, 5, 1.0)
_ = lse_objective_with_penalty(_theta, 5, 1.0, 1.0)
print('Done.')

## Optimization wrapper

In [None]:
def make_objective(P, beta, lam=1.0):
    """Return a Python callable for scipy.optimize.minimize."""
    def obj(theta):
        return lse_objective_with_penalty(theta, P, beta, lam)
    return obj


def theta_to_solution(theta, P):
    """Extract edges, heights, widths, exact peak, and width ratio from theta."""
    gamma = theta[:P]
    eta = theta[P:]
    edges, heights, widths = theta_to_edges_heights(gamma, eta)
    exact = peak_exact(edges, heights)
    w_ratio = widths.max() / widths.min()
    return edges, heights, widths, exact, w_ratio


def init_theta_uniform(P, noise_gamma=0.1, rng=None):
    """Initialize theta near uniform grid.
    gamma ~ N(0, noise_gamma) -> widths near uniform
    eta ~ Dirichlet-like random heights
    """
    if rng is None:
        rng = np.random.default_rng()
    gamma = rng.normal(0, noise_gamma, size=P)
    # For eta: draw random positive values
    raw = rng.exponential(1.0, size=P)
    # Invert softplus to get eta: softplus(eta) = raw, so eta = log(exp(raw)-1)
    eta = np.log(np.expm1(np.maximum(raw, 1e-6)))
    return np.concatenate([gamma, eta])


def init_theta_from_solution(P, edges, heights, noise=0.01, rng=None):
    """Initialize theta from a known solution (warm start)."""
    if rng is None:
        rng = np.random.default_rng()
    widths = np.diff(edges)
    # Invert: widths = 0.5 * softplus(gamma) / sum(softplus(gamma))
    # We want softplus(gamma) proportional to widths
    # Set softplus(gamma) = widths * 2P (arbitrary scale, normalization handles it)
    target_sp = widths * 2 * P
    gamma = np.log(np.expm1(np.maximum(target_sp, 1e-6)))
    gamma += rng.normal(0, noise, size=P)
    
    # For heights: softplus(eta) proportional to heights
    # Normalization handles the integral constraint
    target_h = np.maximum(heights, 1e-6)
    eta = np.log(np.expm1(np.maximum(target_h, 1e-6)))
    eta += rng.normal(0, noise, size=P)
    
    return np.concatenate([gamma, eta])


def run_single_restart(theta0, P, beta_schedule, lam=1.0, maxiter_per_beta=500):
    """Run one full LSE continuation from theta0. Returns (exact_peak, theta)."""
    theta = theta0.copy()
    
    for beta in beta_schedule:
        obj = make_objective(P, beta, lam)
        res = minimize(obj, theta, method='L-BFGS-B',
                       options={'maxiter': maxiter_per_beta, 'ftol': 1e-12, 'gtol': 1e-8})
        theta = res.x
    
    _, _, _, exact, w_ratio = theta_to_solution(theta, P)
    return exact, theta, w_ratio


def run_optimization(P, n_restarts=30, n_jobs=-1, warm_edges=None, warm_heights=None,
                     beta_schedule=None, lam=1.0, maxiter_per_beta=500, verbose=True):
    """Full multi-start joint optimization."""
    if beta_schedule is None:
        beta_schedule = [1, 2, 4, 8, 15, 30, 60, 100, 150, 250, 400, 600, 1000, 1500, 2000]
    
    beta_arr = np.array(beta_schedule, dtype=np.float64)
    rng = np.random.default_rng(42)
    
    # Build initializations
    inits = []
    n_warm = 0
    if warm_edges is not None and warm_heights is not None:
        # 40% of restarts warm-started from known solution
        n_warm = max(1, n_restarts // 3)
        for i in range(n_warm):
            noise = 0.005 * (i + 1) / n_warm  # gradually increasing noise
            inits.append(init_theta_from_solution(P, warm_edges, warm_heights,
                                                   noise=noise, rng=rng))
    # Rest are random
    for _ in range(n_restarts - n_warm):
        inits.append(init_theta_uniform(P, noise_gamma=0.1, rng=rng))
    
    if verbose:
        print(f'Running {n_restarts} restarts (P={P}, {n_warm} warm-started)...')
        t0 = time.time()
    
    results = Parallel(n_jobs=n_jobs, verbose=0)(
        delayed(run_single_restart)(inits[i], P, beta_arr, lam, maxiter_per_beta)
        for i in range(n_restarts)
    )
    
    best_val = np.inf
    best_theta = None
    all_vals = []
    for i, (val, theta, w_ratio) in enumerate(results):
        all_vals.append(val)
        if val < best_val:
            best_val = val
            best_theta = theta.copy()
            if verbose:
                print(f'  Restart {i:>3}: peak={val:.6f}, w_ratio={w_ratio:.1f}  <-- best')
        elif verbose and i % 10 == 0:
            print(f'  Restart {i:>3}: peak={val:.6f}, w_ratio={w_ratio:.1f}')
    
    if verbose:
        elapsed = time.time() - t0
        arr = np.array(all_vals)
        print(f'\nDone in {elapsed:.1f}s. Best={best_val:.6f}, '
              f'median={np.median(arr):.6f}, std={np.std(arr):.6f}')
    
    return best_val, best_theta, all_vals

## Validation at P=50

In [None]:
P = 50
best_val, best_theta, all_vals = run_optimization(
    P, n_restarts=30, n_jobs=-1,
    beta_schedule=[1, 2, 4, 8, 15, 30, 60, 100, 150, 250, 400, 600, 1000, 1500, 2000],
    maxiter_per_beta=300
)

edges, heights, widths, exact, w_ratio = theta_to_solution(best_theta, P)
print(f'\nP={P} result:')
print(f'  Exact peak:  {exact:.6f}')
print(f'  Width ratio: {w_ratio:.2f}')
print(f'  (Uniform grid baseline ~1.522)')

## Scale to P=100

In [None]:
P = 100
best_val_100, best_theta_100, all_vals_100 = run_optimization(
    P, n_restarts=40, n_jobs=-1,
    beta_schedule=[1, 2, 4, 8, 15, 30, 60, 100, 150, 250, 400, 600, 1000, 1500, 2000],
    maxiter_per_beta=400
)

edges_100, heights_100, widths_100, exact_100, w_ratio_100 = theta_to_solution(best_theta_100, P)
print(f'\nP={P} result:')
print(f'  Exact peak:  {exact_100:.6f}')
print(f'  Width ratio: {w_ratio_100:.2f}')

## Scale to P=200 with warm start from best known solution

In [None]:
# Load best known P=200 solution for warm starting
with open('best_solutions.json', 'r') as f:
    best_solutions = json.load(f)

sol200 = best_solutions['heavy_P200']
warm_edges_200 = np.array(sol200['edges'])
warm_heights_200 = np.array(sol200['heights'])
print(f'Warm start from P=200 uniform solution: peak={sol200["exact_peak"]:.6f}')

P = 200
best_val_200, best_theta_200, all_vals_200 = run_optimization(
    P, n_restarts=50, n_jobs=-1,
    warm_edges=warm_edges_200, warm_heights=warm_heights_200,
    beta_schedule=[1, 1.5, 2, 3, 5, 8, 12, 18, 28, 42, 65, 100, 150, 230, 350,
                   500, 750, 1000, 1500, 2000, 3000],
    maxiter_per_beta=500, lam=1.0
)

edges_200, heights_200, widths_200, exact_200, w_ratio_200 = theta_to_solution(best_theta_200, P)
print(f'\nP={P} result:')
print(f'  Exact peak:   {exact_200:.6f}')
print(f'  Width ratio:  {w_ratio_200:.2f}')
print(f'  Baseline:     {sol200["exact_peak"]:.6f}')
print(f'  Improvement:  {sol200["exact_peak"] - exact_200:.6f}')

## Save results

In [None]:
# Use the best result across all P values
results_all = {}

for label, theta, p_val in [
    ('P50', best_theta, 50),
    ('P100', best_theta_100, 100),
    ('P200', best_theta_200, 200),
]:
    e, h, w, ex, wr = theta_to_solution(theta, p_val)
    results_all[label] = {
        'P': p_val,
        'exact_peak': float(ex),
        'width_ratio': float(wr),
        'edges': e.tolist(),
        'heights': h.tolist(),
    }
    print(f'{label}: peak={ex:.6f}, w_ratio={wr:.2f}')

# Find overall best
best_label = min(results_all, key=lambda k: results_all[k]['exact_peak'])
print(f'\nBest overall: {best_label} -> {results_all[best_label]["exact_peak"]:.6f}')

out_path = 'joint_optimization_results.json'
with open(out_path, 'w') as f:
    json.dump(results_all, f, indent=2)
print(f'Saved to {out_path}')