In [1]:
import numpy as np

def xavier_init(in_size, out_size):
    limit = 1/(np.sqrt(in_size))
    return np.random.uniform(low=-limit, high=limit, size=(out_size, in_size))


def he_init(in_size, out_size):
    return np.random.normal(0, np.sqrt(2.0/in_size), size=(out_size, in_size))


# A single layer of neurons
class Layer():
    
    def __init__(self, input_size, output_size, activation, backward_activation, weight_init = "he"):
        # 'xavier' or 'he'
        self.input_size = input_size
        self.output_size = output_size
        if (weight_init == "he"):
            self.weights = he_init(input_size, output_size) 
        else:
            self.weights = xavier_init(input_size, output_size) 
        self.bias = np.random.randn(output_size, 1) * 0.1
        self.activation = activation
        self.backward_activation = backward_activation
        self.a_prev = []
        self.z = []
        self.dW = []
        self.db = []
        self.dA = []
        
        
    def update(self, lr):
        self.weights = self.weights - lr * self.dW
        self.bias = self.bias - lr * self.db
            
            
    # The propagation for a single layer (all neurons) using a matrix of weigths
    # a vector of input, and a vector of biases.
    def forward_propagation(self, input_data):
        # Calculate the weighted sum of input and weights
        self.a_prev = input_data
        self.z = np.dot(self.weights, self.a_prev) + self.bias #T # output needs to be vector, here cannot be sum
        
        # Apply activation function
        return self.activation(self.z)
    
    
    def backward_propagation(self, dA):
        m = self.a_prev.shape[1]
        
        # calculate weights & bias
        dZ = self.backward_activation(dA, self.z)
        self.dW = np.dot(dZ, self.a_prev.T) / m  # save gradient of weights
        self.db = np.sum(dZ, axis=1, keepdims=True) / m  # save gradient of bias 
        self.dA = np.dot(self.weights.T, dZ)  # save next dA
        
        return self.dA, self.dW, self.db