In [6]:
from sympy import symbols, diff, cos, sin, N, Function
import numpy as np
import matplotlib.pyplot as plt
import math



In [16]:
def accelerated_gradient(gradient, y0, alpha, n_iterations=1900):

    ys = [y0]
    xs = [y0]
    count = 0
    for t in range(1, n_iterations+1):
        y = ys[-1]
        x = xs[-1]
        g = gradient(y)
        x_plus = y - alpha(y, g) * g
        y_plus = x_plus + ((t-1)/float(t+2)) * (x_plus - x)
        ys.append(y_plus)
        xs.append(x_plus)
        count += 1
    return xs

class BacktrackingLineSearch(object):

    def __init__(self, function):
        self.function = function
        self.alpha    = 0.25

    def __call__(self, y, g):
        f = self.function
        a = self.alpha
        while f(y - a * g) > f(y) - 0.5 * a * (g*g):
            a *= 0.99
        return a

if __name__ == '__main__':
    import os

    import numpy as np

  ### ACCELERATED GRADIENT ###

  # problem definition
    function = lambda x: x ** 4     # the function to minimize
    gradient = lambda x: 4 * x **3  # its gradient
    alpha = BacktrackingLineSearch(function)
    x0 = 1.0
    n_iterations = 100

  # run gradient descent
    iterates = accelerated_gradient(gradient, x0, alpha, n_iterations=n_iterations)

    for i in range(len(iterates)):
        print iterates[i]

1.0
0.547956349734
0.38342907965
0.302191176538
0.250079421927
0.212780576124
0.184447849508
0.162104160669
0.144018318074
0.129091644114
0.116582976293
0.105969499313
0.0968696932042
0.0889976321313
0.0821343743332
0.0761092615814
0.0707872591989
0.0660601333036
0.0618401509432
0.0580554878179
0.0546468207171
0.0515647596127
0.0487678859851
0.0462212360362
0.0438951151203
0.0417641619617
0.0398066034393
0.0380036562865
0.0363390431416
0.0347985983816
0.0333699450188
0.0320422282649
0.030805894594
0.0296525075746
0.0285745935946
0.0275655120265
0.0266193454821
0.0257308066612
0.024895158975
0.0241081486531
0.0233659464655
0.022665097528
0.0220024779301
0.0213752571402
0.0207808653237
0.0202169648497
0.0196814253824
0.0191723020485
0.018687816254
0.0182263387856
0.0177863748927
0.017366551086
0.016965603431
0.0165823671454
0.0162157673359
0.0158648107345
0.0155285783109
0.0152062186581
0.0148969420585
0.0146000151514
0.0143147561334
0.0140405304319
0.0137767467978
0.0135228537725
0.0132