#                              Momentum Based Gradient Descent

## Implementation of Momentum Based Gradient Descent Algorithm on a toy 2-D dataset which consists of 40 data points.

In [1]:
import pandas as pd
import numpy as np

In [2]:
def f(w,b,x):
    return 1.0/(1.0+np.exp(-(w*x+b)))

In [3]:
def grad_b(w,b,x,y):
    fx=f(w,b,x)
    return (fx-y)*fx*(1-fx)

In [4]:
def grad_w(w,b,x,y):
    fx=f(w,b,x)
    return (fx-y)*fx*(1-fx)*x

#### Using the squared error loss function to compute error 

\begin{align}
loss & = \frac{1}{2} * (y' - y)^2
\end{align}


In [5]:
def error(w,b):# calculate loss/error
    err=0.0
    for x,y in zip(X,Y):
        fx = f(w,b,x)
        err += 0.5*(fx-y)**2
    return err

In [6]:
def do_momentum_gradient_descent(X,Y,init_w,init_b,max_epochs):
    w, b, eta = init_w, init_b, 1.0
    prev_v_w, prev_v_b, gamma = 0, 0, 0.9
    for i in range(max_epochs):
        dw, db = 0,0
        for x,y in zip(X,Y):
            dw += grad_w(w,b,x,y)
            db += grad_b(w,b,x,y)
            
        v_w = gamma * prev_v_w + eta* dw
        v_b = gamma * prev_v_b + eta* db
        w = w - v_w
        b = b - v_b
        prev_v_w = v_w
        prev_v_b = v_b
        print("Epoch{}: Loss={}".format(i,error(w,b)))
    
    return w,b

##### Comparing with Gradient Descent Algorithm of 100 iterations, Momentum based Gradient Algorithm produces better results in 10 iterations.

In [7]:
if __name__=="__main__":
    filename='A2_Q4_data.csv'
    df=pd.read_csv(filename)
    X=df['X']
    Y=df['Y']
    initial_w=1
    initial_b=1
    max_epochs=10
    w,b=do_momentum_gradient_descent(X,Y,initial_w,initial_b,max_epochs)
    error=error(w,b)
    print("Error={}".format(error))

Epoch0: Loss=0.025154357598516662
Epoch1: Loss=0.005742165827785046
Epoch2: Loss=0.004429526156669953
Epoch3: Loss=0.009321023932759602
Epoch4: Loss=0.014098841123542628
Epoch5: Loss=0.01675413689732964
Epoch6: Loss=0.017207092096132772
Epoch7: Loss=0.01607726583922297
Epoch8: Loss=0.01412132787692757
Epoch9: Loss=0.011963769844975991
Error=0.011963769844975991
