# Gradient Descent for Unconstraint Optimization

In [1]:
import jax
import jax.numpy as jnp

from chex import Array
from jax import grad, jit


A = jnp.array([[4.0, 1.0], [1.0, 3.0]])
b = jnp.array([1.0, 2.0])


def f(x: Array) -> Array:
    """Quadratic function f(x) = 0.5 * x^T A x - b^T x."""
    return 0.5 * jnp.dot(x, jnp.dot(A, x)) - jnp.dot(b, x)

In [2]:
# Create gradient function using JAX
grad_f = jit(grad(f))

In [3]:
learning_rate = 0.1
max_iters = 100
tol = 1e-6

# Initial guess
x = jnp.zeros_like(b)

# Gradient descent loop
for i in range(max_iters):
    grad_value = grad_f(x)
    x_new = x - learning_rate * grad_value

    # Check for convergence
    if jnp.linalg.norm(x_new - x) < tol:
        print(f"Converged after {i} iterations.")
        break

    x = x_new

print("Optimal solution:", x)
print("Function value at optimal solution:", f(x))
# Converged after 43 iterations.
# Optimal solution: [0.09091125 0.6363601 ]
# Function value at optimal solution: -0.6818182

Converged after 43 iterations.
Optimal solution: [0.09091125 0.6363601 ]
Function value at optimal solution: -0.6818182


# Optimization with Constraints: Penalty Method

In [4]:
import jax
import jax.numpy as jnp
from chex import Array
from jax import grad, jit
import matplotlib.pyplot as plt


A = jnp.array([[4.0, 1.0], [1.0, 3.0]])
b = jnp.array([1.0, 2.0])


def f(x: Array) -> Array:
    """Quadratic function f(x) = 0.5 * x^T A x - b^T x."""
    return 0.5 * jnp.dot(x, jnp.dot(A, x)) - jnp.dot(b, x)


def g(x: Array) -> Array:
    """Equality constraint: x1 + x2 = 1"""
    return x[0] + x[1] - 1.0


def h(x: Array) -> Array:
    """Inequality constraint: x1 >= 0.2, rewritten as -x1 + 0.2 <= 0"""
    return -x[0] + 0.2

In [5]:
def f_rho(x: Array, rho: float) -> Array:
    """Penalty function using quadratic penalties"""
    eq_penalty = g(x) ** 2  # Equality constraint penalty
    ineq_penalty = jnp.maximum(0, h(x)) ** 2  # Inequality constraint penalty
    return f(x) + rho * (eq_penalty + ineq_penalty)


@jit
def grad_f_rho(x: Array, rho: float) -> Array:
    return grad(lambda x: f_rho(x, rho))(x)

In [6]:
learning_rate = 0.01  # Smaller learning rate for stability
max_iters = 1000
tol = 1e-8
rho = 10.0


x = jnp.array([0.3, 0.7])  # Initial guess

for i in range(max_iters):
    grad_value = grad_f_rho(x, rho)
    x_new = x - learning_rate * grad_value

    if jnp.linalg.norm(x_new - x) < tol:
        break
    x = x_new

# Calculate final metrics
eq_violation = abs(float(g(x)))
ineq_violation = max(0, float(h(x)))

print(f"Converged after {i} iterations with rho = {rho}.")
print(f"Optimal solution: {x}")
print(f"Function value at optimal solution: {f(x):.2f}")
print(f"Equality constraint violation: {eq_violation:.2f}")
print(f"Inequality constraint violation: {ineq_violation:.2f}")


Converged after 181 iterations with rho = 10.0.
Optimal solution: [0.19789854 0.77583164]
Function value at optimal solution: -0.61
Equality constraint violation: 0.03
Inequality constraint violation: 0.00
