In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from numdifftools import Gradient
from alpha_algos import get_alpha

In [2]:
EPSILON = 10e-5
MAX_ITER = 1000

BACKTRACKING_GAMMA = .25
BACKTRACKING_BETA = .5
BACKTRACKING_INIT_STEP = 2

NUM_DEC = 5

In [3]:
def met_stopping_condition(x, f, f_prime, num_iter):
    reached_max_iter = num_iter == MAX_ITER
    converged = np.linalg.norm(f_prime(x))  / (1+abs(f(x))) <= EPSILON
    if reached_max_iter:
        warnings.warn("Failed to converge. Max iterations reached")
    return converged or reached_max_iter

def gradient_descent(x_k0, f):
    f_prime = Gradient(f)
    logs = pd.DataFrame(columns=["x", "\u03B1", "d"])
    logs.index.name = "k"
    x_k = x_k0
    k = 0
    alpha = BACKTRACKING_INIT_STEP
    while not met_stopping_condition(x_k, f, f_prime, k):
        direction = - f_prime(x_k)
        alpha = get_alpha("exact", x_k, f, f_prime, direction, BACKTRACKING_GAMMA, BACKTRACKING_BETA, BACKTRACKING_INIT_STEP)
        logs.loc[k] = [np.round(x_k, NUM_DEC), alpha, np.round(direction, NUM_DEC)]
        x_k = x_k + alpha * direction 
        k+=1
    logs.loc[k] = [np.round(x_k, NUM_DEC), alpha, - np.round(f_prime(x_k), NUM_DEC)]
    return logs


In [4]:
Q = np.array([[1,0,0],
             [0,3,0],
             [0,0,9]])
b= np.array([1,1,1])

f = lambda x: 0.5*x.T@Q@x - b.dot(x) # QUADRATIC

x0 = np.array([0,0,0])
gradient_descent(x0, f)

Unnamed: 0_level_0,x,α,d
k,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,"[0, 0, 0]",0.230769,"[1.0, 1.0, 1.0]"
1,"[0.23077, 0.23077, 0.23077]",0.16318,"[0.76923, 0.30769, -1.07692]"
2,"[0.35629, 0.28098, 0.05504]",0.249488,"[0.64371, 0.15706, 0.50467]"
3,"[0.51689, 0.32016, 0.18095]",0.166078,"[0.48311, 0.03951, -0.62851]"
4,"[0.59712, 0.32673, 0.07656]",0.250958,"[0.40288, 0.01982, 0.31092]"
5,"[0.69823, 0.3317, 0.15459]",0.166211,"[0.30177, 0.0049, -0.39133]"
6,"[0.74839, 0.33251, 0.08955]",0.251018,"[0.25161, 0.00246, 0.19406]"
7,"[0.81155, 0.33313, 0.13826]",0.166216,"[0.18845, 0.00061, -0.24435]"
8,"[0.84287, 0.33323, 0.09765]",0.25102,"[0.15713, 0.0003, 0.12119]"
9,"[0.88231, 0.33331, 0.12807]",0.166216,"[0.11769, 8e-05, -0.15259]"
