This notebook demonstrates the basic behind gradient descent.

In [1]:
import numpy

def gradient_descent(func, grad_func, w_init, n_epochs=100, lr=0.001, verbose=0):
    
    i = 0
    w = w_init

    # conduct a fixed number of steps; other stopping criteria
    # could also be used (e.g., stopping once the difference
    # between the current function value and the one of the 
    # previous iteration becomes very small)
    while i < n_epochs:
        
        # conduct gradient update step!
        delta_w = -lr * grad_func(w)
        w = w + delta_w
                
        if verbose > 0:
            print("f={}; w: {}".format(func(w), w))
            
        # increment counter
        i += 1            
    
    return w

In [2]:
# simple function (e.g., f(w_1,w_2) = w_1*w_1 + w_2*w_2) for d=2)
def f(w):
    return numpy.sum(w*w)

# corresponding gradient
def grad(w):
    return 2*w

In [3]:
# starting point
w_init = numpy.array([10,10])

# learning rate (usually has a big impact!)
lr = 0.1

# apply gradient descent
w_opt = gradient_descent(f, grad, w_init, n_epochs=25, lr=lr, verbose=1)

f=128.0; w: [8. 8.]
f=81.92000000000002; w: [6.4 6.4]
f=52.4288; w: [5.12 5.12]
f=33.554432; w: [4.096 4.096]
f=21.47483648; w: [3.2768 3.2768]
f=13.743895347200002; w: [2.62144 2.62144]
f=8.796093022208003; w: [2.097152 2.097152]
f=5.629499534213123; w: [1.6777216 1.6777216]
f=3.602879701896398; w: [1.34217728 1.34217728]
f=2.305843009213695; w: [1.07374182 1.07374182]
f=1.475739525896765; w: [0.85899346 0.85899346]
f=0.9444732965739295; w: [0.68719477 0.68719477]
f=0.6044629098073148; w: [0.54975581 0.54975581]
f=0.38685626227668146; w: [0.43980465 0.43980465]
f=0.24758800785707613; w: [0.35184372 0.35184372]
f=0.1584563250285287; w: [0.28147498 0.28147498]
f=0.10141204801825837; w: [0.22517998 0.22517998]
f=0.06490371073168535; w: [0.18014399 0.18014399]
f=0.041538374868278626; w: [0.14411519 0.14411519]
f=0.026584559915698323; w: [0.11529215 0.11529215]
f=0.017014118346046925; w: [0.09223372 0.09223372]
f=0.010889035741470033; w: [0.07378698 0.07378698]
f=0.00696898287454082; w: [0