# Gradient Descent

In [2]:
import jax
import jax.numpy as jnp

from chex import Array
from jax import grad, jit


A = jnp.array([[4.0, 1.0], [1.0, 3.0]])
b = jnp.array([1.0, 2.0])


def f(x: Array) -> Array:
    """Quadratic function f(x) = 0.5 * x^T A x - b^T x."""
    return 0.5 * jnp.dot(x, jnp.dot(A, x)) - jnp.dot(b, x)

In [None]:
# Create gradient function using JAX
grad_f = jit(grad(f)) 

In [None]:
learning_rate = 0.1
max_iters = 100
tol = 1e-6

# Initial guess
x = jnp.zeros_like(b)

# Gradient descent loop
for i in range(max_iters):
    grad_value = grad_f(x)
    x_new = x - learning_rate * grad_value
    
    # Check for convergence
    if jnp.linalg.norm(x_new - x) < tol:
        print(f"Converged after {i} iterations.")
        break
    
    x = x_new

print("Optimal solution:", x)
print("Function value at optimal solution:", f(x))
# Converged after 43 iterations.
# Optimal solution: [0.09091125 0.6363601 ]
# Function value at optimal solution: -0.6818182

Converged after 43 iterations.
Optimal solution: [0.09091125 0.6363601 ]
Function value at optimal solution: -0.6818182


# Optimization with Constraints: Penalty Method

In [6]:
import jax
import jax.numpy as jnp
from chex import Array
from jax import grad, jit

# Original objective function components
A = jnp.array([[4.0, 1.0], [1.0, 3.0]])
b = jnp.array([1.0, 2.0])

def f(x: Array) -> Array:
    """Quadratic function f(x) = 0.5 * x^T A x - b^T x."""
    return 0.5 * jnp.dot(x, jnp.dot(A, x)) - jnp.dot(b, x)

# Define constraints
def g(x: Array) -> Array:
    """Equality constraint: x1 + x2 = 1"""
    return x[0] + x[1] - 1.0

def h(x: Array) -> Array:
    """Inequality constraint: x1 >= 0.2, rewritten as -x1 + 0.2 <= 0"""
    return -x[0] + 0.2

# Penalty function
def penalty(x: Array, rho: float) -> Array:
    """Penalty function using quadratic penalties"""
    eq_penalty = g(x)**2  # Equality constraint penalty
    ineq_penalty = jnp.maximum(0, h(x))**2  # Inequality constraint penalty
    return f(x) + rho * (eq_penalty + ineq_penalty)

# Create gradient function for penalized objective
@jit
def grad_penalty(x: Array, rho: float) -> Array:
    return grad(lambda x: penalty(x, rho))(x)

# Solve with penalty method - improved version
def solve_constrained():
    # Parameters - adjusted for better stability
    max_iters = 500  # More iterations per rho
    tol = 1e-8
    
    # Penalty method parameters - more gradual increase
    rho = 0.1  # Start with smaller penalty
    rho_multiplier = 2.0  # Smaller increase factor
    max_outer_iters = 10  # More outer iterations
    
    # Initial guess - closer to feasible region
    x = jnp.array([0.3, 0.7])
    
    print("Solving constrained optimization problem:")
    print("min f(x) = 0.5 * x^T A x - b^T x")
    print("s.t. x1 + x2 = 1")
    print("     x1 >= 0.2")
    print("\n" + "="*50 + "\n")
    
    best_x = x
    best_obj = jnp.inf
    
    # Outer loop: gradually increase penalty parameter
    for outer_iter in range(max_outer_iters):
        print(f"Outer iteration {outer_iter + 1}, rho = {rho:.2f}")
        
        # Adaptive learning rate based on rho
        learning_rate = min(0.1, 1.0 / (1.0 + rho))
        
        # Inner loop: gradient descent for fixed rho
        converged = False
        for i in range(max_iters):
            grad_value = grad_penalty(x, rho)
            
            # Line search for better step size
            alpha = learning_rate
            for _ in range(10):
                x_new = x - alpha * grad_value
                if penalty(x_new, rho) < penalty(x, rho):
                    break
                alpha *= 0.5
            
            x_new = x - alpha * grad_value
            
            # Check for convergence
            if jnp.linalg.norm(x_new - x) < tol:
                converged = True
                break
            
            x = x_new
        
        # Check for NaN
        if jnp.any(jnp.isnan(x)):
            print("  WARNING: NaN detected, reverting to best solution")
            x = best_x
            break
        
        # Print current solution and constraint violations
        eq_violation = abs(g(x))
        ineq_violation = max(0, float(h(x)))
        obj_value = float(f(x))
        
        print(f"  Solution: x = [{x[0]:.6f}, {x[1]:.6f}]")
        print(f"  Objective value: f(x) = {obj_value:.6f}")
        print(f"  Equality constraint violation: |g(x)| = {eq_violation:.2e}")
        print(f"  Inequality constraint violation: max(0, h(x)) = {ineq_violation:.2e}")
        print(f"  Converged: {converged} (iterations: {i+1})")
        print()
        
        # Update best solution if constraints are better satisfied
        if eq_violation + ineq_violation < 0.01 and obj_value < best_obj:
            best_x = x.copy()
            best_obj = obj_value
        
        # Check if constraints are satisfied
        if eq_violation < 1e-5 and ineq_violation < 1e-5:
            print("Constraints satisfied!")
            break
        
        # Increase penalty parameter
        rho *= rho_multiplier
    
    # Use best solution found
    x = best_x
    
    print("\n" + "="*50 + "\n")
    print("Final solution:")
    print(f"x* = [{x[0]:.6f}, {x[1]:.6f}]")
    print(f"f(x*) = {f(x):.6f}")
    print(f"g(x*) = {g(x):.6f} (should be 0)")
    print(f"h(x*) = {h(x):.6f} (should be <= 0)")
    
    # Verify constraints
    print(f"\nConstraint satisfaction:")
    print(f"x1 + x2 = {x[0] + x[1]:.6f} (should be 1.0)")
    print(f"x1 = {x[0]:.6f} (should be >= 0.2)")
    
    # Compare with unconstrained solution
    print("\n" + "="*50 + "\n")
    print("Comparison with unconstrained solution:")
    
    # Solve unconstrained problem
    x_unconstrained = jnp.zeros_like(b)
    grad_f = jit(grad(f))
    
    for i in range(100):
        grad_value = grad_f(x_unconstrained)
        x_new = x_unconstrained - 0.1 * grad_value
        if jnp.linalg.norm(x_new - x_unconstrained) < tol:
            break
        x_unconstrained = x_new
    
    print(f"Unconstrained: x = [{x_unconstrained[0]:.6f}, {x_unconstrained[1]:.6f}], f(x) = {f(x_unconstrained):.6f}")
    print(f"Constrained:   x = [{x[0]:.6f}, {x[1]:.6f}], f(x) = {f(x):.6f}")
    
    # Analytical check - for this problem we can verify the solution
    print("\n" + "="*50 + "\n")
    print("Analytical verification:")
    print("At the optimal point, the inequality constraint should be active (x1 = 0.2)")
    print("Combined with x1 + x2 = 1, we get x2 = 0.8")
    x_analytical = jnp.array([0.2, 0.8])
    print(f"Analytical: x = [{x_analytical[0]:.6f}, {x_analytical[1]:.6f}], f(x) = {f(x_analytical):.6f}")
    
    return x

# Run the constrained optimization
optimal_x = solve_constrained()

Solving constrained optimization problem:
min f(x) = 0.5 * x^T A x - b^T x
s.t. x1 + x2 = 1
     x1 >= 0.2


Outer iteration 1, rho = 0.10
  Solution: x = [0.105090, 0.648057]
  Objective value: f(x) = -0.681045
  Equality constraint violation: |g(x)| = 2.47e-01
  Inequality constraint violation: max(0, h(x)) = 9.49e-02
  Converged: True (iterations: 34)

Outer iteration 2, rho = 0.20
  Solution: x = [0.116476, 0.657869]
  Objective value: f(x) = -0.679267
  Equality constraint violation: |g(x)| = 2.26e-01
  Inequality constraint violation: max(0, h(x)) = 8.35e-02
  Converged: True (iterations: 21)

Outer iteration 3, rho = 0.40
  Solution: x = [0.133505, 0.673575]
  Objective value: f(x) = -0.674527
  Equality constraint violation: |g(x)| = 1.93e-01
  Inequality constraint violation: max(0, h(x)) = 6.65e-02
  Converged: True (iterations: 27)

Outer iteration 4, rho = 0.80
  Solution: x = [0.154497, 0.695263]
  Objective value: f(x) = -0.664783
  Equality constraint violation: |g(x)| =