In [6]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

In [9]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28 *28).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28 *28).astype("float32") / 255.0

In [23]:
class Dense(layers.Layer):
    
    def __init__(self, units):
        super(Dense, self).__init__()
        self.units = units
    
        
    def build(self, input_shape):
        
        self.w = self.add_weight(
            name = 'w',
            shape=(input_shape[-1], self.units),
            initializer= 'random_normal',
            trainable=True,
            
        )
        
        self.b = self.add_weight(
            name='b', shape=(self.units,),
            initializer='zeros',
            trainable=True
        )
        
        
    def call(self, inputs):
        
        return tf.matmul(inputs, self.w) + self.b

In [28]:
class MyReLu(layers.Layer):
    def __init__(self):
        
        super(MyReLu, self).__init__()
        
    def call(self,x):
        return tf.math.maximum(x,0)
    

In [29]:
class MyModel(keras.Model):
    
    def __init__(self, num_classes = 10):
        
        super(MyModel, self).__init__()
        self.dense1 = Dense(64)
        self.dense2 = Dense(num_classes)
        self.relu = MyReLu()
        #self.dense1 = layers.Dense(64)
        #self.dense2 = layers.Dense(num_classes)
        
    def call(self, input_tensor):
        x = self.relu(self.dense1(input_tensor))
        return self.dense2(x)

In [30]:
model = MyModel()
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer = keras.optimizers.Adam(),
    metrics = ["accuracy"],
)

In [31]:
model.fit(x_train,y_train, batch_size=32, epochs=2, verbose=2)
model.evaluate(x_test,y_test, batch_size=32, verbose=2)

Epoch 1/2
1875/1875 - 2s - loss: 0.3478 - accuracy: 0.9045 - 2s/epoch - 1ms/step
Epoch 2/2
1875/1875 - 2s - loss: 0.1632 - accuracy: 0.9530 - 2s/epoch - 821us/step
313/313 - 0s - loss: 0.1349 - accuracy: 0.9606 - 342ms/epoch - 1ms/step


[0.13494901359081268, 0.9606000185012817]

In [32]:
model.save_weights('save_model/')