In [None]:
tensorflow_version 2.x

In [None]:
import tensorflow as tf
from tensorflow import keras

In [None]:
from tensorflow.keras.utils import plot_model

In [None]:
# define custom class
class MyLayer(tf.keras.layers.Layer):

  def __init__(self, output_dim, ** kwargs):
      self.output_dim = output_dim
      super(MyLayer, self).__init__( ** kwargs)

  def build(self, input_shape):
      self.W = self.add_weight(
                  name = 'kernel',
                  shape = (input_shape[1], 
                           self.output_dim),
                  initializer = 'uniform',
                  trainable = True)
      self.built = True

  def call(self, x):
      return tf.matmul(x, self.W)

  def compute_output_shape(self, input_shape):
      return (input_shape[0], self.output_dim)

In [None]:
# define model containing custom layer
model = tf.keras.Sequential([tf.keras.layers.Dense(256 , input_shape=(784,)),
                             tf.keras.layers.Dense(256, activation = 'relu'),
                             MyLayer(10),
                             tf.keras.layers.Dense(10 , activation = 'softmax')])

In [None]:
plot_model (model)

In [None]:
batch_size = 128
num_classes = 10
epochs = 20

In [None]:
# load dataset
(x_train , y_train) , (x_test , y_test) = tf.keras.datasets.mnist.load_data()

In [None]:
# preprocessing
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

In [None]:
# treat categorical columns
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

In [None]:
# compile
model.compile(loss='categorical_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

In [None]:
# train
model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test))