In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import datetime
import tqdm

# in a notebook, load the tensorboard extension, not needed for scripts
%load_ext tensorboard

In [2]:
class FFN(tf.keras.Model):
    def __init__(self):
        super().__init__()
    
        self.metrics_list = [tf.keras.metrics.Mean(name="loss"),
                             tf.keras.metrics.CategoricalAccuracy(name="acc")]
        
        self.optimizer = tf.keras.optimizers.SGD(momentum=0.9)
        
        self.loss_function = tf.keras.losses.CategoricalCrossentropy()
        
        # layers to be used
        self.dense1 = tf.keras.layers.Dense(256, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(256, activation=tf.nn.relu)
        self.dense3 = tf.keras.layers.Dense(256, activation=tf.nn.relu)
        self.dense4 = tf.keras.layers.Dense(256, activation=tf.nn.relu)
        self.out_layer = tf.keras.layers.Dense(19,activation=tf.nn.softmax) 
        
    @tf.function
    def call(self, images, training=False):
        img1, img2 = images
        
        img1_x = self.dense1(img1)
        img1_x = self.dense2(img1_x)
        img1_x = self.dense3(img1_x)
        img1_x = self.dense4(img1_x)
        
        img2_x = self.dense1(img2)
        img2_x = self.dense2(img2_x)
        img2_x = self.dense3(img2_x)
        img2_x = self.dense4(img2_x)
        
        combined_x = tf.concat([img1_x, img2_x], axis=1)
        
        return self.out_layer(combined_x)
     
    @property
    def metrics(self):
        return self.metrics_list
        # return a list with all metrics in the model

    # 4. reset all metrics objects
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()

    # 5. train step method
    @tf.function
    def train_step(self, data):
        img1, img2, target = data
        
        with tf.GradientTape() as tape:
            output = self((img1, img2), training=True)
            loss = self.loss_function(target, output)
            
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        # update loss metric
        self.metrics[0].update_state(loss)
        
        # for all metrics except loss, update states (accuracy etc.)
        for metric in self.metrics[1:]:
            metric.update_state(target, output)

        # Return a dictionary mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}
    
    @tf.function
    def test_step(self, data):
        img1, img2, target = data

        output = self((img1, img2), training=False)
        loss = self.loss_function(target, output)

        self.metrics[0].update_state(loss)
        # for accuracy metrics:
        for metric in self.metrics[1:]:
            metric.update_state(target, output)

        return {m.name: m.result() for m in self.metrics}

def training_loop(model, train_ds, val_ds, epochs, train_summary_writer, val_summary_writer):
    for epoch in range(epochs):
        print(f"Epoch {epoch}:")
        
        # Training:
        
        for data in tqdm.tqdm(train_ds, position=0, leave=True):
            metrics = model.train_step(data)
            
            # logging the validation metrics to the log file which is used by tensorboard
            with train_summary_writer.as_default():
                for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)

        # print the metrics
        print([f"{key}: {value.numpy()}" for (key, value) in metrics.items()])

        # reset all metrics (requires a reset_metrics method in the model)
        model.reset_metrics()    
        
        # Validation:
        for data in val_ds:
            metrics = model.test_step(data)
        
            # logging the validation metrics to the log file which is used by tensorboard
            with val_summary_writer.as_default():
                for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)
                    
        print([f"val_{key}: {value.numpy()}" for (key, value) in metrics.items()])

        # reset all metrics
        model.reset_metrics()
        print("\n")

In [3]:
(train_ds, test_ds), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

def preprocess(data, batch_size):
    # image should be float
    data = data.map(lambda x, t: (tf.cast(x, float), t))
    # image should be flattened
    data = data.map(lambda x, t: (tf.reshape(x, (-1,)), t))
    # image vector will here have values between -1 and 1
    data = data.map(lambda x,t: ((x/128.)-1., t))
    # we want to have two mnist images in each example
    # this leads to a single example being ((x1,y1),(x2,y2))
    zipped_ds = tf.data.Dataset.zip((data.shuffle(2000), 
                                     data.shuffle(2000)))
    # map ((x1,y1),(x2,y2)) to (x1,x2, y1==y2*) *x1[1] + x2[1] >= 5
    zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], x1[1] - x2[1]))
    # transform 
    zipped_ds = zipped_ds.map(lambda x1, x2, t: (x1,x2, tf.cast(t, tf.int32)))
    # target vector should be one-hot encoded
    zipped_ds = zipped_ds.map(lambda x1, x2, t: (x1,x2, tf.one_hot(t, 19,dtype=tf.int32)))
    # batch the dataset
    zipped_ds = zipped_ds.batch(batch_size)
    # prefetch
    zipped_ds = zipped_ds.prefetch(tf.data.AUTOTUNE)
    return zipped_ds

In [4]:
# Define where to save the log
config_name= "config_name"
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

train_log_path = f"logs/{config_name}/{current_time}/SGDwMomentum/train"
val_log_path = f"logs/{config_name}/{current_time}/SGDwMomentum/val"

# log writer for training metrics
train_summary_writer = tf.summary.create_file_writer(train_log_path)

# log writer for validation metrics
val_summary_writer = tf.summary.create_file_writer(val_log_path)

In [5]:
%tensorboard --logdir logs/

Reusing TensorBoard on port 6006 (pid 14280), started 3 days, 18:43:48 ago. (Use '!kill 14280' to kill it.)

In [6]:
model = FFN()

training_loop(model=model,
                train_ds=preprocess(train_ds,batch_size = 128), 
                val_ds=preprocess(test_ds,batch_size = 128), 
                epochs=30, 
                train_summary_writer=train_summary_writer, 
                val_summary_writer=val_summary_writer)

Epoch 0:


100%|██████████| 469/469 [00:17<00:00, 26.33it/s]


['loss: 1.0280272960662842', 'acc: 0.47909998893737793']
['val_loss: 0.8640775084495544', 'val_acc: 0.6248000264167786']


Epoch 1:


100%|██████████| 469/469 [00:08<00:00, 53.83it/s]


['loss: 0.8081130385398865', 'acc: 0.6165833473205566']
['val_loss: 0.7454273700714111', 'val_acc: 0.6696000099182129']


Epoch 2:


100%|██████████| 469/469 [00:07<00:00, 59.99it/s]


['loss: 0.731031596660614', 'acc: 0.6802499890327454']
['val_loss: 0.701227605342865', 'val_acc: 0.6944000124931335']


Epoch 3:


100%|██████████| 469/469 [00:09<00:00, 50.64it/s]


['loss: 0.6786736249923706', 'acc: 0.7183666825294495']
['val_loss: 0.6744408011436462', 'val_acc: 0.7404000163078308']


Epoch 4:


100%|██████████| 469/469 [00:08<00:00, 54.92it/s]


['loss: 0.6417236328125', 'acc: 0.7409166693687439']
['val_loss: 0.6141703128814697', 'val_acc: 0.7688000202178955']


Epoch 5:


100%|██████████| 469/469 [00:09<00:00, 49.23it/s]


['loss: 0.6017286777496338', 'acc: 0.7703999876976013']
['val_loss: 0.6294800043106079', 'val_acc: 0.7724000215530396']


Epoch 6:


100%|██████████| 469/469 [00:09<00:00, 51.11it/s]


['loss: 0.5931857824325562', 'acc: 0.7793166637420654']
['val_loss: 0.575680673122406', 'val_acc: 0.7994999885559082']


Epoch 7:


100%|██████████| 469/469 [00:09<00:00, 48.58it/s]


['loss: 0.5567616820335388', 'acc: 0.8029999732971191']
['val_loss: 0.5910925269126892', 'val_acc: 0.8036999702453613']


Epoch 8:


100%|██████████| 469/469 [00:11<00:00, 39.81it/s]


['loss: 0.5481232404708862', 'acc: 0.8080000281333923']
['val_loss: 0.5742966532707214', 'val_acc: 0.8019000291824341']


Epoch 9:


100%|██████████| 469/469 [00:11<00:00, 40.18it/s]


['loss: 0.5247299671173096', 'acc: 0.8317833542823792']
['val_loss: 0.622374951839447', 'val_acc: 0.8144000172615051']


Epoch 10:


100%|██████████| 469/469 [00:08<00:00, 58.49it/s]


['loss: 0.5139338374137878', 'acc: 0.838366687297821']
['val_loss: 0.5014340281486511', 'val_acc: 0.864799976348877']


Epoch 11:


100%|██████████| 469/469 [00:08<00:00, 58.57it/s]


['loss: 0.5002608299255371', 'acc: 0.8510333299636841']
['val_loss: 0.5319105982780457', 'val_acc: 0.8435999751091003']


Epoch 12:


100%|██████████| 469/469 [00:08<00:00, 53.31it/s]


['loss: 0.48361313343048096', 'acc: 0.8540499806404114']
['val_loss: 0.5040201544761658', 'val_acc: 0.8799999952316284']


Epoch 13:


100%|██████████| 469/469 [00:08<00:00, 55.11it/s]


['loss: 0.4711146354675293', 'acc: 0.8670833110809326']
['val_loss: 0.5156205892562866', 'val_acc: 0.8547000288963318']


Epoch 14:


100%|██████████| 469/469 [00:09<00:00, 51.83it/s]


['loss: 0.4628044664859772', 'acc: 0.8707500100135803']
['val_loss: 0.6224249005317688', 'val_acc: 0.8445000052452087']


Epoch 15:


100%|██████████| 469/469 [00:08<00:00, 58.47it/s]


['loss: 0.4592842161655426', 'acc: 0.876800000667572']
['val_loss: 0.5061670541763306', 'val_acc: 0.9070000052452087']


Epoch 16:


100%|██████████| 469/469 [00:07<00:00, 60.45it/s]


['loss: 0.43286705017089844', 'acc: 0.8877833485603333']
['val_loss: 0.4545452296733856', 'val_acc: 0.8909000158309937']


Epoch 17:


100%|██████████| 469/469 [00:08<00:00, 58.55it/s]


['loss: 0.4433915317058563', 'acc: 0.8932499885559082']
['val_loss: 0.49473002552986145', 'val_acc: 0.8970999717712402']


Epoch 18:


100%|██████████| 469/469 [00:07<00:00, 61.22it/s]


['loss: 0.43566039204597473', 'acc: 0.8961666822433472']
['val_loss: 0.5549697279930115', 'val_acc: 0.9046000242233276']


Epoch 19:


100%|██████████| 469/469 [00:08<00:00, 56.95it/s]


['loss: 0.4391392767429352', 'acc: 0.8988000154495239']
['val_loss: 0.46355172991752625', 'val_acc: 0.9024999737739563']


Epoch 20:


100%|██████████| 469/469 [00:07<00:00, 59.22it/s]


['loss: 0.4142853617668152', 'acc: 0.9050999879837036']
['val_loss: 0.49610573053359985', 'val_acc: 0.9138000011444092']


Epoch 21:


100%|██████████| 469/469 [00:09<00:00, 49.58it/s]


['loss: 0.4083920419216156', 'acc: 0.9166333079338074']
['val_loss: 0.5142243504524231', 'val_acc: 0.8948000073432922']


Epoch 22:


100%|██████████| 469/469 [00:08<00:00, 56.04it/s]


['loss: 0.4176337420940399', 'acc: 0.9121500253677368']
['val_loss: 0.5435327291488647', 'val_acc: 0.8790000081062317']


Epoch 23:


100%|██████████| 469/469 [00:09<00:00, 50.54it/s]


['loss: 0.4094710350036621', 'acc: 0.9164333343505859']
['val_loss: 0.4599464535713196', 'val_acc: 0.9261999726295471']


Epoch 24:


100%|██████████| 469/469 [00:08<00:00, 54.77it/s]


['loss: 0.39221709966659546', 'acc: 0.9219833612442017']
['val_loss: 0.5116478204727173', 'val_acc: 0.8899999856948853']


Epoch 25:


100%|██████████| 469/469 [00:07<00:00, 59.91it/s]


['loss: 0.3955976665019989', 'acc: 0.92003333568573']
['val_loss: 0.49730178713798523', 'val_acc: 0.9150000214576721']


Epoch 26:


100%|██████████| 469/469 [00:09<00:00, 52.04it/s]


['loss: 0.3913388252258301', 'acc: 0.9251333475112915']
['val_loss: 0.5039185881614685', 'val_acc: 0.9103000164031982']


Epoch 27:


100%|██████████| 469/469 [00:08<00:00, 54.29it/s]


['loss: 0.3764664828777313', 'acc: 0.9325833320617676']
['val_loss: 0.44685372710227966', 'val_acc: 0.9276999831199646']


Epoch 28:


100%|██████████| 469/469 [00:07<00:00, 59.16it/s]


['loss: 0.3795968294143677', 'acc: 0.9311333298683167']
['val_loss: 0.45947396755218506', 'val_acc: 0.9138000011444092']


Epoch 29:


100%|██████████| 469/469 [00:08<00:00, 56.61it/s]


['loss: 0.36362743377685547', 'acc: 0.9322166442871094']
['val_loss: 0.4292239844799042', 'val_acc: 0.9351999759674072']


