# Creating our own Customized layer 

In [1]:
from keras import backend as K
from keras.layers import Layer
# Layer is a base class and we will be sub-classing it to crearte a layer.

class Kushal(Layer):
    
    # initializing the layer class.
    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(Kushal, self).__init__(**kwargs) # calls the base init function
        
    # build is the main method and its purpose is to implement the layer properly.
    def build(self, input_shape):
        # Below command creates the weight corresponding to input shape and set it in the kernel.
        # It is our custom functionality of the layer.
        # It creates the weight using ‘normal’ initializer.
        self.kernel = self.add_weight(name='kernel', shape=(input_shape[1], self.output_dim), initializer='normal', trainable=True)
        super(Kushal, self).build(input_shape) # calls the base build function
        
    def call(self,input_data):  # input_data is the input data of our customized layer.
        return K.dot(input_data, self.kernel) # dot product of input data and our layers kernel.
    
    def compute_output_shape(self, input_shape):
        # computes the output shape using shape of input data and output dimension set while initializing the layer.
        return (input_shape[0], self.output_dim)

Using TensorFlow backend.


# Checking whether our layer named Kushal created or not

In [2]:
from keras.models import Sequential 
from keras.layers import Dense 

model = Sequential()
model.add(Kushal(32, input_shape = (16,))) 
model.add(Dense(8, activation = 'sigmoid'))
model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
kushal_1 (Kushal)            (None, 32)                512       
_________________________________________________________________
dense_1 (Dense)              (None, 8)                 264       
Total params: 776
Trainable params: 776
Non-trainable params: 0
_________________________________________________________________
