In [151]:
import numpy as np
import matplotlib.pyplot as plt

In [152]:
def sigmoid(z):
    return 1 / (1 + np.exp(-z)) # If z is big, exp is small and that forces sigmoid to be close to one

In [153]:
x = np.array([[2, 4, 6], [3, 5, 7], [12, 14, 16], [8, 10, 12], [9, 11, 13], [13, 15, 17]])
y = np.array([1, 0, 1, 1, 0, 0])

In [154]:
xs = x.reshape(x.shape[0], -1).T
ys = y.reshape(y.shape[0], -1).T

In [155]:
w = np.zeros((xs.shape[0], 1))
b = 0
m = xs.shape[0]

In [156]:
def propogate(xs, ys, w, b):
    # forward pass 
    y_hat = sigmoid(np.dot(w.T, xs) + b)
    
    # binary cross entropy loss
    J = - (1 / m) * np.sum((ys * np.log(y_hat)) + ((1 - ys) * np.log(1 - y_hat)))
    
    # backword pass
    dw = (1 / m) * np.dot(xs, (y_hat - ys).T)
    db = (1 / m) * np.sum(y_hat - ys)
    
    grads = {
        "dw": dw,
        "db": db
    }
    
    return grads, J

In [157]:
costs = []

def optimize(xs, ys, w, b, num_iters, learning_rate, print_cost = False):
    for i in range(num_iters):
        # cost and gradient calculation
        grads, cost = propogate(xs, ys, w, b)
        
        # Retrieve derivatives from grads
        dw = grads["dw"]
        db = grads["db"]
        
        # Update rule
        w = w - learning_rate * dw
        b = b - learning_rate * db
        
        # Record the costs
        if i % 100 == 0:
            costs.append(cost)
        
        # Print the cost every 100 training iterations
        if print_cost and i % 100 == 0:
            print ("Cost after iteration %i: %f" %(i, cost))
            
    params = {"w": w,
              "b": b}
    
    grads = {"dw": dw,
             "db": db}
    
    return params, grads, costs

In [161]:
optimize(xs, ys, w, b, 5000, 0.001)

({'w': array([[-0.14250302],
         [-0.01894932],
         [ 0.10460437]]), 'b': 0.061776848012697866},
 {'dw': array([[ 0.00364435],
         [ 0.00041958],
         [-0.00280519]]), 'db': -0.0016123835815278023},
 [1.3862943611198904,
  3.5211437807338513,
  6.355271925270024,
  5.742947642422429,
  4.065789293897368,
  6.405492856040816,
  10.332823455080547,
  4.113157892045001,
  6.352467654644284,
  10.33336489651396,
  1.3862943611198904,
  3.5211437807338513,
  6.355271925270024,
  5.742947642422429,
  4.065789293897368,
  6.405492856040816,
  10.332823455080547,
  4.113157892045001,
  6.352467654644284,
  10.33336489651396,
  4.11301247956539,
  6.352582390151142,
  10.333360262479498,
  4.113012570564473,
  6.352582239180305,
  10.333360273019007,
  4.113012570614284,
  6.352582239337231,
  10.33336027300479,
  4.113012570613891,
  6.3525822393371785,
  10.33336027300479,
  4.113012570613891,
  6.3525822393371785,
  10.33336027300479,
  4.113012570613891,
  6.3525822393371