In [2]:
import jax
import jax.numpy as jnp
import numpy as np
    
jax.config.update("jax_enable_x64", True)

In [3]:
# https://www.hello-statisticians.com/optimization/optimization5.html
# https://www.hello-statisticians.com/optimization/optimization4.html

@jax.jit
def objective(x:jnp.ndarray) -> float:
    
    return (x[0] - 1)**2 + (x[1] - 2)**2

@jax.jit
def constraints(x:jnp.ndarray) -> jnp.ndarray:

    return jnp.asarray([
        (x[0]**2 + x[1]**2 - 2),
        (-x[0] + x[1]),
        (-x[1])
        ])

@jax.jit
def lagrange(x:jnp.ndarray) -> float:

    
    return objective(x[:-3]) + x[-3:] @ constraints(x[:-3])

In [146]:
# Newton method
x = jnp.zeros(2)
lr = 0.1
epochs = 100
newton_iter = 100
alpha = 0.01

# set initial value
for i in range(epochs):

    x_old = x.copy()
    x = x - lr * jax.grad(objective)(x)

    if jnp.linalg.norm(x - x_old) < 1e-6:
        print(f"Converged after {i} iterations.: {x}\n")
        break

x_lambda = jnp.concatenate((x, jnp.ones(3)))
J = jax.jacobian(lagrange, argnums=0)
H = jax.hessian(lagrange, argnums=0)

for i in range(newton_iter):

    dx = jnp.linalg.inv(H(x_lambda)) @ J(x_lambda)

    if any(jnp.isnan(dx)):
        break
    
    x_lambda = x_lambda - alpha * dx

    print(f"Iteration {i}: x = {x_lambda[:-3]}, constraints = {constraints(x_lambda[:-3])}, objective = {objective(x_lambda[:-3])}")



Converged after 59 iterations.: [0.99999847 1.99999694]

Iteration 0: x = [1.2001163  1.12023456], constraints = [ 0.69520459 -0.07988174 -1.12023456], objective = 0.8140337686998983
Iteration 1: x = [1.18618927 1.10897396], constraints = [ 0.63686823 -0.07721531 -1.10897396], objective = 0.8285938505142163
Iteration 2: x = [1.16997149 1.09755232], constraints = [ 0.5734544  -0.07241917 -1.09755232], objective = 0.8433021145854976
Iteration 3: x = [1.1625937  1.09107278], constraints = [ 0.54206394 -0.07152092 -1.09107278], objective = 0.8525854018966292
Iteration 4: x = [1.0926595  1.05426158], constraints = [ 0.30537225 -0.03839792 -1.05426158], objective = 0.9030069411562056
Iteration 5: x = [1.11789764 1.11545657], constraints = [ 0.49393851 -0.00244107 -1.11545657], objective = 0.7963169316754178
Iteration 6: x = [1.11493987 1.10194745], constraints = [ 0.4573791  -0.01299241 -1.10194745], objective = 0.8197095511510271
Iteration 7: x = [1.08650317 1.08209988], constraints = [ 0.3

In [17]:
newton_iter = 100
alpha = 0.01

x_lambda = jax.random.uniform(jax.random.PRNGKey(11668), (5,))
J = jax.jacobian(lagrange, argnums=0)
H = jax.hessian(lagrange, argnums=0)

for i in range(newton_iter):

    dx = jnp.linalg.inv(H(x_lambda)) @ J(x_lambda)

    if any(jnp.isnan(dx)):
        print("NaN")
        break
    
    x_lambda = x_lambda - alpha * dx

    print(f"Iteration {i}: x = {x_lambda[:-3]}, constraints = {constraints(x_lambda[:-3])}, objective = {objective(x_lambda[:-3])}")

Iteration 0: x = [0.13578914 0.33314748], constraints = [-1.87057407  0.19735834 -0.33314748], objective = 3.525257736488507
Iteration 1: x = [0.20054193 0.35722084], constraints = [-1.8321762   0.15667891 -0.35722084], objective = 3.337856561799195
Iteration 2: x = [0.22946829 0.38104721], constraints = [-1.80214733  0.15157892 -0.38104721], objective = 3.214727247791555
Iteration 3: x = [0.25349235 0.36680616], constraints = [-1.80119487  0.11331381 -0.36680616], objective = 3.2245957772115545
Iteration 4: x = [0.20771193 0.36061666], constraints = [-1.82681138  0.15290473 -0.36061666], objective = 3.3152981193831836
Iteration 5: x = [0.34795582 0.42051656], constraints = [-1.70209257  0.07256075 -0.42051656], objective = 2.9199295518365926
Iteration 6: x = [0.43221553 0.43458191], constraints = [-1.6243283   0.00236637 -0.43458191], objective = 2.7729130037782754
Iteration 7: x = [0.66269617 0.23051618], constraints = [-1.50769607 -0.43217999 -0.23051618], objective = 3.244846854004

In [None]:
# Full Gradient Descent
x = jnp.zeros(2)
lr = 0.1
epochs = 100
newton_iter = 100
alpha = 0.1

history = []

# set initial value
for i in range(epochs):

    x_old = x.copy()
    x = x - lr * jax.grad(objective)(x)

    if jnp.linalg.norm(x - x_old) < 1e-6:
        print(f"Converged after {i} iterations.: {x}\n")
        break

@jax.jit
def lagrange_loss(x_lambda:jnp.ndarray) -> float:

    return jnp.linalg.norm(jax.jacobian(lagrange)(x_lambda))

x_lambda = jnp.concatenate((x, jnp.ones(3)))
grad_fn = jax.grad(lagrange_loss)

for i in range(newton_iter):

    x_lambda_old = x_lambda.copy()

    x_lambda = x_lambda - alpha * grad_fn(x_lambda)

    if any(jnp.isnan(x_lambda)):
        print(f"NaN encountered at iteration {i}: {x_lambda}")
        print(f"{x_lambda_old}")
        break
    
    if (i+1) % (newton_iter // 10) == 0:
        print(f"Iteration {i+1}: x = {x_lambda[:-3]}, constraints = {constraints(x_lambda[:-3])}, objective = {objective(x_lambda)}")

    if jnp.linalg.norm(x_lambda - x_lambda_old) < 1e-6:
        print(f"Converged after {i} iterations.: {x_lambda}, objective = {objective(x_lambda)}\n")
        break

Converged after 59 iterations.: [0.99999847 1.99999694]

Iteration 10: x = [0.80278405 0.88497999], constraints = [-0.57234819  0.08219594 -0.88497999], objective = 1.2821637519840117
Iteration 20: x = [0.84416158 0.81399162], constraints = [-0.62480887 -0.03016996 -0.81399162], objective = 1.4309014875230317
Iteration 30: x = [0.89007062 0.76998225], constraints = [-0.61490163 -0.12008836 -0.76998225], objective = 1.5250281251127131
Iteration 40: x = [0.9263145  0.73423043], constraints = [-0.60284712 -0.19208408 -0.73423043], objective = 1.607602159601633
Iteration 50: x = [0.95272283 0.70564037], constraints = [-0.59439088 -0.24708246 -0.70564037], objective = 1.6776019795417627
Iteration 60: x = [0.97153163 0.68380781], constraints = [-0.58853318 -0.28772382 -0.68380781], objective = 1.733172337896585
Iteration 70: x = [0.98491092 0.66771927], constraints = [-0.58410146 -0.31719165 -0.66771927], objective = 1.7751996262817813
Iteration 80: x = [0.99450768 0.65616436], constraints =

In [None]:
# https://tm23forest.com/contents/bfgs-formula-quasi-newton-method-scipy-motivated

newton_iter = 20
alpha = 1.0

dobjdx = jax.jacobian(objective, argnums=0)
dconstdx = jax.jacobian(constraints, argnums=0)

def f(arr:jnp.ndarray) -> jnp.array:

    x, lambda_ = arr[:2], arr[2:]

    return jnp.block([dobjdx(x) + lambda_ @ dconstdx(x),
                      constraints(x)
                      ])

# find x, lambda s.t. f(x, lambda) = 0
arr = jax.random.uniform(jax.random.PRNGKey(0), (5,))

dfdarr = jax.jacobian(f, argnums=0) # 5x5 matrix
B = jnp.eye(len(arr))   # Hessian approximation, 5x5 matrix
I = jnp.eye(len(arr))   # Identity matrix, 5x5 matrix

c1 = 1e-4
alphas = jnp.full((len(arr), ), 1.0)    # 1x5 vector

for k in range(newton_iter):

    grad = dfdarr(arr)

    if jnp.linalg.norm(grad) < 0.001: break

    p = -B @ grad   # p: 5x5 matrix

    gammas = alphas # gammas: 1x5 vector

    while all(f(arr + gammas @ p) > f(arr) + c1 * gammas @ grad @ p): gammas *= 0.9

    s = gammas @ p  # s: 1x5 vector

    if any(jnp.isnan(s)):
        print(gammas)
        print(rho)
        print(grad)
        print(p)
        print(y)
        print(B)
        break

    arr = arr + s
    y = dfdarr(arr) - grad

    rho = 1 / (y.T@s)

    B = (I - rho * s@y.T) @ B @ (I - rho * y @ s.T) + rho * s @ s.T

    print(f"Iteration {k}: x = {arr[:2]}, constraints = {constraints(arr[:2])}, objective = {objective(arr[:2])}")


Iteration 0: x = [-3.34910003 -4.14693838], constraints = [26.41356894 -0.79783834  4.14693838], objective = 56.69952251405348
Iteration 1: x = [3465.74939546 3832.03356846], constraints = [ 2.66958981e+07  3.66284173e+02 -3.83203357e+03], objective = 26673645.50884811
Iteration 2: x = [-3.00564560e+12 -3.10388655e+12], constraints = [ 1.86680172e+25 -9.82409489e+10  3.10388655e+12], objective = 1.8668017194913408e+25
Iteration 3: x = [8.89519322e+31 9.07137000e+31], constraints = [ 1.61414216e+64  1.76176789e+30 -9.07137000e+31], objective = 1.6141421609558777e+64
Iteration 4: x = [-1.59604661e+73 -1.61021110e+73], constraints = [ 5.14014459e+146 -1.41644913e+071  1.61021110e+073], objective = 5.1401445944796296e+146
Iteration 5: x = [4.91355862e+157 4.93644861e+157], constraints = [             inf  2.28899885e+155 -4.93644861e+157], objective = inf
Iteration 6: x = [-inf -inf], constraints = [inf nan inf], objective = inf
[1. 1. 1. 1. 1.]
[nan nan nan nan nan]
[[nan nan nan nan nan]

In [None]:
# Differential Multiplier
# https://github.com/crowsonkb/mdmm-jax/tree/master
@jax.jit
def eq_constraints(x:jnp.ndarray) -> jnp.ndarray:
    return jnp.asarray([
        (x[0]**2 + x[1]**2 - 2 - x[2]**2),
        (-x[0] + x[1] - x[3]**2),
        (-x[1] - x[4]**2)
    ])

x = jax.random.uniform(jax.random.PRNGKey(42), (5,))
lambda_ = jax.random.uniform(jax.random.PRNGKey(42), (3,))

epochs = 1000
lr = 0.01
c = 10

dfdx = jax.jacobian(objective, argnums=0)
dgdx = jax.jacobian(eq_constraints, argnums=0)

for i in range(epochs):

    dLdx = -dfdx(x) - lambda_ @ dgdx(x) - c * eq_constraints(x) @ dgdx(x)
    dLdlambda = eq_constraints(x)

    if any(jnp.isnan(dLdx)) or any(jnp.isnan(dLdlambda)):
        print(f"Iteration {i+1}: x = {x[:-3]}, constraints = {eq_constraints(x[:-3])}, objective = {objective(x[:-3])}")
        break

    x = x + lr * dLdx
    lambda_ = lambda_ + lr * dLdlambda

    if (i+1) % (epochs // 10) == 0:
        print(f"Iteration {i+1}: x = {x[:-3]}, constraints = {eq_constraints(x[:-3])}, objective = {objective(x[:-3])}")

Iteration 100: x = [1.09110191 0.79647743], constraints = [-0.80949662 -0.92900078 -1.43085374], objective = 1.4567661249364345
Iteration 200: x = [1.09589966 0.78889083], constraints = [-0.79900393 -0.92935757 -1.41123957], objective = 1.4759821666745985
Iteration 300: x = [1.10032121 0.78183848], constraints = [-0.78929323 -0.92975414 -1.39310988], objective = 1.4939818399995595
Iteration 400: x = [1.10438774 0.77528842], constraints = [-0.78032772 -0.93017145 -1.37636054], objective = 1.5108152660910976
Iteration 500: x = [1.10812456 0.76921354], constraints = [-0.77205996 -0.93060049 -1.36090301], objective = 1.5265262349623514
Iteration 600: x = [1.11155532 0.76358755], constraints = [-0.76444478 -0.93103371 -1.3466535 ], objective = 1.5411603336816802
Iteration 700: x = [1.11470207 0.75838499], constraints = [-0.7574393  -0.93146487 -1.33353279], objective = 1.5547643900354056
Iteration 800: x = [1.11758541 0.75358127], constraints = [-0.75100285 -0.93188887 -1.321466  ], objecti

In [19]:
import jax
import jax.numpy as jnp
from jaxopt import LBFGS
from functools import partial

@jax.jit
def objective(x: jnp.ndarray) -> jnp.ndarray:
    """目的関数 f(x)."""
    return (x[0] - 1.0)**2 + (x[1] - 2.0)**2

@jax.jit
def constraints(x: jnp.ndarray) -> jnp.ndarray:
    """不等式制約 g(x) ≤ 0 をベクトルで返す."""
    return jnp.array([
        x[0]**2 + x[1]**2 - 2.0,  # g1(x) ≤ 0
        -x[0] + x[1],             # g2(x) ≤ 0
        -x[1]                     # g3(x) ≤ 0
    ])

@partial(jax.jit, static_argnums=(3, 4, 5))
def solve_augmented_lagrangian(x0, λ0, μ0, num_outer, num_inner, μ_factor):
    x = x0
    λ = λ0
    μ = μ0

    for _ in range(num_outer):
        # Augmented Lagrangian の定義
        def aug_obj(x):
            g = constraints(x)
            # max(g, -λ/μ) で「ペナルティをかける部分」を構築
            penalty_term = jnp.sum(jnp.maximum(g, -λ/μ)**2)
            return objective(x) + λ @ g + (μ * 0.5) * penalty_term

        # L-BFGS を用いて内部最適化
        solver = LBFGS(fun=aug_obj, maxiter=num_inner)
        sol = solver.run(x)
        x = sol.params

        # multipliers 更新
        g = constraints(x)
        λ = jnp.maximum(λ + μ * g, 0.0)

        # penalty weight を増加
        μ = μ * μ_factor

    return x, λ, μ

# 初期値
x0       = jnp.array([0.0, 0.0])
λ0       = jnp.zeros(3)
μ0       = 1.0
num_outer = 10    # 外側ループ回数
num_inner = 100   # 内部 L-BFGS の反復数
μ_factor  = 10.0  # μ の増大係数

x_opt, λ_opt, μ_opt = solve_augmented_lagrangian(
    x0, λ0, μ0, num_outer, num_inner, μ_factor
)

print("最適解 x* =", x_opt)
print("乗数 λ*  =", λ_opt)
print("最終 μ    =", μ_opt)

最適解 x* = [1. 1.]
乗数 λ*  = [0.50000004 1.00000008 0.        ]
最終 μ    = 10000000000.0
