In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import math
import datetime
import pprint
import tqdm
# in a notebook, load the tensorboard extension, not needed for scripts
%load_ext tensorboard

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class FFN(tf.keras.Model):
    def __init__(self):
        super().__init__()
    
        self.optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
        
        self.metrics_list = [
                        tf.keras.metrics.Mean(name="loss"),
                        tf.keras.metrics.CategoricalAccuracy(name="acc"),
                        tf.keras.metrics.TopKCategoricalAccuracy(3,name="top-3-acc") 
                       ]
        
        self.loss_function = tf.keras.losses.CategoricalCrossentropy(from_logits=False)   
        
        # define layers
        self.layer1 = tf.keras.layers.Dense(256,activation=tf.nn.relu)
        self.layer2 = tf.keras.layers.Dense(256, activation=tf.nn.relu)
        self.output_layer = tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    
    def call(self, x, training=False):
        x = self.layer1(x)
        x = self.layer2(x)
        out = self.output_layer(x)
       
        return out
    
    def reset_metrics(self):
        
        for metric in self.metrics:
            metric.reset_states()
            
    @tf.function
    def train_step(self, data):
        
        x, targets = data
        
        with tf.GradientTape() as tape:
            predictions = self(x, training=True)
            
            loss = self.loss_function(targets, predictions) + tf.reduce_sum(self.losses)
        
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # update loss metric
        self.metrics[0].update_state(loss)
        
        # for all metrics except loss, update states (accuracy etc.)
        for metric in self.metrics[1:]:
            metric.update_state(targets,predictions)

        # Return a dictionary mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

    @tf.function
    def test_step(self, data):

        x, targets = data
        predictions = self(x, training=False)
        loss = self.loss_function(targets, predictions) + tf.reduce_sum(self.losses)

        self.metrics[0].update_state(loss)
        # for accuracy metrics:
        for metric in self.metrics[1:]:
            metric.update_state(targets, predictions)

        return {m.name: m.result() for m in self.metrics}

In [None]:
train_ds, val_ds = tfds.load("mnist", split=["train", "test"], as_supervised=True)

train_ds = train_ds.map(lambda img, target:)

In [4]:
train_ds, val_ds = tfds.load("mnist", split=["train", "test"], as_supervised=True)

train_ds = train_ds.map(lambda img, target: ((tf.cast(tf.reshape(img, (-1,)), tf.float32)/128.)-1., tf.one_hot(target, 10, dtype=tf.float32)),\
                        num_parallel_calls=tf.data.AUTOTUNE).cache().shuffle(5000).batch(32).prefetch(tf.data.AUTOTUNE)

val_ds = val_ds.map(lambda img, target: ((tf.cast(tf.reshape(img, (-1,)), tf.float32)/128.)-1., tf.one_hot(target, 10, dtype=tf.float32)),\
                    num_parallel_calls=tf.data.AUTOTUNE).cache().batch(32).prefetch(tf.data.AUTOTUNE)


In [4]:
list(train_ds.take(1))

[(<tf.Tensor: shape=(32, 784), dtype=float32, numpy=
  array([[-1., -1., -1., ..., -1., -1., -1.],
         [-1., -1., -1., ..., -1., -1., -1.],
         [-1., -1., -1., ..., -1., -1., -1.],
         ...,
         [-1., -1., -1., ..., -1., -1., -1.],
         [-1., -1., -1., ..., -1., -1., -1.],
         [-1., -1., -1., ..., -1., -1., -1.]], dtype=float32)>,
  <tf.Tensor: shape=(32, 10), dtype=float32, numpy=
  array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 

In [5]:
# instantiate the model
model = FFN()

# run model on input once so the layers are built
model(tf.keras.Input((32, 784)))
model.summary()

Model: "ffn"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               multiple                  200960    
                                                                 
 dense_1 (Dense)             multiple                  65792     
                                                                 
 dense_2 (Dense)             multiple                  2570      
                                                                 
Total params: 269,328
Trainable params: 269,322
Non-trainable params: 6
_________________________________________________________________


In [6]:
# Define where to save the log
config_name= "config_name"
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

train_log_path = f"logs/{config_name}/{current_time}/train"
val_log_path = f"logs/{config_name}/{current_time}/val"

# log writer for training metrics
train_summary_writer = tf.summary.create_file_writer(train_log_path)

# log writer for validation metrics
val_summary_writer = tf.summary.create_file_writer(val_log_path)

In [7]:
def training_loop(model, train_ds, val_ds, epochs, train_summary_writer, val_summary_writer):
    for epoch in range(epochs):
        print(f"Epoch {epoch}:")
        
        # Training:
        tq = tqdm.tqdm(train_ds, position=0, leave=True)
        for data in tq:
            metrics = model.train_step(data)
            # logging the validation metrics to the log file which is used by tensorboard
            with train_summary_writer.as_default():
                for metric in model.metrics:

                    # if metric
                    # tq.set_description(f"MSE: {}",refresh=True)
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)

        # print the metrics
        print([f"{key}: {value.numpy()}" for (key, value) in metrics.items()])

        # reset all metrics (requires a reset_metrics method in the model)
        model.reset_metrics()    
        
        # Validation:
        for data in val_ds:
            metrics = model.test_step(data)
        
            # logging the validation metrics to the log file which is used by tensorboard
            with val_summary_writer.as_default():
                for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)
                    
        print([f"val_{key}: {value.numpy()}" for (key, value) in metrics.items()])

        # reset all metrics
        model.reset_metrics()
        print("\n")

In [8]:
%tensorboard --logdir logs/

Reusing TensorBoard on port 6006 (pid 16280), started 1:27:34 ago. (Use '!kill 16280' to kill it.)

In [9]:
# run the training loop 
training_loop(model=model, 
                train_ds=train_ds, 
                val_ds=val_ds, 
                epochs=50, 
                train_summary_writer=train_summary_writer, 
                val_summary_writer=val_summary_writer)

Epoch 0:


100%|██████████| 1875/1875 [00:26<00:00, 71.94it/s]


['loss: 0.34150853753089905', 'acc: 0.8912166953086853', 'top-3-acc: 0.9734166860580444']
['val_loss: 0.15371127426624298', 'val_acc: 0.9517999887466431', 'val_top-3-acc: 0.9930999875068665']


Epoch 1:


100%|██████████| 1875/1875 [00:26<00:00, 70.42it/s]


['loss: 0.13463154435157776', 'acc: 0.9584833383560181', 'top-3-acc: 0.994949996471405']
['val_loss: 0.1248740628361702', 'val_acc: 0.9607999920845032', 'val_top-3-acc: 0.9939000010490417']


Epoch 2:


100%|██████████| 1875/1875 [00:27<00:00, 67.61it/s]


['loss: 0.09958094358444214', 'acc: 0.968666672706604', 'top-3-acc: 0.9969000220298767']
['val_loss: 0.09707801043987274', 'val_acc: 0.9678000211715698', 'val_top-3-acc: 0.9968000054359436']


Epoch 3:


100%|██████████| 1875/1875 [00:25<00:00, 73.67it/s]


['loss: 0.07865321636199951', 'acc: 0.9748666882514954', 'top-3-acc: 0.9980833530426025']
['val_loss: 0.0983724370598793', 'val_acc: 0.968999981880188', 'val_top-3-acc: 0.9976000189781189']


Epoch 4:


100%|██████████| 1875/1875 [00:27<00:00, 68.05it/s]


['loss: 0.06431134790182114', 'acc: 0.9798166751861572', 'top-3-acc: 0.9986000061035156']
['val_loss: 0.08546897023916245', 'val_acc: 0.9710999727249146', 'val_top-3-acc: 0.9975000023841858']


Epoch 5:


100%|██████████| 1875/1875 [00:23<00:00, 81.15it/s]


['loss: 0.054179172962903976', 'acc: 0.98253333568573', 'top-3-acc: 0.9990500211715698']
['val_loss: 0.09746911376714706', 'val_acc: 0.9690999984741211', 'val_top-3-acc: 0.9980000257492065']


Epoch 6:


100%|██████████| 1875/1875 [00:20<00:00, 92.43it/s]


['loss: 0.046427104622125626', 'acc: 0.984666645526886', 'top-3-acc: 0.9994666576385498']
['val_loss: 0.0979091078042984', 'val_acc: 0.9706000089645386', 'val_top-3-acc: 0.9975000023841858']


Epoch 7:


100%|██████████| 1875/1875 [00:23<00:00, 81.32it/s]


['loss: 0.038938894867897034', 'acc: 0.9867166876792908', 'top-3-acc: 0.9995333552360535']
['val_loss: 0.08633707463741302', 'val_acc: 0.9760000109672546', 'val_top-3-acc: 0.9977999925613403']


Epoch 8:


100%|██████████| 1875/1875 [00:25<00:00, 74.85it/s]


['loss: 0.033453091979026794', 'acc: 0.9888666868209839', 'top-3-acc: 0.9996333122253418']
['val_loss: 0.07460270822048187', 'val_acc: 0.9776999950408936', 'val_top-3-acc: 0.9983000159263611']


Epoch 9:


100%|██████████| 1875/1875 [00:25<00:00, 73.63it/s]


['loss: 0.02890862710773945', 'acc: 0.9899166822433472', 'top-3-acc: 0.9997000098228455']
['val_loss: 0.07793331146240234', 'val_acc: 0.9778000116348267', 'val_top-3-acc: 0.9979000091552734']


Epoch 10:


100%|██████████| 1875/1875 [00:25<00:00, 74.73it/s]


['loss: 0.02432647906243801', 'acc: 0.9924666881561279', 'top-3-acc: 0.999750018119812']
['val_loss: 0.07568110525608063', 'val_acc: 0.9794999957084656', 'val_top-3-acc: 0.9979000091552734']


Epoch 11:


100%|██████████| 1875/1875 [00:24<00:00, 75.48it/s]


['loss: 0.019032234326004982', 'acc: 0.9939500093460083', 'top-3-acc: 0.999916672706604']
['val_loss: 0.06958574801683426', 'val_acc: 0.9811000227928162', 'val_top-3-acc: 0.9983999729156494']


Epoch 12:


100%|██████████| 1875/1875 [00:24<00:00, 76.44it/s]


['loss: 0.01953180879354477', 'acc: 0.9933666586875916', 'top-3-acc: 0.999916672706604']
['val_loss: 0.07967589795589447', 'val_acc: 0.9790999889373779', 'val_top-3-acc: 0.998199999332428']


Epoch 13:


100%|██████████| 1875/1875 [00:24<00:00, 75.51it/s]


['loss: 0.018371999263763428', 'acc: 0.9935500025749207', 'top-3-acc: 0.9998999834060669']
['val_loss: 0.08659365773200989', 'val_acc: 0.9761999845504761', 'val_top-3-acc: 0.9983000159263611']


Epoch 14:


100%|██████████| 1875/1875 [00:23<00:00, 78.41it/s]


['loss: 0.0127332154661417', 'acc: 0.9957166910171509', 'top-3-acc: 0.9999499917030334']
['val_loss: 0.09034107625484467', 'val_acc: 0.9789999723434448', 'val_top-3-acc: 0.9977999925613403']


Epoch 15:


100%|██████████| 1875/1875 [00:27<00:00, 69.01it/s]


['loss: 0.013365108519792557', 'acc: 0.9956333041191101', 'top-3-acc: 0.9999333620071411']
['val_loss: 0.08867717534303665', 'val_acc: 0.9787999987602234', 'val_top-3-acc: 0.9976999759674072']


Epoch 16:


100%|██████████| 1875/1875 [00:25<00:00, 72.99it/s]


['loss: 0.011150045320391655', 'acc: 0.9960833191871643', 'top-3-acc: 0.9999666810035706']
['val_loss: 0.09422758966684341', 'val_acc: 0.9783999919891357', 'val_top-3-acc: 0.998199999332428']


Epoch 17:


100%|██████████| 1875/1875 [00:25<00:00, 72.93it/s]


['loss: 0.010992918163537979', 'acc: 0.9964666962623596', 'top-3-acc: 0.9999833106994629']
['val_loss: 0.08364678174257278', 'val_acc: 0.9803000092506409', 'val_top-3-acc: 0.9980999827384949']


Epoch 18:


100%|██████████| 1875/1875 [00:26<00:00, 69.47it/s]


['loss: 0.007680567447096109', 'acc: 0.9976166486740112', 'top-3-acc: 0.9999833106994629']
['val_loss: 0.08036080002784729', 'val_acc: 0.9818000197410583', 'val_top-3-acc: 0.998199999332428']


Epoch 19:


100%|██████████| 1875/1875 [00:22<00:00, 84.42it/s]


['loss: 0.009261597879230976', 'acc: 0.996833324432373', 'top-3-acc: 0.9999833106994629']
['val_loss: 0.09114983677864075', 'val_acc: 0.9805999994277954', 'val_top-3-acc: 0.9976999759674072']


Epoch 20:


100%|██████████| 1875/1875 [00:22<00:00, 81.97it/s]


['loss: 0.008768885396420956', 'acc: 0.9970999956130981', 'top-3-acc: 1.0']
['val_loss: 0.08397651463747025', 'val_acc: 0.9812999963760376', 'val_top-3-acc: 0.9980000257492065']


Epoch 21:


100%|██████████| 1875/1875 [00:22<00:00, 81.68it/s]


['loss: 0.0049278209917247295', 'acc: 0.9983999729156494', 'top-3-acc: 1.0']
['val_loss: 0.08147501945495605', 'val_acc: 0.980400025844574', 'val_top-3-acc: 0.9983999729156494']


Epoch 22:


100%|██████████| 1875/1875 [00:22<00:00, 81.81it/s]


['loss: 0.007986768148839474', 'acc: 0.9971833229064941', 'top-3-acc: 0.9999833106994629']
['val_loss: 0.08740774542093277', 'val_acc: 0.9800999760627747', 'val_top-3-acc: 0.9979000091552734']


Epoch 23:


100%|██████████| 1875/1875 [00:23<00:00, 79.39it/s]


['loss: 0.007285546977072954', 'acc: 0.9974499940872192', 'top-3-acc: 1.0']
['val_loss: 0.0820651650428772', 'val_acc: 0.9830999970436096', 'val_top-3-acc: 0.9979000091552734']


Epoch 24:


100%|██████████| 1875/1875 [00:22<00:00, 81.89it/s]


['loss: 0.0027594880666583776', 'acc: 0.9991666674613953', 'top-3-acc: 1.0']
['val_loss: 0.08859468996524811', 'val_acc: 0.9818000197410583', 'val_top-3-acc: 0.9977999925613403']


Epoch 25:


100%|██████████| 1875/1875 [00:23<00:00, 80.70it/s]


['loss: 0.004593678750097752', 'acc: 0.998449981212616', 'top-3-acc: 1.0']
['val_loss: 0.103614941239357', 'val_acc: 0.9789999723434448', 'val_top-3-acc: 0.9976999759674072']


Epoch 26:


100%|██████████| 1875/1875 [00:25<00:00, 73.30it/s]


['loss: 0.008674440905451775', 'acc: 0.9972833395004272', 'top-3-acc: 0.9999833106994629']
['val_loss: 0.10539736598730087', 'val_acc: 0.9782999753952026', 'val_top-3-acc: 0.9976999759674072']


Epoch 27:


100%|██████████| 1875/1875 [00:25<00:00, 73.91it/s]


['loss: 0.007340386509895325', 'acc: 0.9974333047866821', 'top-3-acc: 0.9999833106994629']
['val_loss: 0.09869898110628128', 'val_acc: 0.9797999858856201', 'val_top-3-acc: 0.998199999332428']


Epoch 28:


100%|██████████| 1875/1875 [00:24<00:00, 76.34it/s]


['loss: 0.004301588516682386', 'acc: 0.9984999895095825', 'top-3-acc: 1.0']
['val_loss: 0.10573913156986237', 'val_acc: 0.9799000024795532', 'val_top-3-acc: 0.9980000257492065']


Epoch 29:


100%|██████████| 1875/1875 [00:24<00:00, 76.00it/s]


['loss: 0.0015110253589227796', 'acc: 0.9997166395187378', 'top-3-acc: 1.0']
['val_loss: 0.09421123564243317', 'val_acc: 0.9818000197410583', 'val_top-3-acc: 0.9984999895095825']


Epoch 30:


100%|██████████| 1875/1875 [00:24<00:00, 75.80it/s]


['loss: 0.0018317471258342266', 'acc: 0.9995499849319458', 'top-3-acc: 1.0']
['val_loss: 0.09092629700899124', 'val_acc: 0.9825999736785889', 'val_top-3-acc: 0.9983999729156494']


Epoch 31:


100%|██████████| 1875/1875 [00:24<00:00, 75.34it/s]


['loss: 0.0008259740425273776', 'acc: 0.9998000264167786', 'top-3-acc: 1.0']
['val_loss: 0.08755786716938019', 'val_acc: 0.9833999872207642', 'val_top-3-acc: 0.9984999895095825']


Epoch 32:


100%|██████████| 1875/1875 [00:25<00:00, 74.53it/s]


['loss: 0.00047650118358433247', 'acc: 0.9998999834060669', 'top-3-acc: 1.0']
['val_loss: 0.08798333257436752', 'val_acc: 0.9828000068664551', 'val_top-3-acc: 0.9984999895095825']


Epoch 33:


100%|██████████| 1875/1875 [00:26<00:00, 71.13it/s]


['loss: 0.00037925707874819636', 'acc: 0.999916672706604', 'top-3-acc: 1.0']
['val_loss: 0.08921276777982712', 'val_acc: 0.9825999736785889', 'val_top-3-acc: 0.9984999895095825']


Epoch 34:


100%|██████████| 1875/1875 [00:25<00:00, 72.83it/s]


['loss: 0.0001407239178661257', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.08678590506315231', 'val_acc: 0.9833999872207642', 'val_top-3-acc: 0.9987000226974487']


Epoch 35:


100%|██████████| 1875/1875 [00:24<00:00, 75.32it/s]


['loss: 0.0001016379173961468', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.08670958131551743', 'val_acc: 0.9837999939918518', 'val_top-3-acc: 0.9987000226974487']


Epoch 36:


100%|██████████| 1875/1875 [00:30<00:00, 60.69it/s]


['loss: 8.748135587666184e-05', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.0868687778711319', 'val_acc: 0.9836999773979187', 'val_top-3-acc: 0.9987000226974487']


Epoch 37:


100%|██████████| 1875/1875 [00:27<00:00, 67.85it/s]


['loss: 8.079349208856001e-05', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.0869741439819336', 'val_acc: 0.9836999773979187', 'val_top-3-acc: 0.9987000226974487']


Epoch 38:


100%|██████████| 1875/1875 [00:25<00:00, 72.16it/s]


['loss: 7.585734419990331e-05', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.0871395394206047', 'val_acc: 0.9836000204086304', 'val_top-3-acc: 0.9987000226974487']


Epoch 39:


100%|██████████| 1875/1875 [00:25<00:00, 73.96it/s]


['loss: 7.058738265186548e-05', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.08735715597867966', 'val_acc: 0.9836000204086304', 'val_top-3-acc: 0.9987000226974487']


Epoch 40:


100%|██████████| 1875/1875 [00:23<00:00, 79.12it/s]


['loss: 6.645343819400296e-05', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.08757936209440231', 'val_acc: 0.9836999773979187', 'val_top-3-acc: 0.9987000226974487']


Epoch 41:


100%|██████████| 1875/1875 [00:22<00:00, 82.26it/s]


['loss: 6.296357605606318e-05', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.08785348385572433', 'val_acc: 0.9836999773979187', 'val_top-3-acc: 0.9986000061035156']


Epoch 42:


100%|██████████| 1875/1875 [00:23<00:00, 78.13it/s]


['loss: 6.0229056543903425e-05', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.0880521610379219', 'val_acc: 0.9836000204086304', 'val_top-3-acc: 0.9986000061035156']


Epoch 43:


100%|██████████| 1875/1875 [00:22<00:00, 83.83it/s]


['loss: 5.775853060185909e-05', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.0881437435746193', 'val_acc: 0.9837999939918518', 'val_top-3-acc: 0.9986000061035156']


Epoch 44:


100%|██████████| 1875/1875 [00:23<00:00, 81.05it/s]


['loss: 5.522563515114598e-05', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.08845759183168411', 'val_acc: 0.9839000105857849', 'val_top-3-acc: 0.9986000061035156']


Epoch 45:


100%|██████████| 1875/1875 [00:23<00:00, 81.38it/s]


['loss: 5.329292980604805e-05', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.0886065661907196', 'val_acc: 0.9839000105857849', 'val_top-3-acc: 0.9986000061035156']


Epoch 46:


100%|██████████| 1875/1875 [00:22<00:00, 83.29it/s]


['loss: 5.10859172209166e-05', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.08877644687891006', 'val_acc: 0.984000027179718', 'val_top-3-acc: 0.9986000061035156']


Epoch 47:


100%|██████████| 1875/1875 [00:22<00:00, 82.03it/s]


['loss: 4.9554233555682003e-05', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.08892601728439331', 'val_acc: 0.984000027179718', 'val_top-3-acc: 0.9986000061035156']


Epoch 48:


100%|██████████| 1875/1875 [00:23<00:00, 79.66it/s]


['loss: 4.811776307178661e-05', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.08911961317062378', 'val_acc: 0.9839000105857849', 'val_top-3-acc: 0.9986000061035156']


Epoch 49:


100%|██████████| 1875/1875 [00:23<00:00, 80.34it/s]


['loss: 4.651385097531602e-05', 'acc: 1.0', 'top-3-acc: 1.0']
['val_loss: 0.08929443359375', 'val_acc: 0.9839000105857849', 'val_top-3-acc: 0.9986000061035156']




In [10]:
# save the model with a meaningful name
model.save_weights(f"saved_model_{config_name}", save_format="tf")

# load the model:
# instantiate a new model from our CNN class
loaded_model = FFN()

# build the model
inp= tf.keras.Input((28,28,1))
loaded_model(inp)

# load the model weights to continue training. 
loaded_model.load_weights(f"saved_model_{config_name}");

# continue training (but: optimizer state is lost)

# run the training loop 
training_loop(model=loaded_model, 
                train_ds=train_ds, 
                val_ds=val_ds, 
                epochs=10, 
                train_summary_writer=train_summary_writer, 
                val_summary_writer=val_summary_writer)

ValueError: Received incompatible tensor with shape (784, 256) when attempting to restore variable with shape (1, 256) and name layer1/kernel/.ATTRIBUTES/VARIABLE_VALUE.