In [7]:
import jax.numpy as jnp
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_hyperplane

# Objective function
eign_values = jnp.array([0.5, 0.6])
lam = 2
def objective(params):
    return jnp.sum(eign_values/(1+2/lam*eign_values*jnp.exp(params))**2)

# Constraint: x1 + x2 = 1, represented as a hyperplane
def projection_fn(params, hyperparams_proj):
    return projection_hyperplane(params, hyperparams=(hyperparams_proj, 1))

# Initialize 'ProjectedGradient' solver
solver = ProjectedGradient(fun=objective, projection=projection_fn)

# Initial parameters
params_init = jnp.array([0.5, 0.5])

# Define the optimization problem
sol = solver.run(params_init, hyperparams_proj=jnp.array([1, 1]))

# check the solution
o=eign_values
s=jnp.exp(sol.params)
cc = ((o**2)*s)/((1+2/lam*o*s)**3)
print('cc', cc)

# Print the optimal solution
print("Optimal Solution: ", sol.params)

cc [0.07055375 0.07050372]
Optimal Solution:  [0.37525702 0.6247432 ]


In [8]:
import jax.numpy as jnp
lam=0.01
eign_values = jnp.array([0.1, 0.2])
params = jnp.array([0.5, 0.5])
eign_values/(1+2/lam*eign_values*jnp.exp(params))**2

Array([8.6635475e-05, 4.4621454e-05], dtype=float32)