# Gradient Descent

**Import Libraries**

In [1]:
import pandas as pd

**Cost Function (MSE)**

In [2]:
def cost_function(Y, b, w, X): 
    
    m = len(Y)
    sse = 0
    for i in range(0, m):
        y_hat = b + w * X[i]
        y = Y[i]
        sse += (y_hat - y) ** 2   
    mse = sse / m  
    
    return mse

**Update Weights Function**

In [3]:
def update_weights(Y, b, w, X, learning_rate):
    
    m = len(Y)
    b_deriv_sum = 0
    w_deriv_sum = 0
    for i in range(0, m):
        y_hat = b + w * X[i]
        y = Y[i]
        b_deriv_sum += (y_hat - y)
        w_deriv_sum += (y_hat - y) * X[i]
    new_b = b - (learning_rate * (1/m) * b_deriv_sum)
    new_w = w - (learning_rate * (1/m) * w_deriv_sum)
    
    return new_b, new_w

**Train Function**

In [4]:
def train(Y, initial_b, initial_w, X, learning_rate, num_iters):
    
    print('Starting gradient descent at b={0}, w={1}, mse={2}'.format(initial_b, initial_w,
                                                                     cost_function(Y, initial_b, initial_w, X)))
    b = initial_b
    w = initial_w
    cost_history = []
    for i in range(num_iters):
        b, w = update_weights(Y, b, w, X, learning_rate)
        mse = cost_function(Y, b, w, X)
        cost_history.append(mse)
        
        if i%100 == 0:
            print('iter={:d}    b={:.2f}    w={:.4f}    mse={:.4f}'.format(i, b, w, mse))
            
    print('After {0} iterations b={1}    w={2}    mse={3}'.format(num_iters, b, w, cost_function(Y, b, w, X)))
    
    return cost_history, b, w  

**Read csv file**

In [5]:
df = pd.read_csv('advertising.csv')

In [6]:
df.head()

Unnamed: 0,TV,radio,newspaper,sales
0,230.1,37.8,69.2,22.1
1,44.5,39.3,45.1,10.4
2,17.2,45.9,69.3,9.3
3,151.5,41.3,58.5,18.5
4,180.8,10.8,58.4,12.9


**Feature and Label**

In [7]:
X = df['radio']
Y = df['sales']

**Hyperparameters**

In [8]:
learning_rate = 0.001
initial_b = 0.001
initial_w = 0.001
num_iters = 10000

In [9]:
train(Y, initial_b, initial_w, X, learning_rate, num_iters)

Starting gradient descent at b=0.001, w=0.001, mse=222.9477491673001
iter=0    b=0.01    w=0.3708    mse=53.2540
iter=100    b=0.28    w=0.4788    mse=41.6028
iter=200    b=0.54    w=0.4709    mse=40.2861
iter=300    b=0.79    w=0.4633    mse=39.0433
iter=400    b=1.03    w=0.4559    mse=37.8700
iter=500    b=1.27    w=0.4487    mse=36.7624
iter=600    b=1.49    w=0.4417    mse=35.7169
iter=700    b=1.72    w=0.4349    mse=34.7299
iter=800    b=1.93    w=0.4283    mse=33.7981
iter=900    b=2.14    w=0.4219    mse=32.9186
iter=1000    b=2.35    w=0.4157    mse=32.0883
iter=1100    b=2.54    w=0.4096    mse=31.3045
iter=1200    b=2.74    w=0.4037    mse=30.5646
iter=1300    b=2.92    w=0.3980    mse=29.8662
iter=1400    b=3.10    w=0.3925    mse=29.2068
iter=1500    b=3.28    w=0.3871    mse=28.5844
iter=1600    b=3.45    w=0.3818    mse=27.9968
iter=1700    b=3.62    w=0.3767    mse=27.4422
iter=1800    b=3.78    w=0.3718    mse=26.9186
iter=1900    b=3.94    w=0.3670    mse=26.4243
ite

([53.25401123914189,
  43.56776882023221,
  43.00212068546803,
  42.95636329298037,
  42.94024797669445,
  42.92583007024646,
  42.91151669535878,
  42.897217051903226,
  42.88292595970384,
  42.86864311900961,
  42.85436850826132,
  42.840102121759344,
  42.825843954710805,
  42.81159400237707,
  42.79735226002503,
  42.78311872292461,
  42.76889338634844,
  42.75467624557186,
  42.74046729587293,
  42.72626653253241,
  42.712073950833854,
  42.69788954606341,
  42.68371331351007,
  42.66954524846553,
  42.65538534622403,
  42.64123360208273,
  42.627090011341444,
  42.61295456930256,
  42.598827271271354,
  42.58470811255565,
  42.570597088466194,
  42.55649419431616,
  42.54239942542163,
  42.528312777101284,
  42.514234244676544,
  42.50016382347146,
  42.48610150881289,
  42.4720472960303,
  42.4580011804558,
  42.443963157424406,
  42.42993322227355,
  42.4159113703435,
  42.40189759697717,
  42.387891897520234,
  42.37389426732086,
  42.35990470173017,
  42.345923196101694,
  42