###  Created by Luis Alejandro (alejand@umich.edu)
A logistic regression forward and backward pass. Compares results to analytical computations.

In [1]:
import numpy as np
import numpy.random as rnd

import sys
sys.path.append('../')

from utils.graphs.core import Param
from utils.graphs.core import DataHolder
from utils.graphs.core import Graph
from utils.graphs.core import Operation

from utils.graphs.nodes import linear_node
from utils.graphs.nodes import bias_node
from utils.graphs.nodes import sigmoid_node
from utils.graphs.nodes import bce_node

from sklearn import datasets

In [2]:
dataset = datasets.load_iris()

In [3]:
predictors = dataset['data']
responses = dataset['target'].reshape(-1,1)
responses[responses == 2] = 1
m,d = predictors.shape

In [4]:
X_node = DataHolder()
y_node = DataHolder()
w_node = Param(shape=(d,1))
b_node = Param(shape=(1,1))

In [5]:
r_node = linear_node(X_node,w_node)
z_node = bias_node(r_node,b_node)
h_node = sigmoid_node(z_node)
J_node = bce_node(h_node,y_node)

In [6]:
g = Graph()
g.build(J_node).initialize().feed({X_node:predictors, y_node:responses})

<utils.graphs.core.Graph at 0x242b1e4e348>

In [7]:
g.forward().backward()

<utils.graphs.core.Graph at 0x242b1e4e348>

In [8]:
def sigmoid(x):
    return 1/(1+np.exp(-x))

def compute_cost(w,X,y):
    h = sigmoid(X.dot(w))
    loglikelihood = sum([np.log(prob) if label else np.log(1-prob) for prob,label in zip(h,y)])
    return -loglikelihood

def compute_grad(w,X,y):
    h = sigmoid(X.dot(w))
    grad = (X.T).dot(h - y)
    return grad

In [9]:
X = np.hstack((np.ones((predictors.shape[0],1)), predictors))
w = np.vstack((b_node.value,w_node.value))
y = y_node.value

In [10]:
print('Convetional:', compute_cost(w,X,y))
print('Graph:', J_node.value)

Convetional: [68.8358429]
Graph: [68.8358429]


In [11]:
print('Conventional:', compute_grad(w,X,y).flatten())

Conventional: [ -44.87399624 -276.94107244 -126.88571433 -215.5101709   -72.66582952]


In [12]:
print('Graph:', np.hstack((b_node.gradient.flatten(), w_node.gradient.flatten())))

Graph: [ -44.87399624 -276.94107244 -126.88571433 -215.5101709   -72.66582952]
