# Descent Methods

In [50]:
using LaTeXStrings
using Plots
using ForwardDiff: gradient
using LinearAlgebra: norm, dot

In [51]:
function bisect_min(f; a=-1e12, b=1e12, tol=1e-6, max_iters=1000)
  """
  Find the minimum of a one-dimensional function f(x) using binary search.

  Args:
      f (Function): The function to minimize.
      a (Float64): The left bound of the search interval.
      b (Float64): The right bound of the search interval.
      tol (Float64): The desired tolerance for the minimum.
      max_iters (Int64): The maximum number of iterations possible.
            
  Returns:
      Float64: The minimum of the function f(x) within the given interval.
  """
  k = 1
  a₁ = a
  b₁ = b

  while k < max_iters
    ∇f = gradient(f, [(a₁ + b₁) / 2])[1]
    
    if b₁ - a₁ < tol
      break
    end

    if ∇f > 0
      b₁ = (a₁ + b₁) / 2
    else
      a₁ = (a₁ + b₁) / 2
    end

    k += 1
  end
  
  return (a₁ + b₁) / 2
end;

function exact(f, x, Δx; tol=1e-6, max_iters=100000)
    """
    Performs exact line search to find the optimal step size.

    Parameters:
    f (function): The objective function.
    x (Array{Float64,1}): The current point.
    Δx (Array{Float64,1}): The search direction.
    tol (Float64): The tolerance for the termination condition.

    Returns:
    Float64: The optimal step size.
    """
    h(s) = x .+ s[1] * Δx 
    return bisect_min(f ∘ h)
end;

In [52]:
function backtracking(f, x, Δx; t₀=1, β=0.5, α=0.25)
  """
  Performs backtracking line search to find a suitable step length.

  Args:
    f (Function): the objective function
    x (Vector): the current point
    t₀ (Float64): initial step length
    Δx (Vector): the search direction
    β (Float64): backtracking factor. β ∈ (0,1)
    α (Float64): sufficient decrease condition parameter. α ∈ (0, 0.5)
  """

  t = t₀
  fₓ = f(x)

  while f(x + t * Δx) > fₓ + α * t * dot(gradient(f, x), Δx)
    t *= β
  end

  return t
end;

## Gradient Descent

In [141]:
function gradient_descent(f, x₀; tol=1e-8, max_iters=1000, linesearch=:backtracking)
  """
  Performs the gradient descent algorithm with backtracking line search

  Args:
    f (Function): the objective function
    x (Vector): the initial point
    tol (Float64): tolerance to break the loop
    max_iters (Int64): maximum number of iterations
  """
  x = copy(x₀)
  f_val = f(x)

  iteration = 0

  for i = 1:max_iters
    iteration = i
    ∇f = gradient(f, x)

    if norm(∇f) < tol
      break
    end

    if linesearch == :exact
      t = exact(f, x, -∇f)
    else
      t = backtracking(f, x, -∇f, α=0.1, β=0.7)
    end

    x = x - t * ∇f
    f_val = f(x)
  end

  return x, f_val, iteration
end;

In [150]:
f(x) = x[1]^2
x₀ = [5]

x, f_min, iter = @time gradient_descent(f, x₀, tol=1e-10, max_iters=1000)
println("x_min = $x at f(x_min)=$f_min ($iter iterations - backtracking)")

x, f_min, iter = @time gradient_descent(f, x₀, tol=1e-10, max_iters=1000, linesearch=:exact)
println("x_min = $x at f(x_min)=$f_min ($iter iterations - exact)")

  0.869742 seconds (1.67 M allocations: 114.637 MiB, 1.88% gc time, 99.87% compilation time: 2% of which was recompilation)
x_min = [3.602879701896376e-11] at f(x_min)=1.298074214633692e-21 (29 iterations - backtracking)
  0.292914 seconds (809.86 k allocations: 55.930 MiB, 2.87% gc time, 99.77% compilation time: 3% of which was recompilation)
x_min = [9.578043450053121e-13] at f(x_min)=9.17389163311055e-25 (3 iterations - exact)


In [153]:
# 2-dimensional function
g(x) = exp(x[1] + 3 * x[2] - 0.1) + exp(x[1] - 3 * x[2] - 0.1) + exp(-x[1] - 0.1)
x₀ = [0, 1]

x, g_min, iter = @time gradient_descent(g, x₀, tol=1e-10, max_iters=1000)
println("x_min = $x at f(x_min)=$g_min ($iter iterations - backtracking)")

x, g_min, iter = @time gradient_descent(g, x₀, tol=1e-10, max_iters=1000, linesearch=:exact)
println("x_min = $x at f(x_min)=$g_min ($iter iterations - exact)")

  0.140520 seconds (123.83 k allocations: 6.243 MiB, 93.21% compilation time: 94% of which was recompilation)
x_min = [-0.34657359027997275, 7.852103101148952e-9] at f(x_min)=2.5592666966582156 (1000 iterations - backtracking)
  0.030946 seconds (34.80 k allocations: 2.150 MiB, 91.69% compilation time: 100% of which was recompilation)
x_min = [-0.3465735902696454, 7.239816783138962e-12] at f(x_min)=2.5592666966582156 (27 iterations - exact)


## Steepest Descent

In [146]:
function steepest_descent(f, x₀, Δxsd, norm; tol=1e-4, max_iters=1000, linesearch=:backtracking)
  x = copy(x₀)
  f_val = f(x)

  iter = 0
  for i = 1:max_iters
    iter = i
    Δx = Δxsd(f, x)

    if norm(Δx) < tol
      break
    end

    if linesearch == :exact
      t = exact(f, x, Δx)
    else
      t = backtracking(f, x, Δx, α=0.1, β=0.7)
    end

    x += t * Δx
    f_val = f(x)
  end

  return x, f_val, iter

end;

In [147]:
f(x) = exp(x[1] + 3 * x[2] - 0.1) + exp(x[1] - 3 * x[2] - 0.1) + exp(-x[1] - 0.1)
P1 = [2 0; 0 8]
# P2 = [8 0; 0 2]
norm_fn(z) = z' * P1 * z
Δxsd(f, x) = -inv(P1) * gradient(f, x);

In [148]:
x, f_val, iter = @time steepest_descent(f, [-5, 5], Δxsd, norm_fn, tol=1e-10, linesearch=:backtracking)

println("x_min = $x at f(x_min)=$f_val ($iter iterations - backtracking)")

  0.103401 seconds (30.93 k allocations: 1.991 MiB, 99.49% compilation time: 80% of which was recompilation)
x_min = [-0.3465735859500477, -1.5162050377715008e-6] at f(x_min)=2.5592666966714535 (22 iterations - backtracking)


In [149]:
x, f_val, iter = @time steepest_descent(f, [-5, 5], Δxsd, norm_fn, tol=1e-10, linesearch=:exact)

println("x_min = $x at f(x_min)=$f_val ($iter iterations - exact line search)")

  0.197424 seconds (25.15 k allocations: 1.628 MiB, 53.29% compilation time: 100% of which was recompilation)
x_min = [-0.3465731907770624, -1.1983563574664841e-7] at f(x_min)=2.5592666966585025 (10 iterations - exact line search)
