In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from numdifftools import Gradient, Hessian
from alpha_algos import get_alpha

In [6]:
def distance(x_k0, x_k1):
    if len(x_k0) > 1:
        return sum((x_k1 - x_k0)**2) ** .5
    return abs(x_k1-x_k0)

def get_direction(gradient, hessian):
    if len(gradient) > 1:
        return - gradient.dot(hessian)
    return - gradient / hessian

def quasi_newton(f, hessian, x_k0, epsilon=10e-6):
    f_prime = Gradient(f)
    logs = pd.DataFrame(columns=["x", "f(x)", "d","step", "Gradient", "Hessian", "Distance"])
    Hk_0 = np.identity(len(x_k0))
    direction = get_direction(f_prime(x_k0), Hk_0)
    step = get_alpha("exact", x_k0, f, f_prime, direction)
    logs.loc[0] = [x_k0, f(x_k0), direction, step, f_prime(x_k0), Hk_0, np.nan]

    
    x_k1 = x_k0 + step * direction
    dist = distance(x_k0, x_k1)
    i = 1
    while(dist > epsilon):
        Hk_1 = hessian(x_k0, x_k1, Hk_0, f_prime)
        direction = get_direction(f_prime(x_k1), Hk_1)
        step = get_alpha("exact", x_k1, f, f_prime, direction)
        logs.loc[i] = [x_k1, f(x_k1), direction, step, f_prime(x_k1), Hk_1, dist]
        x_k0 = x_k1
        x_k1 = x_k0 + step * direction
        dist = distance(x_k0, x_k1)
        i+=1
    return logs

In [7]:
# Quasi-Newton 
def f_double_prime(xk_0, xk_1, H, f_prime):
    y = f_prime(xk_1) - f_prime(xk_0)
    s = xk_1 - xk_0
    a = np.identity(len(xk_0)) - s.reshape(-1, 1) * y.reshape(-1, 1).T / y.dot(s)
    b = np.identity(len(xk_0)) - y.reshape(-1, 1) * s.reshape(-1, 1).T / y.dot(s)
    c = s.reshape(-1, 1) * s.reshape(-1, 1).T / y.dot(s)
    
    return a.dot(H.dot(b)) + c 
    
x_k0 = np.array([1, 2])
Q = np.array([[2, 0], [0, 1]])
b = np.array([0, 0])
c = 3
f = lambda x: x.dot(Q.dot(x)) - b.dot(x) + c

df = quasi_newton(f, f_double_prime, x_k0)
df

Unnamed: 0,x,f(x),d,step,Gradient,Hessian,Distance
0,"[1, 2]",9.0,"[-4.000000000000001, -3.999999999999993]",0.333333,"[4.000000000000001, 3.999999999999993]","[[1.0, 0.0], [0.0, 1.0]]",
1,"[-0.3333333333333335, 0.6666666666666692]",3.66667,"[0.8888888888888877, -1.7777777777777863]",0.375,"[-1.3333333333333344, 1.3333333333333397]","[[0.3888888888888875, -0.2777777777777766], [-...",1.885618
2,"[-4.996003610813204e-16, -8.881784197001252e-16]",3.0,"[-0.0, -0.0]",1.0,"[0.0, 0.0]","[[0.6388888888888906, 0.3888888888888895], [0....",0.745356
