In [3]:
import numpy as np

In [4]:
x_train = np.array([1.0, 2.0])
y_train = np.array([300.0, 500.0])

In [5]:
def cost_function(x_train, y_train, w, b):
    
    m = x_train.shape[0]
    cost = 0
    
    for i in range (m):
        cost = cost + ((w * x_train[i] + b) - y_train[i])**2
    total_cost = (1/(2*m)) * cost
    
    return total_cost

In [6]:
def compute_gradient(x_train, y_train, w, b):
    
    m = x_train.shape[0]
    dj_dw = 0
    dj_db = 0
    
    for i in range (m):
        dj_dw_i = ((w * x_train[i] + b) - y_train[i]) * x_train[i]
        dj_db_i = (w * x_train[i] + b) - y_train[i]
        dj_dw = dj_dw + dj_dw_i
        dj_db = dj_db + dj_db_i
        
    dj_dw = dj_dw/m
    dj_db = dj_db/m
    
    return dj_dw, dj_db

In [7]:
def gradient_descent(x_train, y_train, w_init, b_init, iterations, alpha):
    
    J_hist = []
    w = w_init
    b = b_init
    
    for i in range (iterations):
        dj_dw, dj_db = compute_gradient(x_train, y_train, w, b)
        w = w - alpha * dj_dw
        b = b - alpha * dj_db
        
        J_hist.append(cost_function(x_train, y_train, w, b))
        
    return w, b, J_hist

In [8]:
w_init = 0
b_init = 0
iterations = 1000
alpha = 1.0e-1
w_final, b_final, J_hist = gradient_descent(x_train, y_train, w_init, b_init, iterations, alpha)

In [9]:
w_final, b_final

(199.9930208199552, 100.01129255052611)

In [10]:
J_hist

[36731.25,
 15877.265625,
 6867.464765625002,
 2974.7862325195297,
 1292.8934389768071,
 566.1426188404786,
 252.04969997801044,
 116.24169135612853,
 57.46086242104911,
 31.960145876394353,
 20.83918547363745,
 15.932313551094376,
 13.711640998365713,
 12.652969043845532,
 12.097766766840465,
 11.761503740522716,
 11.521231740166298,
 11.323812264759727,
 11.146267464222902,
 10.97865028498776,
 10.816643741905665,
 10.658363528753364,
 10.502976624868193,
 10.35010447861183,
 10.199565156790765,
 10.051261939828962,
 9.905135186921623,
 9.761141536192111,
 9.619244914625524,
 9.479412650313236,
 9.341613789136423,
 9.20581836390118,
 9.071997075018702,
 8.940121149043303,
 8.810162274098094,
 8.682092568571598,
 8.55588456423881,
 8.43151119566047,
 8.308945792345563,
 8.18816207215246,
 8.069134135273774,
 7.951836458518722,
 7.836243889769112,
 7.722331642555574,
 7.610075290728943,
 7.49945076321742,
 7.390434338862366,
 7.283002641330246,
 7.177132634100417,
 7.072801615524524,
 