In [56]:
from collections.abc import Callable
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import optax
import plotly.graph_objects as go
from jax_tqdm.loop_pbar import loop_tqdm

LossHistory = jax.Array
StepState = tuple[optax.Params, optax.OptState, LossHistory]
LossFn = Callable[[optax.Params], jax.Array]
ValueAndGradFn = Callable[[optax.Params], tuple[jax.Array, jax.Array]]
StepFn = Callable[[int, StepState], StepState]


@dataclass(frozen=True)
class Config:
    N: int = 1000
    MAX_ITERS: int = 200
    X_MIN: float = -10.0
    X_MAX: float = 10.0
    LAMBDA_LINEAR: float = 1
    LAMBDA_PEAK: float = 1


def get_initial_state(config: Config) -> jax.Array:
    # return jnp.zeros(config.N)
    x_coords = jnp.linspace(config.X_MIN, config.X_MAX, config.N)
    return jnp.sinc(x_coords) ** 2


def create_triangle_loss_fn(config: Config) -> LossFn:
    center_idx = config.N // 2

    @jax.jit
    def loss_fn(s: jax.Array) -> jax.Array:
        s_prime = jnp.diff(s, n=1)
        left_slopes = s_prime[:center_idx]
        right_slopes = s_prime[center_idx:]
        l_linear = jnp.var(left_slopes) + jnp.var(right_slopes)
        l_peak = (s[center_idx] - 1.0) ** 2

        return config.LAMBDA_LINEAR * l_linear + config.LAMBDA_PEAK * l_peak

    return loss_fn


def create_step_fn(
    solver: optax.GradientTransformationExtraArgs,
    loss_fn: LossFn,
    value_and_grad_fn: ValueAndGradFn,
    max_iters: int,
) -> StepFn:
    @jax.jit
    @loop_tqdm(max_iters, desc="Optimizing")
    def step_fn(i: int, state: StepState) -> StepState:
        params, opt_state, loss_history = state
        value, grad = value_and_grad_fn(params)
        updates, opt_state = solver.update(grad, opt_state, params, value=value, grad=grad, value_fn=loss_fn)
        params = optax.apply_updates(params, updates)
        loss_history = loss_history.at[i].set(value)
        return (params, opt_state, loss_history)

    return step_fn


def run_optimization(
    config: Config,
    s_initial: jax.Array,
    solver: optax.GradientTransformation,
    step_fn: StepFn,
) -> tuple[jax.Array, LossHistory]:
    initial_opt_state = solver.init(s_initial)
    initial_loss_history = jnp.zeros(config.MAX_ITERS)
    initial_state = (s_initial, initial_opt_state, initial_loss_history)

    (s_final, _, loss_history) = jax.lax.fori_loop(0, config.MAX_ITERS, step_fn, initial_state)

    s_final.block_until_ready()
    return s_final, loss_history


def plot_results(x_coords: jax.Array, s_initial: jax.Array, s_final: jax.Array, initial_loss: jax.Array, final_loss: jax.Array) -> None:
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x_coords, y=s_initial, name=f"Initial - Loss: {initial_loss:.4f}"))
    fig.add_trace(go.Scatter(x=x_coords, y=s_final, name=f"Final - Loss: {final_loss:.4f}"))
    fig.update_layout(
        title="Optimization Results",
        xaxis_title="Coordinates",
        yaxis_title="Amplitude",
        legend_title="State",
    )
    fig.show()


def main() -> None:
    config = Config()
    s_initial = get_initial_state(config)
    # loss_fn = create_mock_loss_fn(config)
    loss_fn = create_triangle_loss_fn(config)
    value_and_grad_fn = jax.value_and_grad(loss_fn)
    solver = optax.lbfgs(memory_size=100)
    step_fn = create_step_fn(solver, loss_fn, value_and_grad_fn, config.MAX_ITERS)
    s_final, loss_history = run_optimization(config, s_initial, solver, step_fn)
    initial_loss = loss_fn(s_initial)
    final_loss = loss_history[-1]
    print(f"Initial Loss: {initial_loss:.6f}, Final Loss: {final_loss:.6f}")

    x_coords = jnp.linspace(config.X_MIN, config.X_MAX, config.N)

    plot_results(x_coords, s_initial, s_final, initial_loss, final_loss)


if __name__ == "__main__":
    main()

Optimizing:   0%|          | 0/200 [00:00<?, ?it/s]

Initial Loss: 0.000098, Final Loss: 0.000007
