In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28 * 28).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28 * 28).astype("float32") / 255.0

In [3]:
class Dense(layers.Layer):
    def __init__(self, units):
        super(Dense, self).__init__()
        self.units = units
       
    def build(self, input_shape):
        self.w = self.add_weight(
            name = 'w',
            shape= (input_shape[-1], self.units),
            initializer = 'random_normal',
            trainable = True,
        )
        
        self.b = self.add_weight(
            name = 'b',
            shape = (self.units,),
            initializer = 'zeros',
            trainable = True
        )
        
    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b
    
    

In [4]:
class MyReLU(layers.Layer):
    def __init__(self):
        super(MyReLU, self).__init__()
        
    def call(self, x):
        return tf.math.maximum(x, 0)

In [5]:
class MyModel(keras.Model):
    def __init__(self, num_classes=10):
        super(MyModel, self).__init__()
        self.dense1 = Dense(64)
        self.dense2 = Dense(num_classes)
        self.relu = MyReLU()
        #self.dense1 = layers.Dense(64)
        #self.dense2 = layers.Dense(num_classes)
        
    def call(self, input_tensor):
        x = self.relu(self.dense1(input_tensor))
        return self.dense2(x)

In [6]:
model = MyModel(10)
model.compile(
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer = keras.optimizers.Adam(),
    metrics = ["accuracy"]
)

In [7]:
model.fit(x_train, y_train, batch_size=32, epochs=2, verbose=2)

Train on 60000 samples
Epoch 1/2
60000/60000 - 6s - loss: 0.3363 - accuracy: 0.9086
Epoch 2/2
60000/60000 - 5s - loss: 0.1644 - accuracy: 0.9527


<tensorflow.python.keras.callbacks.History at 0x1c1085d1b08>

In [8]:
model.evaluate(x_test, y_test, batch_size=32, verbose=2)

10000/10000 - 1s - loss: 0.1454 - accuracy: 0.9576


[0.14536936699151992, 0.9576]