In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28 * 28).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28 * 28).astype("float32") / 255.0

In [2]:
class Dense(layers.Layer):
  def __init__(self, units):
    super(Dense, self).__init__()
    self.units = units

  def build(self, input_shape):
    self.w = self.add_weight(
        name="w",
        shape=(input_shape[-1], self.units),
        initializer=tf.keras.initializers.HeNormal(),
        trainable=True
      )

    self.b = self.add_weight(
      name="b",
      shape=(self.units, ),
      initializer="zeros",
      trainable=True
    )


  def call(self, input_tensor):
    return tf.matmul(input_tensor, self.w) + self.b


class Relu(layers.Layer):
  def __init__(self):
      super(Relu, self).__init__()

  def call(self, input_tensor):
    return tf.math.maximum(input_tensor,0)
      

In [14]:
class MyModel(keras.Model):  # model.fit, model.evalute, model.predict
  def __init__(self, num_classes=10):
      super(MyModel, self).__init__()
      self.dense1 = Dense(64)
      self.dense2 = Dense(num_classes)
      self.relu = Relu()
  def call(self, x):
      x = self.dense1(x)
      x = self.relu(x)
      return self.dense2(x)

  def model(self):
      input  = keras.Input(shape=(28*28*1,))
      return keras.Model(inputs=[input], outputs=[self.call(input)])

In [15]:
model = MyModel()

In [16]:
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(),
    metrics=["accuracy"]
)

model.fit(x_train, y_train, batch_size=32, epochs=3)
model.evaluate(x_test, y_test, batch_size=32)

Epoch 1/3
Epoch 2/3
Epoch 3/3


[0.1043839231133461, 0.9692000150680542]