In [2]:
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist

In [13]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28 * 28).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28 * 28).astype("float32") / 255.0

In [35]:
class Dense(layers.Layer):
  def __init__(self, units):
    super(Dense, self).__init__()
    self.units = units

  def build(self, input_shape):
    self.w = self.add_weight(
        name = "w",
        shape = (input_shape[-1], self.units),
        initializer = "random_normal",
        trainable = True,
    )
    self.b = self.add_weight(
        name = "b",
        shape = (self.units, ),
        initializer = "zeros",
        trainable = True,
    )


  def call(self, inputs):
    return tf.matmul(inputs, self.w) + self.b

In [37]:
class MyRelu(layers.Layer):
  def __init__(self):
    super(MyRelu, self).__init__()

  def call(self, x):
    return tf.math.maximum(x, 0)

In [38]:
class MyModel(keras.Model):
  def __init__(self, num_classes = 10):
    super(MyModel, self).__init__()
    self.dense1 = Dense(64)
    self.dense2 = Dense(num_classes)
    self.relu = MyRelu()
    
    #self.dense1 = layers.Dense(64)
    #self.dense2 = layers.Dense(num_classes)

  def call(self, input_tensor):
    x = self.relu(self.dense1(input_tensor))
    return self.dense2(x)

In [39]:
model = MyModel()

In [40]:
model.compile(
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits = True),
    optimizer = keras.optimizers.Adam(),
    metrics = ["accuracy"]
)

In [41]:
model.fit(
    x_train, y_train, batch_size = 64, epochs = 2
)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7effd530ad30>

In [42]:
model.summary()

Model: "my_model_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_10 (Dense)            multiple                  50240     
                                                                 
 dense_11 (Dense)            multiple                  650       
                                                                 
 my_relu (MyRelu)            multiple                  0         
                                                                 
Total params: 50,890
Trainable params: 50,890
Non-trainable params: 0
_________________________________________________________________
