# The Model Subclassing API

In [1]:
import tensorflow as tf

In [2]:
img_rows = 28
img_cols = 28

In [3]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)

In [4]:
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255

In [5]:
class CNNClassifier(tf.keras.Model):
    
    def __init__(self, num_classes):
        super(CNNClassifier, self).__init__(name='CNNClassifier')
        self.num_classes = num_classes
        self.conv = tf.keras.layers.Conv2D(32, kernel_size=(3, 3),
                                             activation='relu')
        self.max_pool = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))
        self.dropout_1 = tf.keras.layers.Dropout(.25)
        self.flatten = tf.keras.layers.Flatten()
        self.dense_1 = tf.keras.layers.Dense(128, activation='relu')
        self.dropout_2 = tf.keras.layers.Dropout(.5)
        self.dense_2 = tf.keras.layers.Dense(num_classes, activation='softmax')
        
    def call(self, input_tensor, training=False):
        x = self.conv(input_tensor)
        x = self.max_pool(x)
        if training:
            x = self.dropout_1(x, training=training)
        x = self.flatten(x)
        x = self.dense_1(x)
        if training:
            x = self.dropout_2(x, training=training)
        
        return self.dense_2(x)


In [6]:
minst_clf = CNNClassifier(num_classes=10)
minst_clf.compile(optimizer="rmsprop",
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [7]:
minst_clf.fit(X_train, tf.keras.utils.to_categorical(y_train), batch_size=32, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x234861bd0b8>

In [9]:
minst_clf.evaluate(X_test, tf.keras.utils.to_categorical(y_test))



[0.04672080550612882, 0.9853]