In [94]:
import numpy as np
from sklearn.metrics import mean_squared_error

# The neural network class

In [95]:
class NNetwork:
    def __init__(self, layers: list, activate_function='sigmoid', cost_function='square_error'):
        self.layers = layers
        self.network = None
        
        if activate_function == 'sigmoid':
            self._activate_function = lambda x: 1/(1 + np.exp(-x))
            self._derivative_activate_function = lambda x: x * (1.0 - x)
            
        if cost_function == 'square_error':
            self._cost_function = lambda x, y: 1/2 * (x - y)**2
            self._derivative_cost_function = lambda x, y: x - y
        
        self._initialize_network()
    
    def _initialize_network(self):
        self.network = list()
        for i in range(1, len(self.layers)):
            layer = np.random.rand(self.layers[i], self.layers[i-1])
            self.network.append({"weights": layer, "forward": None, "deltas": None})
            
    def _forward_propagation(self, inputs: np.array):
        for layer in self.network:
            weigths = layer['weights']
            inputs = self._activate_function(weigths @ inputs)
            layer['forward'] = inputs
        return inputs
            
    def _back_propagation(self, y):
        for index in reversed(range(len(self.network))):
            layer = self.network[index]
            if index == len(self.network)-1:
                layer['deltas'] = (y - layer['forward']) * self._derivative_activate_function(layer['forward'])
            else:
                next_layer = self.network[index+1]
                layer['deltas'] = self._derivative_activate_function(layer['forward']) * (next_layer['deltas'] @ next_layer['weights'])        
    
    def _update_weights(self, inputs, lr):
        for index, layer in enumerate(self.network):
            if index == 0:
                x0, x1 = np.meshgrid(inputs, layer['deltas'])
                dw = lr * (x0 * x1)
            else:
                x0, x1 = np.meshgrid(self.network[index-1]['forward'], layer['deltas'])
                dw = lr * (x0 * x1)
            
            layer['weights'] += dw
    
    def fit(self, X, y, n_epoch=10000, lr=0.01, verbose_epoch=500):
        print("learning rate = %.3f" % lr)
        for epoch in range(n_epoch):
            for x_, y_ in zip(X, y):
                self._forward_propagation(x_)                
                self._back_propagation(y_)
                self._update_weights(x_, lr)
                
            if epoch % verbose_epoch == 0:
                error = np.sqrt(mean_squared_error(y, self.predict(X)))
                print('epoch=%d, error=%.3f' % (epoch, error))
            elif epoch == n_epoch-1:
                print('epoch=%d, error=%.3f' % (epoch+1, error))
            
    def predict(self, inputs: np.array, around=None):
        output = []
        for x in inputs:
            output.append(self._forward_propagation(x))
        
        output = np.array(output)
        if around is not None:
            output = np.around(output, decimals=around)
        return output

# The 3x3 counter

In [96]:
X = np.array([
    [0, 0, 0],
    [0, 0, 1],
    [0, 1, 0],
    [0, 1, 1],
    [1, 0, 0],
    [1, 0, 1],
    [1, 1, 0],
    [1, 1, 1]
])

y = np.array([
    [0, 0, 1],
    [0, 1, 0],
    [0, 1, 1],
    [1, 0, 0],
    [1, 0, 1],
    [1, 1, 0],
    [1, 1, 1],
    [0, 0, 0]
])

In [97]:
nn = NNetwork([3, 4, 4, 3])
nn.fit(X, y, lr=0.2)

learning rate = 0.200
epoch=0, error=0.599
epoch=500, error=0.433
epoch=1000, error=0.377
epoch=1500, error=0.318
epoch=2000, error=0.280
epoch=2500, error=0.267
epoch=3000, error=0.250
epoch=3500, error=0.198
epoch=4000, error=0.133
epoch=4500, error=0.099
epoch=5000, error=0.078
epoch=5500, error=0.063
epoch=6000, error=0.053
epoch=6500, error=0.046
epoch=7000, error=0.041
epoch=7500, error=0.037
epoch=8000, error=0.034
epoch=8500, error=0.031
epoch=9000, error=0.029
epoch=9500, error=0.028
epoch=10000, error=0.028


In [98]:
nn.predict(X[:], around=0)

array([[0., 0., 1.],
       [0., 1., 0.],
       [0., 1., 1.],
       [1., 0., 0.],
       [1., 0., 1.],
       [1., 1., 0.],
       [1., 1., 1.],
       [0., 0., 0.]])