In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from numdifftools import Gradient, Hessian
from alpha_algos import get_alpha

In [3]:
def get_direction(gradient, hessian):
    d = - gradient.dot(np.linalg.inv(hessian))
    return d[0,0] if d.shape[0] == 1 else d

def distance(x_k0, x_k1):
    if len(x_k0) > 1:
        return sum((x_k1 - x_k0)**2) ** .5
    return abs(x_k1-x_k0)

def newton_method(f, x_k0, epsilon=10e-6):
    f_prime, hessian = Gradient(f), Hessian(f)
    logs = pd.DataFrame(columns=["x", "f(x)", "step", "Gradient", "Hessian", "Distance"])
    dist = None
    i = 0
    while(i == 0 or dist > epsilon):
        direction = get_direction(f_prime(x_k0), hessian(x_k0))
        step = 1#get_alpha("exact", x_k0, f, f_prime, direction)
        x_k1 =  x_k0 + step * direction
        dist = distance(x_k0, x_k1)
        logs.loc[i] = [x_k0, np.round(f(x_k0), 4), np.round(step, 4), 
                       np.round(f_prime(x_k0), 4), np.round(hessian(x_k0), 4), np.round(dist, 4)]
#         print(f"----- Iteration {i} ------")
#         print(logs.loc[i])
        x_k0 = x_k1
        i+=1
    return logs

In [10]:
# n dimention newton
f = lambda x: (x[0]-2)**4+(x[0]-2)**2 * (x[1])**2+(x[1]+1)**2
x_k0 = np.array([1, 2])
newton_method(f, x_k0)

Unnamed: 0,x,f(x),step,Gradient,Hessian,Distance
0,"[1, 2]",14.0,1,"[-12.0, 10.0]","[[20.0, -8.0], [-8.0, 4.0]]",6.8007
1,"[-1.000000000000088, -4.500000000000192]",275.5,1,"[-229.5, -88.0]","[[148.5, 54.0], [54.0, 20.0]]",12.855
2,"[-4.000000000005886, 8.00000000001581]",3681.0,1,"[-1632.0, 594.0]","[[560.0, -192.0], [-192.0, 74.0]]",4.4652
3,"[-2.5314685314797236, 3.783216783206946]",738.433,1,"[-501.9158, 164.9371]","[[275.0359, -68.5741], [-68.5741, 43.0684]]",2.1047
4,"[-1.088605932332375, 2.250907593024177]",149.903,1,"[-149.1522, 49.4468]","[[124.607, -27.8087], [-27.8087, 21.079]]",1.4463
5,"[-0.13411488973268226, 1.1643414757532518]",31.6018,1,"[-44.6652, 14.9345]","[[57.3647, -9.9394], [-9.9394, 11.1089]]",1.0023
6,"[0.5116834609779106, 0.39777307048930366]",7.2109,1,"[-13.658, 4.5577]","[[26.8975, -2.368], [-2.368, 6.4302]]",0.709
7,"[0.9719831435701063, -0.14151842004429394]",1.875,1,"[-4.3869, 1.4178]","[[12.7219, 0.5819], [0.5819, 4.1136]]",0.5372
8,"[1.3349279553874698, -0.5375315074693571]",0.5373,1,"[-1.561, 0.4494]","[[5.8857, 1.43], [1.43, 2.8846]]",0.4748
9,"[1.6795040946008999, -0.864141995352457]",0.1057,1,"[-0.6103, 0.0942]","[[2.7261, 1.1078], [1.1078, 2.2054]]",0.3604
