In [57]:
class Layer:
  def __init__(self):
    self.input = None
    self.output = None

  def forward_propogation(self,input):
    raise NotImplementedError
  
  def backward_propogation(self,output_error,learning_rate):
    raise NotImplementedError


In [58]:
import numpy as np

class ConnLayer(Layer):
  def __init__(self,input_size,output_size):
    self.weights = np.random.rand(input_size,output_size) - 0.5
    self.bias = np.random.rand(1,output_size) - 0.5

  def forward_propogation(self,input_data):
    self.input = input_data
    self.output = np.dot(self.input,self.weights) + self.bias
    return self.output

  def backward_propogation(self,output_error,learning_rate):
    input_error = np.dot(output_error,self.weights.T)
    weights_error = np.dot(self.input.T,output_error)

    self.weights -= learning_rate * weights_error
    self.bias -= learning_rate * output_error
    return input_error


In [59]:

class ActivationLayer(Layer):
    def __init__(self, activation, activation_prime):
        self.activation = activation
        self.activation_prime = activation_prime

    def forward_propogation(self, input_data):
        self.input = input_data
        self.output = self.activation(self.input)
        return self.output

    def backward_propogation(self, output_error, learning_rate):
        return self.activation_prime(self.input) * output_error


def sig(x):
    return 1/(1 + np.exp(-x))

def sig_prime(x):
    f = sig(x)
    return f*(1-f)

def mse(y_true, y_pred):
    return np.mean(np.power(y_true-y_pred, 2));

def mse_prime(y_true, y_pred):
    return 2*(y_pred-y_true)/y_true.size;

In [60]:
class Network:
    def __init__(self):
        self.layers = []
        self.loss = None
        self.loss_prime = None

    # add layer to network
    def add(self, layer):
        self.layers.append(layer)

    def use(self, loss, loss_prime):
        self.loss = loss
        self.loss_prime = loss_prime

    # predict output for given input
    def predict(self, input_data):
        samples = len(input_data)
        result = []

        for i in range(samples):
            # forward propagation
            output = input_data[i]
            for layer in self.layers:
                output = layer.forward_propogation(output)
            result.append(output)

        return result

    # train the network
    def fit(self, x_train, y_train, epochs, learning_rate):
        samples = len(x_train)

        for i in range(epochs):
            err = 0
            for j in range(samples):
                # forward propagation
                output = x_train[j]
                for layer in self.layers:
                    output = layer.forward_propogation(output)

                err += self.loss(y_train[j], output)

                # backward propagation
                error = self.loss_prime(y_train[j], output)
                for layer in reversed(self.layers):
                    error = layer.backward_propogation(error, learning_rate)

            # calculate average error on all samples
            err /= samples
            print('epoch %d/%d   error=%f' % (i+1, epochs, err))

In [61]:


# training data
x_train = np.array([[[0,0]], [[0,1]], [[1,0]], [[1,1]]])
y_train = np.array([[[0]], [[1]], [[1]], [[0]]])

# network
net = Network()
net.add(ConnLayer(2, 3))
net.add(ActivationLayer(sig, sig_prime))
net.add(ConnLayer(3, 1))
net.add(ActivationLayer(sig, sig_prime))

# train
net.use(mse, mse_prime)
net.fit(x_train, y_train, epochs=500, learning_rate=0.1)

# test
out = net.predict(x_train)
print(out)

epoch 1/500   error=0.258947
epoch 2/500   error=0.258421
epoch 3/500   error=0.257975
epoch 4/500   error=0.257599
epoch 5/500   error=0.257281
epoch 6/500   error=0.257012
epoch 7/500   error=0.256785
epoch 8/500   error=0.256594
epoch 9/500   error=0.256432
epoch 10/500   error=0.256296
epoch 11/500   error=0.256181
epoch 12/500   error=0.256084
epoch 13/500   error=0.256002
epoch 14/500   error=0.255932
epoch 15/500   error=0.255873
epoch 16/500   error=0.255824
epoch 17/500   error=0.255781
epoch 18/500   error=0.255746
epoch 19/500   error=0.255715
epoch 20/500   error=0.255689
epoch 21/500   error=0.255667
epoch 22/500   error=0.255648
epoch 23/500   error=0.255632
epoch 24/500   error=0.255618
epoch 25/500   error=0.255606
epoch 26/500   error=0.255596
epoch 27/500   error=0.255587
epoch 28/500   error=0.255579
epoch 29/500   error=0.255572
epoch 30/500   error=0.255566
epoch 31/500   error=0.255560
epoch 32/500   error=0.255556
epoch 33/500   error=0.255551
epoch 34/500   erro