In [31]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import datetime
import tqdm

# in a notebook, load the tensorboard extension, not needed for scripts
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [32]:
class FFN(tf.keras.Model):
    def __init__(self):
        super().__init__()
    
        self.metrics_list = [tf.keras.metrics.Mean(name="loss"),
                             tf.keras.metrics.BinaryAccuracy(name="acc")]
        
        self.optimizer = tf.keras.optimizers.Adam()
        
        self.loss_function = tf.keras.losses.BinaryCrossentropy()
        
        # layers to be used
        self.dense1 = tf.keras.layers.Dense(64, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(64, activation=tf.nn.relu)
        self.dense3 = tf.keras.layers.Dense(64, activation=tf.nn.relu)
        self.out_layer = tf.keras.layers.Dense(1,activation=tf.nn.sigmoid) 
        
    @tf.function
    def call(self, images, training=False):
        img1, img2 = images
        
        img1_x = self.dense1(img1)
        img1_x = self.dense2(img1_x)
        img1_x = self.dense3(img1_x)
        
        img2_x = self.dense1(img2)
        img2_x = self.dense2(img2_x)
        img2_x = self.dense3(img2_x)
        
        combined_x = tf.concat([img1_x, img2_x], axis=1)
        
        return self.out_layer(combined_x)
    
    # 4. reset all metrics objects
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()

    # 5. train step method
    @tf.function
    def train_step(self, data):
        img1, img2, target = data
        
        with tf.GradientTape() as tape:
            output = self((img1, img2), training=True)
            loss = self.loss_function(target, output)
            
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        # update loss metric
        self.metrics[0].update_state(loss)
        
        # for all metrics except loss, update states (accuracy etc.)
        for metric in self.metrics[1:]:
            metric.update_state(target, output)

        # Return a dictionary mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}
    
    @tf.function
    def test_step(self, data):
        img1, img2, target = data

        output = self((img1, img2), training=False)
        loss = self.loss_function(target, output)

        self.metrics[0].update_state(loss)
        # for accuracy metrics:
        for metric in self.metrics[1:]:
            metric.update_state(target, output)

        return {m.name: m.result() for m in self.metrics}

def training_loop(model, train_ds, val_ds, epochs, train_summary_writer, val_summary_writer):
    for epoch in range(epochs):
        print(f"Epoch {epoch}:")
        
        # Training:
        
        for data in tqdm.tqdm(train_ds, position=0, leave=True):
            metrics = model.train_step(data)
            
            # logging the validation metrics to the log file which is used by tensorboard
            with train_summary_writer.as_default():
                for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)

        # print the metrics
        print([f"{key}: {value.numpy()}" for (key, value) in metrics.items()])

        # reset all metrics (requires a reset_metrics method in the model)
        model.reset_metrics()    
        
        # Validation:
        for data in val_ds:
            metrics = model.test_step(data)
        
            # logging the validation metrics to the log file which is used by tensorboard
            with val_summary_writer.as_default():
                for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)
                    
        print([f"val_{key}: {value.numpy()}" for (key, value) in metrics.items()])

        # reset all metrics
        model.reset_metrics()
        print("\n")

In [33]:
(train_ds, test_ds), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

def preprocess(data, batch_size):
    # image should be float
    data = data.map(lambda x, t: (tf.cast(x, float), t))
    # image should be flattened
    data = data.map(lambda x, t: (tf.reshape(x, (-1,)), t))
    # image vector will here have values between -1 and 1
    data = data.map(lambda x,t: ((x/128.)-1., t),\
                    num_parallel_calls = tf.data.AUTOTUNE).cache()
    # we want to have two mnist images in each example
    # this leads to a single example being ((x1,y1),(x2,y2))
    zipped_ds = tf.data.Dataset.zip((data.shuffle(2000), 
                                     data.shuffle(2000)))
    # map ((x1,y1),(x2,y2)) to (x1,x2, y1==y2*) *x1[1] + x2[1] >= 5
    zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], x1[1] + x2[1] >= 5))
    # transform boolean target to int
    zipped_ds = zipped_ds.map(lambda x1, x2, t: (x1,x2, tf.cast(t, tf.int32)))
    # batch the dataset
    zipped_ds = zipped_ds.batch(batch_size)
    # prefetch
    zipped_ds = zipped_ds.prefetch(tf.data.AUTOTUNE)
    return zipped_ds

In [34]:
# Define where to save the log
config_name= "config_name"
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

train_log_path = f"logs/{config_name}/{current_time}/train"
val_log_path = f"logs/{config_name}/{current_time}/val"

# log writer for training metrics
train_summary_writer = tf.summary.create_file_writer(train_log_path)

# log writer for validation metrics
val_summary_writer = tf.summary.create_file_writer(val_log_path)

In [35]:
%tensorboard --logdir logs/

Reusing TensorBoard on port 6006 (pid 16644), started 21:56:50 ago. (Use '!kill 16644' to kill it.)

In [36]:
model = FFN()

training_loop(model=model,
                train_ds=preprocess(train_ds,batch_size = 32), 
                val_ds=preprocess(test_ds,batch_size = 32), 
                epochs=30, 
                train_summary_writer=train_summary_writer, 
                val_summary_writer=val_summary_writer)

Epoch 0:


100%|██████████| 1875/1875 [00:20<00:00, 91.63it/s] 


['loss: 0.18452395498752594', 'acc: 0.9280666708946228']
['val_loss: 0.13359536230564117', 'val_acc: 0.9498999714851379']


Epoch 1:


100%|██████████| 1875/1875 [00:16<00:00, 113.60it/s]


['loss: 0.12632176280021667', 'acc: 0.9559666514396667']
['val_loss: 0.10942225158214569', 'val_acc: 0.964900016784668']


Epoch 2:


100%|██████████| 1875/1875 [00:16<00:00, 111.36it/s]


['loss: 0.1118360087275505', 'acc: 0.9619166851043701']
['val_loss: 0.11882492154836655', 'val_acc: 0.9599999785423279']


Epoch 3:


100%|██████████| 1875/1875 [00:16<00:00, 111.57it/s]


['loss: 0.10397277772426605', 'acc: 0.9673333168029785']
['val_loss: 0.10694197565317154', 'val_acc: 0.9692000150680542']


Epoch 4:


100%|██████████| 1875/1875 [00:16<00:00, 112.22it/s]


['loss: 0.09879332780838013', 'acc: 0.9700833559036255']
['val_loss: 0.09497284889221191', 'val_acc: 0.972000002861023']


Epoch 5:


100%|██████████| 1875/1875 [00:16<00:00, 113.27it/s]


['loss: 0.09466612339019775', 'acc: 0.970466673374176']
['val_loss: 0.09932252019643784', 'val_acc: 0.9660000205039978']


Epoch 6:


100%|██████████| 1875/1875 [00:16<00:00, 112.46it/s]


['loss: 0.08805335313081741', 'acc: 0.9739000201225281']
['val_loss: 0.09707578271627426', 'val_acc: 0.9743000268936157']


Epoch 7:


100%|██████████| 1875/1875 [00:16<00:00, 111.47it/s]


['loss: 0.0846991017460823', 'acc: 0.9761666655540466']
['val_loss: 0.12170896679162979', 'val_acc: 0.9581000208854675']


Epoch 8:


100%|██████████| 1875/1875 [00:16<00:00, 113.21it/s]


['loss: 0.08338457345962524', 'acc: 0.9749500155448914']
['val_loss: 0.10719374567270279', 'val_acc: 0.9689000248908997']


Epoch 9:


100%|██████████| 1875/1875 [00:16<00:00, 113.04it/s]


['loss: 0.07991459965705872', 'acc: 0.9779333472251892']
['val_loss: 0.08567831665277481', 'val_acc: 0.9764999747276306']


Epoch 10:


100%|██████████| 1875/1875 [00:16<00:00, 112.25it/s]


['loss: 0.08115673065185547', 'acc: 0.977649986743927']
['val_loss: 0.08232147991657257', 'val_acc: 0.9793000221252441']


Epoch 11:


100%|██████████| 1875/1875 [00:16<00:00, 111.36it/s]


['loss: 0.07890158891677856', 'acc: 0.9780833125114441']
['val_loss: 0.08745201677083969', 'val_acc: 0.9786999821662903']


Epoch 12:


100%|██████████| 1875/1875 [00:16<00:00, 113.20it/s]


['loss: 0.07763214409351349', 'acc: 0.9794999957084656']
['val_loss: 0.08304541558027267', 'val_acc: 0.9815000295639038']


Epoch 13:


100%|██████████| 1875/1875 [00:16<00:00, 112.39it/s]


['loss: 0.07450096309185028', 'acc: 0.9802666902542114']
['val_loss: 0.07854843884706497', 'val_acc: 0.9801999926567078']


Epoch 14:


100%|██████████| 1875/1875 [00:16<00:00, 110.32it/s]


['loss: 0.07177147269248962', 'acc: 0.9815666675567627']
['val_loss: 0.09516724199056625', 'val_acc: 0.9750999808311462']


Epoch 15:


100%|██████████| 1875/1875 [00:17<00:00, 109.82it/s]


['loss: 0.07114841043949127', 'acc: 0.9815000295639038']
['val_loss: 0.09069570153951645', 'val_acc: 0.9800000190734863']


Epoch 16:


100%|██████████| 1875/1875 [00:16<00:00, 111.66it/s]


['loss: 0.06864269822835922', 'acc: 0.9819666743278503']
['val_loss: 0.0746145024895668', 'val_acc: 0.9825999736785889']


Epoch 17:


100%|██████████| 1875/1875 [00:16<00:00, 110.97it/s]


['loss: 0.06877842545509338', 'acc: 0.9823166728019714']
['val_loss: 0.0839061439037323', 'val_acc: 0.9819999933242798']


Epoch 18:


100%|██████████| 1875/1875 [00:16<00:00, 112.06it/s]


['loss: 0.0672803595662117', 'acc: 0.9822666645050049']
['val_loss: 0.08648830652236938', 'val_acc: 0.9807000160217285']


Epoch 19:


100%|██████████| 1875/1875 [00:17<00:00, 110.18it/s]


['loss: 0.06636162847280502', 'acc: 0.9834333062171936']
['val_loss: 0.0898495689034462', 'val_acc: 0.9800999760627747']


Epoch 20:


100%|██████████| 1875/1875 [00:16<00:00, 112.43it/s]


['loss: 0.06522057950496674', 'acc: 0.984666645526886']
['val_loss: 0.07836715131998062', 'val_acc: 0.98089998960495']


Epoch 21:


100%|██████████| 1875/1875 [00:16<00:00, 113.14it/s]


['loss: 0.06400840729475021', 'acc: 0.9838166832923889']
['val_loss: 0.06966858357191086', 'val_acc: 0.9824000000953674']


Epoch 22:


100%|██████████| 1875/1875 [00:16<00:00, 111.63it/s]


['loss: 0.06325466930866241', 'acc: 0.9845666885375977']
['val_loss: 0.07957512140274048', 'val_acc: 0.9836999773979187']


Epoch 23:


100%|██████████| 1875/1875 [00:16<00:00, 111.73it/s]


['loss: 0.06285630911588669', 'acc: 0.9848833084106445']
['val_loss: 0.08054092526435852', 'val_acc: 0.9829999804496765']


Epoch 24:


100%|██████████| 1875/1875 [00:17<00:00, 108.46it/s]


['loss: 0.06108089163899422', 'acc: 0.9849666953086853']
['val_loss: 0.07852412015199661', 'val_acc: 0.98089998960495']


Epoch 25:


100%|██████████| 1875/1875 [00:16<00:00, 113.29it/s]


['loss: 0.06108899414539337', 'acc: 0.9847666621208191']
['val_loss: 0.08526896685361862', 'val_acc: 0.9815000295639038']


Epoch 26:


100%|██████████| 1875/1875 [00:16<00:00, 110.42it/s]


['loss: 0.05956495180726051', 'acc: 0.9856666922569275']
['val_loss: 0.07230580598115921', 'val_acc: 0.9824000000953674']


Epoch 27:


100%|██████████| 1875/1875 [00:16<00:00, 111.68it/s]


['loss: 0.059847746044397354', 'acc: 0.9859166741371155']
['val_loss: 0.08261139690876007', 'val_acc: 0.9836000204086304']


Epoch 28:


100%|██████████| 1875/1875 [00:16<00:00, 111.94it/s]


['loss: 0.0595540888607502', 'acc: 0.9853166937828064']
['val_loss: 0.08672389388084412', 'val_acc: 0.9819999933242798']


Epoch 29:


100%|██████████| 1875/1875 [00:16<00:00, 112.47it/s]


['loss: 0.057998426258563995', 'acc: 0.9859499931335449']
['val_loss: 0.08251028507947922', 'val_acc: 0.9828000068664551']


