In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1,28,28,1).astype('float32')/255.0
x_test = x_test.reshape(-1,28,28,1).astype('float32')/255.0

model = keras.Sequential(
    [
        layers.Input(shape=(28,28,1)),
        layers.Conv2D(64,3,padding='same',activation='relu'),
        layers.Conv2D(128,3,padding='same',activation='relu'),
        layers.Flatten(),
        layers.Dense(10,activation='softmax')
    ],
    name='model'
)

In [3]:
class CustomFit(keras.Model):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def compile(self, optimizer, loss):
        super().compile()
        self.optimizer = optimizer
        self.loss = loss
    
    def train_step(self, data):
        x, y = data
        
        with tf.GradientTape() as tape:
            y_pred = self.model(x, training=True)
            loss = self.loss(y, y_pred)
            #loss = self.compiled_loss(y, y_pred)
        
        training_vars = self.trainable_variables
        gradients = tape.gradient(loss, training_vars)
        
        self.optimizer.apply_gradients(zip(gradients, training_vars))
        acc_metric.update_state(y, y_pred)
        #self.compiled_metrics.update_state(y, y_pred)
        
        #return {m.name: m.result() for m in self.metrics}
        return {'loss':loss, 'accuracy':acc_metric.result()}
    
    def test_step(self, data):
        x, y = data
        
        y_pred = self.model(x, training=False)
        loss = self.loss(y, y_pred)
        acc_metric.update_state(y, y_pred)
        
        return {'loss':loss, 'accuracy':acc_metric.result()}

In [4]:
acc_metric = keras.metrics.SparseCategoricalAccuracy(name='accuracy')

training = CustomFit(model)

training.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam())

training.fit(x_train, y_train, batch_size=32, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x2661c01f850>

In [5]:
training.evaluate(x_test, y_test, batch_size=32)



[0.9876096844673157, 0.00023556053929496557]