Gradient descent implementation with some improvements

In [25]:
import math
import matplotlib
import numpy as np
import matplotlib.cm as cm
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
from sympy import symbols, diff, cos, sin, N, Function
import cvxpy

def f(x,y):
    return (x - 1)**2 + (y - x**2)**2 # Rosenbrock function

In [35]:
TOL = 1e-6
noise = 1e-6
n_iterates = 100

def gradient_descent(f, diff_x, diff_y, start, f_prev, n, path):
    step = linesearch(start, diff_x, diff_y) # calculate optimal step size using line search
    
    start[x] -= step*(N(diff_x.subs(start)))
    start[y] -= step*(N(diff_y.subs(start)))
    print (start[x], start[y])
    
    if n >= n_iterates: # condition for algorithm to terminate
        return N(f.subs(start)), start, path
    
    f_prev = N(f.subs(start))
    path.append((start[x], start[y]))
    n += 1
    
    return gradient_descent(f, diff_x, diff_y, start, f_prev, n, path) # recursion

def linesearch(start, diff_x, diff_y):
    global x, y, alpha, x_alpha, y_alpha, f_alpha, TOL, noise
    
    l = noise
    h = start[alpha]
    mid = (l+h)/2
    
    x_alpha = diff_x*alpha
    y_alpha = diff_y*alpha

    for _ in range(100):
        res = diff(f_alpha, alpha).subs({alpha:h, x: start[x], y: start[y]})
        if res <= TOL:
            return mid
        if res < 0:
            h *= 2
        if res > 0:
            mid = l + h
            if mid <= 0:
                l = mid
            else:
                h = mid
    return mid

x, y, alpha = symbols('x y alpha', real = True)
f = x**4

# function h'(a) for line search
x_alpha, y_alpha = symbols('x_alpha y_alpha', real = True)
f_alpha = (x_alpha)**4

# partial derivatives wrt to x and y
diff_x = diff(f, x)
diff_y = diff(f, y)

# initial state
start = {alpha:0.25, x:1, y:0}

# search result
res = gradient_descent(f, diff_x, diff_y, start, N(f.subs(start)), 0, [(start[x], start[y])])

# print out result
print("Optimum point:    %f" %res[0])
print("Error: %f" %(res[0])) # based on our analytical solution
print("At values:  ", res[1])

(-2.00000000005751e-6, 0)
(-2.00000000004951e-6, 0)
(-2.00000000004151e-6, 0)
(-2.00000000003351e-6, 0)
(-2.00000000002551e-6, 0)
(-2.00000000001751e-6, 0)
(-2.00000000000951e-6, 0)
(-2.00000000000151e-6, 0)
(-1.99999999999351e-6, 0)
(-1.99999999998551e-6, 0)
(-1.99999999997751e-6, 0)
(-1.99999999996951e-6, 0)
(-1.99999999996151e-6, 0)
(-1.99999999995351e-6, 0)
(-1.99999999994551e-6, 0)
(-1.99999999993751e-6, 0)
(-1.99999999992951e-6, 0)
(-1.99999999992151e-6, 0)
(-1.99999999991351e-6, 0)
(-1.99999999990551e-6, 0)
(-1.99999999989751e-6, 0)
(-1.99999999988951e-6, 0)
(-1.99999999988151e-6, 0)
(-1.99999999987351e-6, 0)
(-1.99999999986551e-6, 0)
(-1.99999999985751e-6, 0)
(-1.99999999984951e-6, 0)
(-1.99999999984151e-6, 0)
(-1.99999999983350e-6, 0)
(-1.99999999982550e-6, 0)
(-1.99999999981750e-6, 0)
(-1.99999999980950e-6, 0)
(-1.99999999980150e-6, 0)
(-1.99999999979350e-6, 0)
(-1.99999999978550e-6, 0)
(-1.99999999977750e-6, 0)
(-1.99999999976950e-6, 0)
(-1.99999999976150e-6, 0)
(-1.99999999