In [9]:
#Base Class
#The abstract class Layer, which all other layers will inherit from, 
#handles simple properties which are an input, an output, and both a forward and backward methods.
class Layer:
    def __init__(self):
        self.input=None
        self.output=None
        
#computes the output Y of a layer for a given input X
    def forward_propogation(self,input):
        raise NotImplementedError
        
#compute dE/dX for a given dE/dY(and update parameters if any)
    def backward_propogation(self,output_error,learning_rate):
        raise NotImplementedError

In [14]:
import numpy as np
class FCLayer(Layer):
    # input_size = number of input neurons
    # output_size = number of output neurons
    def __init__(self,input_size,output_size):
        self.weights=np.random.rand(input_size,output_size)-0.5
        self.bias=np.random.rand(1,output_size)-0.5
        
    def forward_propogation(self,input_data):
        self.input=input_data
        self.output=np.dot(self.input,self.weights) + self.bias
        return self.output
    
    # computes dE/dW, dE/dB for a given output_error=dE/dY. Returns input_error=dE/dX.
    def backward_propogation(self,output_error,learning_rate):
        input_error=np.dot(output_error,self.weights.T)
        weights_error=np.dot(self.input.T,output_error)
        #dBias=output_error
        
        #update parameters
        self.weights -= weights_error*learning_rate
        self.bias = self.bias-(output_error*learning_rate)
        return input_error

In [4]:
#ActivationLayer
class ActivationLayer(Layer):
    def __init__(self,activation,activation_prime):
        self.activation=activation
        self.activation_prime=activation_prime
        
    #returns activated output
    def forward_propogation(self,input_data):
        self.input=input_data
        self.output=self.activation(self.input)
        return self.output
    
    # Returns input_error=dE/dX for a given output_error=dE/dY.
    # learning_rate is not used because there is no "learnable" parameters.
    def backward_propogation(self,output_error,learning_rate):
        return self.activation_prime(self.input)*output_error
    

In [5]:
#ActivationFuntions

def tanh(x):
    return np.tanh(x)

def tanh_prime(x):
    return 1-np.tanh(x)**2

In [6]:
# loss function and its derivative we use mse here

def mse(y_true,y_pred):
    return np.mean(np.power(y_true-y_pred,2))

def mse_prime(y_true,y_pred):
    return 2*(y_pred-y_true)/y_true.size

In [25]:
class Network:
    def __init__(self):
        self.layers=[]
        self.loss=None
        self.loss_prime=None
    
    #add layer to network
    def add(self,layer):
        self.layers.append(layer)
        
    #set loss func to use
    def use(self,loss,loss_prime):
        self.loss=loss
        self.loss_prime=loss_prime
    
    #predict output for given input
    def predict(self,input_data):
        n_samples=len(input_data)
        result=[]
        
        for i in range(n_samples):
            output=input_data[i]
            
            for layer in self.layers:
                output=layer.forward_propogation(output)
            result.append(output)
            
        return result
    
    #train the nerwork
    def fit(self,X_train,y_train,epochs,learning_rate):
        n_samples=len(X_train)
        
        #training loop
        for i in range(epochs):
            err=0
            for j in range(n_samples):
                #forward propogation
                output=X_train[j]
                for layer in self.layers:
                    output=layer.forward_propogation(output)
                
                # compute loss (for display purpose only)
                err += self.loss(y_train[j],output)
                
                #backward propogation
                error = self.loss_prime(y_train[j],output)
                for layer in reversed(self.layers):
                    error=layer.backward_propogation(error,learning_rate)
                
                #calculate average error on all samples
                err /= n_samples
                print("epoch %d/%d error=%f"%(i+1,epochs,err))

In [26]:
#XOR problem

#training data
X_train=np.array([[[0,0]], [[0,1]], [[1,0]], [[1,1]]])
y_train=np.array([[[0]],  [[1]],   [[1]],   [[1]]])

#network
net=Network()
net.add(FCLayer(2,3))
net.add(ActivationLayer(tanh,tanh_prime))
net.add(FCLayer(3,1))
net.add(ActivationLayer(tanh,tanh_prime))

#train
net.use(mse,mse_prime)
net.fit(X_train,y_train,epochs=1000,learning_rate=0.1)

#test
out=net.predict(X_train)
print(out)

epoch 1/1000 error=0.023031
epoch 1/1000 error=0.137204
epoch 1/1000 error=0.121302
epoch 1/1000 error=0.070473
epoch 2/1000 error=0.080862
epoch 2/1000 error=0.070082
epoch 2/1000 error=0.058265
epoch 2/1000 error=0.035459
epoch 3/1000 error=0.100583
epoch 3/1000 error=0.058768
epoch 3/1000 error=0.043366
epoch 3/1000 error=0.025600
epoch 4/1000 error=0.109338
epoch 4/1000 error=0.054597
epoch 4/1000 error=0.037090
epoch 4/1000 error=0.021093
epoch 5/1000 error=0.113400
epoch 5/1000 error=0.052460
epoch 5/1000 error=0.033779
epoch 5/1000 error=0.018555
epoch 6/1000 error=0.114925
epoch 6/1000 error=0.051122
epoch 6/1000 error=0.031817
epoch 6/1000 error=0.016941
epoch 7/1000 error=0.114863
epoch 7/1000 error=0.050150
epoch 7/1000 error=0.030575
epoch 7/1000 error=0.015824
epoch 8/1000 error=0.113702
epoch 8/1000 error=0.049358
epoch 8/1000 error=0.029757
epoch 8/1000 error=0.014997
epoch 9/1000 error=0.111715
epoch 9/1000 error=0.048650
epoch 9/1000 error=0.029202
epoch 9/1000 error=0

epoch 115/1000 error=0.000189
epoch 116/1000 error=0.000012
epoch 116/1000 error=0.000562
epoch 116/1000 error=0.000700
epoch 116/1000 error=0.000186
epoch 117/1000 error=0.000012
epoch 117/1000 error=0.000554
epoch 117/1000 error=0.000690
epoch 117/1000 error=0.000184
epoch 118/1000 error=0.000012
epoch 118/1000 error=0.000547
epoch 118/1000 error=0.000681
epoch 118/1000 error=0.000181
epoch 119/1000 error=0.000011
epoch 119/1000 error=0.000540
epoch 119/1000 error=0.000673
epoch 119/1000 error=0.000179
epoch 120/1000 error=0.000011
epoch 120/1000 error=0.000533
epoch 120/1000 error=0.000664
epoch 120/1000 error=0.000177
epoch 121/1000 error=0.000011
epoch 121/1000 error=0.000526
epoch 121/1000 error=0.000656
epoch 121/1000 error=0.000174
epoch 122/1000 error=0.000010
epoch 122/1000 error=0.000520
epoch 122/1000 error=0.000648
epoch 122/1000 error=0.000172
epoch 123/1000 error=0.000010
epoch 123/1000 error=0.000513
epoch 123/1000 error=0.000640
epoch 123/1000 error=0.000170
epoch 124/

epoch 210/1000 error=0.000002
epoch 210/1000 error=0.000243
epoch 210/1000 error=0.000302
epoch 210/1000 error=0.000080
epoch 211/1000 error=0.000002
epoch 211/1000 error=0.000241
epoch 211/1000 error=0.000300
epoch 211/1000 error=0.000079
epoch 212/1000 error=0.000002
epoch 212/1000 error=0.000240
epoch 212/1000 error=0.000298
epoch 212/1000 error=0.000079
epoch 213/1000 error=0.000002
epoch 213/1000 error=0.000238
epoch 213/1000 error=0.000296
epoch 213/1000 error=0.000078
epoch 214/1000 error=0.000002
epoch 214/1000 error=0.000237
epoch 214/1000 error=0.000294
epoch 214/1000 error=0.000078
epoch 215/1000 error=0.000002
epoch 215/1000 error=0.000235
epoch 215/1000 error=0.000292
epoch 215/1000 error=0.000077
epoch 216/1000 error=0.000002
epoch 216/1000 error=0.000234
epoch 216/1000 error=0.000291
epoch 216/1000 error=0.000077
epoch 217/1000 error=0.000002
epoch 217/1000 error=0.000233
epoch 217/1000 error=0.000289
epoch 217/1000 error=0.000076
epoch 218/1000 error=0.000002
epoch 218/

epoch 301/1000 error=0.000190
epoch 301/1000 error=0.000050
epoch 302/1000 error=0.000001
epoch 302/1000 error=0.000153
epoch 302/1000 error=0.000190
epoch 302/1000 error=0.000050
epoch 303/1000 error=0.000001
epoch 303/1000 error=0.000152
epoch 303/1000 error=0.000189
epoch 303/1000 error=0.000050
epoch 304/1000 error=0.000001
epoch 304/1000 error=0.000152
epoch 304/1000 error=0.000188
epoch 304/1000 error=0.000049
epoch 305/1000 error=0.000001
epoch 305/1000 error=0.000151
epoch 305/1000 error=0.000187
epoch 305/1000 error=0.000049
epoch 306/1000 error=0.000001
epoch 306/1000 error=0.000151
epoch 306/1000 error=0.000186
epoch 306/1000 error=0.000049
epoch 307/1000 error=0.000001
epoch 307/1000 error=0.000150
epoch 307/1000 error=0.000186
epoch 307/1000 error=0.000049
epoch 308/1000 error=0.000001
epoch 308/1000 error=0.000149
epoch 308/1000 error=0.000185
epoch 308/1000 error=0.000049
epoch 309/1000 error=0.000001
epoch 309/1000 error=0.000149
epoch 309/1000 error=0.000184
epoch 309/

epoch 382/1000 error=0.000142
epoch 382/1000 error=0.000037
epoch 383/1000 error=0.000000
epoch 383/1000 error=0.000115
epoch 383/1000 error=0.000142
epoch 383/1000 error=0.000037
epoch 384/1000 error=0.000000
epoch 384/1000 error=0.000114
epoch 384/1000 error=0.000141
epoch 384/1000 error=0.000037
epoch 385/1000 error=0.000000
epoch 385/1000 error=0.000114
epoch 385/1000 error=0.000141
epoch 385/1000 error=0.000037
epoch 386/1000 error=0.000000
epoch 386/1000 error=0.000114
epoch 386/1000 error=0.000140
epoch 386/1000 error=0.000037
epoch 387/1000 error=0.000000
epoch 387/1000 error=0.000113
epoch 387/1000 error=0.000140
epoch 387/1000 error=0.000037
epoch 388/1000 error=0.000000
epoch 388/1000 error=0.000113
epoch 388/1000 error=0.000140
epoch 388/1000 error=0.000037
epoch 389/1000 error=0.000000
epoch 389/1000 error=0.000112
epoch 389/1000 error=0.000139
epoch 389/1000 error=0.000037
epoch 390/1000 error=0.000000
epoch 390/1000 error=0.000112
epoch 390/1000 error=0.000139
epoch 390/

epoch 528/1000 error=0.000097
epoch 528/1000 error=0.000025
epoch 529/1000 error=0.000000
epoch 529/1000 error=0.000078
epoch 529/1000 error=0.000097
epoch 529/1000 error=0.000025
epoch 530/1000 error=0.000000
epoch 530/1000 error=0.000078
epoch 530/1000 error=0.000096
epoch 530/1000 error=0.000025
epoch 531/1000 error=0.000000
epoch 531/1000 error=0.000078
epoch 531/1000 error=0.000096
epoch 531/1000 error=0.000025
epoch 532/1000 error=0.000000
epoch 532/1000 error=0.000078
epoch 532/1000 error=0.000096
epoch 532/1000 error=0.000025
epoch 533/1000 error=0.000000
epoch 533/1000 error=0.000078
epoch 533/1000 error=0.000096
epoch 533/1000 error=0.000025
epoch 534/1000 error=0.000000
epoch 534/1000 error=0.000077
epoch 534/1000 error=0.000096
epoch 534/1000 error=0.000025
epoch 535/1000 error=0.000000
epoch 535/1000 error=0.000077
epoch 535/1000 error=0.000095
epoch 535/1000 error=0.000025
epoch 536/1000 error=0.000000
epoch 536/1000 error=0.000077
epoch 536/1000 error=0.000095
epoch 536/

epoch 609/1000 error=0.000000
epoch 609/1000 error=0.000067
epoch 609/1000 error=0.000082
epoch 609/1000 error=0.000021
epoch 610/1000 error=0.000000
epoch 610/1000 error=0.000066
epoch 610/1000 error=0.000082
epoch 610/1000 error=0.000021
epoch 611/1000 error=0.000000
epoch 611/1000 error=0.000066
epoch 611/1000 error=0.000082
epoch 611/1000 error=0.000021
epoch 612/1000 error=0.000000
epoch 612/1000 error=0.000066
epoch 612/1000 error=0.000082
epoch 612/1000 error=0.000021
epoch 613/1000 error=0.000000
epoch 613/1000 error=0.000066
epoch 613/1000 error=0.000081
epoch 613/1000 error=0.000021
epoch 614/1000 error=0.000000
epoch 614/1000 error=0.000066
epoch 614/1000 error=0.000081
epoch 614/1000 error=0.000021
epoch 615/1000 error=0.000000
epoch 615/1000 error=0.000066
epoch 615/1000 error=0.000081
epoch 615/1000 error=0.000021
epoch 616/1000 error=0.000000
epoch 616/1000 error=0.000066
epoch 616/1000 error=0.000081
epoch 616/1000 error=0.000021
epoch 617/1000 error=0.000000
epoch 617/

epoch 700/1000 error=0.000018
epoch 701/1000 error=0.000000
epoch 701/1000 error=0.000057
epoch 701/1000 error=0.000070
epoch 701/1000 error=0.000018
epoch 702/1000 error=0.000000
epoch 702/1000 error=0.000057
epoch 702/1000 error=0.000070
epoch 702/1000 error=0.000018
epoch 703/1000 error=0.000000
epoch 703/1000 error=0.000056
epoch 703/1000 error=0.000070
epoch 703/1000 error=0.000018
epoch 704/1000 error=0.000000
epoch 704/1000 error=0.000056
epoch 704/1000 error=0.000069
epoch 704/1000 error=0.000018
epoch 705/1000 error=0.000000
epoch 705/1000 error=0.000056
epoch 705/1000 error=0.000069
epoch 705/1000 error=0.000018
epoch 706/1000 error=0.000000
epoch 706/1000 error=0.000056
epoch 706/1000 error=0.000069
epoch 706/1000 error=0.000018
epoch 707/1000 error=0.000000
epoch 707/1000 error=0.000056
epoch 707/1000 error=0.000069
epoch 707/1000 error=0.000018
epoch 708/1000 error=0.000000
epoch 708/1000 error=0.000056
epoch 708/1000 error=0.000069
epoch 708/1000 error=0.000018
epoch 709/

epoch 803/1000 error=0.000060
epoch 803/1000 error=0.000016
epoch 804/1000 error=0.000000
epoch 804/1000 error=0.000049
epoch 804/1000 error=0.000060
epoch 804/1000 error=0.000016
epoch 805/1000 error=0.000000
epoch 805/1000 error=0.000048
epoch 805/1000 error=0.000060
epoch 805/1000 error=0.000016
epoch 806/1000 error=0.000000
epoch 806/1000 error=0.000048
epoch 806/1000 error=0.000060
epoch 806/1000 error=0.000016
epoch 807/1000 error=0.000000
epoch 807/1000 error=0.000048
epoch 807/1000 error=0.000059
epoch 807/1000 error=0.000016
epoch 808/1000 error=0.000000
epoch 808/1000 error=0.000048
epoch 808/1000 error=0.000059
epoch 808/1000 error=0.000016
epoch 809/1000 error=0.000000
epoch 809/1000 error=0.000048
epoch 809/1000 error=0.000059
epoch 809/1000 error=0.000016
epoch 810/1000 error=0.000000
epoch 810/1000 error=0.000048
epoch 810/1000 error=0.000059
epoch 810/1000 error=0.000015
epoch 811/1000 error=0.000000
epoch 811/1000 error=0.000048
epoch 811/1000 error=0.000059
epoch 811/

epoch 879/1000 error=0.000044
epoch 879/1000 error=0.000054
epoch 879/1000 error=0.000014
epoch 880/1000 error=0.000000
epoch 880/1000 error=0.000044
epoch 880/1000 error=0.000054
epoch 880/1000 error=0.000014
epoch 881/1000 error=0.000000
epoch 881/1000 error=0.000044
epoch 881/1000 error=0.000054
epoch 881/1000 error=0.000014
epoch 882/1000 error=0.000000
epoch 882/1000 error=0.000044
epoch 882/1000 error=0.000054
epoch 882/1000 error=0.000014
epoch 883/1000 error=0.000000
epoch 883/1000 error=0.000044
epoch 883/1000 error=0.000054
epoch 883/1000 error=0.000014
epoch 884/1000 error=0.000000
epoch 884/1000 error=0.000044
epoch 884/1000 error=0.000054
epoch 884/1000 error=0.000014
epoch 885/1000 error=0.000000
epoch 885/1000 error=0.000044
epoch 885/1000 error=0.000054
epoch 885/1000 error=0.000014
epoch 886/1000 error=0.000000
epoch 886/1000 error=0.000044
epoch 886/1000 error=0.000053
epoch 886/1000 error=0.000014
epoch 887/1000 error=0.000000
epoch 887/1000 error=0.000043
epoch 887/

epoch 953/1000 error=0.000040
epoch 953/1000 error=0.000049
epoch 953/1000 error=0.000013
epoch 954/1000 error=0.000000
epoch 954/1000 error=0.000040
epoch 954/1000 error=0.000049
epoch 954/1000 error=0.000013
epoch 955/1000 error=0.000000
epoch 955/1000 error=0.000040
epoch 955/1000 error=0.000049
epoch 955/1000 error=0.000013
epoch 956/1000 error=0.000000
epoch 956/1000 error=0.000040
epoch 956/1000 error=0.000049
epoch 956/1000 error=0.000013
epoch 957/1000 error=0.000000
epoch 957/1000 error=0.000040
epoch 957/1000 error=0.000049
epoch 957/1000 error=0.000013
epoch 958/1000 error=0.000000
epoch 958/1000 error=0.000040
epoch 958/1000 error=0.000049
epoch 958/1000 error=0.000013
epoch 959/1000 error=0.000000
epoch 959/1000 error=0.000040
epoch 959/1000 error=0.000049
epoch 959/1000 error=0.000013
epoch 960/1000 error=0.000000
epoch 960/1000 error=0.000040
epoch 960/1000 error=0.000049
epoch 960/1000 error=0.000013
epoch 961/1000 error=0.000000
epoch 961/1000 error=0.000040
epoch 961/