In [41]:
import tensorflow as tf
import tensorflow_datasets as tfds
import datetime

%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [42]:
# get mnist from tensorflow_datasets
mnist = tfds.load("mnist", split =["train","test"], as_supervised=True)
train_ds = mnist[0]
val_ds = mnist[1]

In [43]:
# write function to create the dataset that we want
# "condition" expects of of two arguments: condition==0 being the comparison, and condition == 1 being the subtraction
def preprocess(data, batch_size, condition):
    
    
    # image should be float
    data = data.map(lambda x, t: (tf.cast(x, float), t))
    # image should be flattened
    data = data.map(lambda x, t: (tf.reshape(x, (-1,)), t))
    # image vector will here have values between 0 and 1
    data = data.map(lambda x,t: ((x/128.)-1, t))
    
    
    # we want to have two mnist images in each example
    # this leads to a single example being ((x1,y1),(x2,y2))
    zipped_ds = tf.data.Dataset.zip((data.shuffle(2000), 
                                     data.shuffle(2000)))
    
    if condition == 0:
        # use maping to change the target
        zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], x1[1]+x2[1] >= 5))
        # transform boolean target to int
        zipped_ds = zipped_ds.map(lambda x1, x2, t: (x1,x2, tf.cast(t, tf.int32)))
    else:
        # use maping to change the target
        zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], x1[1]-x2[1]))
        # transform target to float
        zipped_ds = zipped_ds.map(lambda x1, x2, t: (x1,x2, tf.cast(t, tf.float32)))
        
    
    # batch the dataset
    zipped_ds = zipped_ds.batch(batch_size)
    # prefetch
    zipped_ds = zipped_ds.prefetch(tf.data.AUTOTUNE)
    return zipped_ds

In [44]:
class TwinMNISTModel(tf.keras.Model):

    # 1. constructor
    # "condition" expects of of two arguments: condition==0 being the comparison, and condition == 1 being the subtraction
    # "opt" expects of of two arguments: opt==0 being the Adam optimizer, opt == 1 being the SGD optimizer, opt == 2 being RMSprop and opt == 3 being ADAGrad
    def __init__(self, condition, opt):
        # inherit functionality from parent class
        super().__init__()

        # metrics 
        if condition == 0:
            self.metrics_list = [tf.keras.metrics.Mean(name="loss"),
                                tf.keras.metrics.BinaryAccuracy()]
        else:
            self.metrics_list = [tf.keras.metrics.Mean(name="loss"),
                                 tf.keras.metrics.RootMeanSquaredError()]
        #optimizers
        if opt == 0:
            self.optimizer = tf.keras.optimizers.Adam()
        elif opt == 1: 
            self.optimizer = tf.keras.optimizers.SGD(0.01, 0.03)
        elif opt == 2: 
            self.optimizer = tf.keras.optimizers.RMSprop()
        elif opt == 3: 
            self.optimizer = tf.keras.optimizers.Adagrad()
        
        #loss functions 
        if condition == 0:
            self.loss_function = tf.keras.losses.BinaryCrossentropy()
        else:
            self.loss_function = tf.keras.losses.MeanSquaredError()
            
        # layers to encode the images (both layers used for both images)
        self.dense1 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        
        self.dense3 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        
        if condition == 0:
            self.out_layer = tf.keras.layers.Dense(1,activation=tf.nn.sigmoid)
        else:
            self.out_layer = tf.keras.layers.Dense(1,activation=None)
        
    # 2. call method (forward computation)
    def call(self, images, training=False):
        img1, img2 = images
        
        img1_x = self.dense1(img1)
        img1_x = self.dense2(img1_x)
        
        img2_x = self.dense1(img2)
        img2_x = self.dense2(img2_x)
        
        combined_x = tf.concat([img1_x, img2_x ], axis=1)
        combined_x = self.dense3(combined_x)
        return self.out_layer(combined_x)

    # 3. metrics property
    @property
    def metrics(self):
        return self.metrics_list
        # return a list with all metrics in the model

    # 4. reset all metrics objects
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()

    # 5. train step method
    @tf.function
    def train_step(self, data):
        img1, img2, label = data
        
        with tf.GradientTape() as tape:
            output = self((img1, img2), training=True)
            loss = self.loss_function(label, output)
            
        gradients = tape.gradient(loss, self.trainable_variables)
        
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # update the state of the metrics
        self.metrics[0].update_state(loss)
        for metric in self.metrics[1:]:
            metric.update_state(label, output)
            
        
        # return a dictionary with metric names as keys and metric results as values
        return {m.name : m.result() for m in self.metrics}

    # 6. test_step method
    @tf.function
    def test_step(self, data):
        img1, img2, label = data
        # same as train step (without parameter updates)
        output = self((img1, img2), training=False)
        loss = self.loss_function(label, output)
        self.metrics[0].update_state(loss)
        for metric in self.metrics[1:]:
            metric.update_state(label, output)
        
        return {m.name : m.result() for m in self.metrics}

In [45]:
def create_summary_writers(config_name):
    
    # Define where to save the logs
    # along with this, you may want to save a config file with the same name so you know what the hyperparameters were used
    # alternatively make a copy of the code that is used for later reference
    
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

    train_log_path = f"logs/{config_name}/{current_time}/train"
    val_log_path = f"logs/{config_name}/{current_time}/val"

    # log writer for training metrics
    train_summary_writer = tf.summary.create_file_writer(train_log_path)

    # log writer for validation metrics
    val_summary_writer = tf.summary.create_file_writer(val_log_path)
    
    return train_summary_writer, val_summary_writer


In [46]:
import tqdm
def training_loop(model, train_ds, val_ds, start_epoch,
                  epochs, train_summary_writer, 
                  val_summary_writer, save_path):

    # 1. iterate over epochs
    for e in range(start_epoch, epochs):

        # 2. train steps on all batches in the training data
        for data in tqdm.tqdm(train_ds, position=0, leave=True):
            metrics = model.train_step(data)

        # 3. log and print training metrics

        with train_summary_writer.as_default():
            # for scalar metrics:
            for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=e)
            # alternatively, log metrics individually (allows for non-scalar metrics such as tf.keras.metrics.MeanTensor)
            # e.g. tf.summary.image(name="mean_activation_layer3", data = metrics["mean_activation_layer3"],step=e)
        
        #print the metrics
        print([f"{key}: {value.numpy()}" for (key, value) in metrics.items()])
        
        # 4. reset metric objects
        model.reset_metrics()


        # 5. evaluate on validation data
        for data in val_ds:
            metrics = model.test_step(data)
        

        # 6. log validation metrics

        with val_summary_writer.as_default():
            # for scalar metrics:
            for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=e)
            # alternatively, log metrics individually (allows for non-scalar metrics such as tf.keras.metrics.MeanTensor)
            # e.g. tf.summary.image(name="mean_activation_layer3", data = metrics["mean_activation_layer3"],step=e)
            
        print([f"val_{key}: {value.numpy()}" for (key, value) in metrics.items()])
        # 7. reset metric objects
        model.reset_metrics()
        
    # 8. save model weights if save_path is given
    if save_path:
        model.save_weights(save_path)

In [47]:
# open the tensorboard logs
%tensorboard --logdir logs/ --host localhost --port 8087

Reusing TensorBoard on port 8087 (pid 12992), started 0:38:51 ago. (Use '!kill 12992' to kill it.)

In [48]:
# 1. Apply preprocessing to datasets

#train_ds = train_ds.take(1000)
#val_ds = val_ds.take(100)

train_ds = preprocess(train_ds, batch_size=32, condition = 0) 
val_ds= preprocess(val_ds, batch_size=32, condition = 0) 


# 2. instantiate model
models = [TwinMNISTModel(condition = 0, opt = 0), TwinMNISTModel(condition = 0, opt = 1), TwinMNISTModel(condition = 0, opt = 2), TwinMNISTModel(condition = 0, opt = 3)]

# 3. choose a path to save the weights
save_path = "trained_model_RUN1"

# 4. pass arguments to training loop function
n = 1

for model in models: 
    train_summary_writer, val_summary_writer = create_summary_writers(config_name="RUN"+ str(n))
    n += 1
    
    training_loop(model=model,
        train_ds=train_ds,
        val_ds=val_ds,
        start_epoch=0,
        epochs=10,
        train_summary_writer=train_summary_writer,
        val_summary_writer=val_summary_writer,
        save_path=save_path)
    

100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:37<00:00, 49.61it/s]


['loss: 0.15105974674224854', 'binary_accuracy: 0.9400333166122437']
['val_loss: 0.09854016453027725', 'val_binary_accuracy: 0.9653000235557556']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:33<00:00, 55.83it/s]


['loss: 0.08515463769435883', 'binary_accuracy: 0.9694333076477051']
['val_loss: 0.07681399583816528', 'val_binary_accuracy: 0.9733999967575073']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:30<00:00, 61.96it/s]


['loss: 0.06987934559583664', 'binary_accuracy: 0.9758999943733215']
['val_loss: 0.06136864796280861', 'val_binary_accuracy: 0.9797000288963318']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:30<00:00, 62.48it/s]


['loss: 0.058697447180747986', 'binary_accuracy: 0.9790999889373779']
['val_loss: 0.05549951270222664', 'val_binary_accuracy: 0.9812999963760376']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:33<00:00, 56.01it/s]


['loss: 0.05159265175461769', 'binary_accuracy: 0.9822166562080383']
['val_loss: 0.05001470819115639', 'val_binary_accuracy: 0.983299970626831']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:33<00:00, 56.62it/s]


['loss: 0.04761778563261032', 'binary_accuracy: 0.9832833409309387']
['val_loss: 0.06132951378822327', 'val_binary_accuracy: 0.9796000123023987']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:34<00:00, 54.97it/s]


['loss: 0.04384526237845421', 'binary_accuracy: 0.9850666522979736']
['val_loss: 0.05177909880876541', 'val_binary_accuracy: 0.984499990940094']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:32<00:00, 58.17it/s]


['loss: 0.04199792444705963', 'binary_accuracy: 0.9855166673660278']
['val_loss: 0.05138665810227394', 'val_binary_accuracy: 0.9829999804496765']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:28<00:00, 64.77it/s]


['loss: 0.04169772192835808', 'binary_accuracy: 0.9855666756629944']
['val_loss: 0.06114626303315163', 'val_binary_accuracy: 0.9804999828338623']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:31<00:00, 59.28it/s]


['loss: 0.03679279237985611', 'binary_accuracy: 0.9872166514396667']
['val_loss: 0.05256982520222664', 'val_binary_accuracy: 0.9839000105857849']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:31<00:00, 59.11it/s]


['loss: 0.2495661824941635', 'binary_accuracy: 0.9003499746322632']
['val_loss: 0.1893942505121231', 'val_binary_accuracy: 0.9229999780654907']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:32<00:00, 58.08it/s]


['loss: 0.1708274632692337', 'binary_accuracy: 0.9322666525840759']
['val_loss: 0.14221331477165222', 'val_binary_accuracy: 0.9474999904632568']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:29<00:00, 63.60it/s]


['loss: 0.12937933206558228', 'binary_accuracy: 0.9504666924476624']
['val_loss: 0.11074098944664001', 'val_binary_accuracy: 0.9596999883651733']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:30<00:00, 61.29it/s]


['loss: 0.10858304798603058', 'binary_accuracy: 0.9605166912078857']
['val_loss: 0.10890229791402817', 'val_binary_accuracy: 0.9605000019073486']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:33<00:00, 56.58it/s]


['loss: 0.0956466943025589', 'binary_accuracy: 0.9665833115577698']
['val_loss: 0.09290154278278351', 'val_binary_accuracy: 0.9653000235557556']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:28<00:00, 64.83it/s]


['loss: 0.08230137825012207', 'binary_accuracy: 0.9702666401863098']
['val_loss: 0.07351826876401901', 'val_binary_accuracy: 0.974399983882904']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:28<00:00, 65.53it/s]


['loss: 0.07488652318716049', 'binary_accuracy: 0.9736166596412659']
['val_loss: 0.07644029706716537', 'val_binary_accuracy: 0.9732999801635742']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:36<00:00, 52.08it/s]


['loss: 0.0673801600933075', 'binary_accuracy: 0.9761166572570801']
['val_loss: 0.07051470130681992', 'val_binary_accuracy: 0.9749000072479248']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:33<00:00, 56.22it/s]


['loss: 0.0625852420926094', 'binary_accuracy: 0.9787499904632568']
['val_loss: 0.0615217424929142', 'val_binary_accuracy: 0.9786999821662903']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:33<00:00, 55.39it/s]


['loss: 0.056462015956640244', 'binary_accuracy: 0.980733335018158']
['val_loss: 0.05876893177628517', 'val_binary_accuracy: 0.9787999987602234']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:36<00:00, 50.71it/s]


['loss: 0.1604619026184082', 'binary_accuracy: 0.9383999705314636']
['val_loss: 0.08950544893741608', 'val_binary_accuracy: 0.9674000144004822']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:37<00:00, 49.45it/s]


['loss: 0.10879118740558624', 'binary_accuracy: 0.9647166728973389']
['val_loss: 0.11101808398962021', 'val_binary_accuracy: 0.9679999947547913']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:35<00:00, 53.57it/s]


['loss: 0.1028788685798645', 'binary_accuracy: 0.9683833122253418']
['val_loss: 0.10165619105100632', 'val_binary_accuracy: 0.9652000069618225']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:38<00:00, 49.06it/s]


['loss: 0.10520676523447037', 'binary_accuracy: 0.9685999751091003']
['val_loss: 0.08670560270547867', 'val_binary_accuracy: 0.9739999771118164']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:32<00:00, 57.91it/s]


['loss: 0.1086566299200058', 'binary_accuracy: 0.9690166711807251']
['val_loss: 0.09450317174196243', 'val_binary_accuracy: 0.9754999876022339']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:32<00:00, 57.46it/s]


['loss: 0.1027074083685875', 'binary_accuracy: 0.9705666899681091']
['val_loss: 0.10026872158050537', 'val_binary_accuracy: 0.972599983215332']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:32<00:00, 56.94it/s]


['loss: 0.10246839374303818', 'binary_accuracy: 0.9713000059127808']
['val_loss: 0.09545672684907913', 'val_binary_accuracy: 0.9690999984741211']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:22<00:00, 82.84it/s]


['loss: 0.10266315191984177', 'binary_accuracy: 0.9712833166122437']
['val_loss: 0.13631168007850647', 'val_binary_accuracy: 0.9742000102996826']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:23<00:00, 80.33it/s]


['loss: 0.10226284712553024', 'binary_accuracy: 0.9716333150863647']
['val_loss: 0.10437232255935669', 'val_binary_accuracy: 0.979200005531311']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:33<00:00, 55.80it/s]


['loss: 0.09953545778989792', 'binary_accuracy: 0.9729999899864197']
['val_loss: 0.10198183357715607', 'val_binary_accuracy: 0.9786999821662903']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:28<00:00, 65.33it/s]


['loss: 0.30581405758857727', 'binary_accuracy: 0.873283326625824']
['val_loss: 0.24555785953998566', 'val_binary_accuracy: 0.9049000144004822']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:33<00:00, 56.36it/s]


['loss: 0.23808272182941437', 'binary_accuracy: 0.9049999713897705']
['val_loss: 0.21688003838062286', 'val_binary_accuracy: 0.9146000146865845']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:28<00:00, 66.34it/s]


['loss: 0.2121051400899887', 'binary_accuracy: 0.9156000018119812']
['val_loss: 0.19754146039485931', 'val_binary_accuracy: 0.9226999878883362']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:27<00:00, 67.19it/s]


['loss: 0.203724667429924', 'binary_accuracy: 0.9194833040237427']
['val_loss: 0.18734537065029144', 'val_binary_accuracy: 0.9265999794006348']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:26<00:00, 71.55it/s]


['loss: 0.1919778287410736', 'binary_accuracy: 0.9246000051498413']
['val_loss: 0.18238277733325958', 'val_binary_accuracy: 0.9304999709129333']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:28<00:00, 64.79it/s]


['loss: 0.1823308765888214', 'binary_accuracy: 0.928600013256073']
['val_loss: 0.17032557725906372', 'val_binary_accuracy: 0.9351000189781189']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:25<00:00, 74.79it/s]


['loss: 0.17428793013095856', 'binary_accuracy: 0.9309333562850952']
['val_loss: 0.16884711384773254', 'val_binary_accuracy: 0.9330999851226807']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:25<00:00, 74.87it/s]


['loss: 0.17020629346370697', 'binary_accuracy: 0.932283341884613']
['val_loss: 0.16056907176971436', 'val_binary_accuracy: 0.9388999938964844']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:27<00:00, 67.84it/s]


['loss: 0.1593446433544159', 'binary_accuracy: 0.9375']
['val_loss: 0.1566997468471527', 'val_binary_accuracy: 0.9409999847412109']


100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:25<00:00, 73.02it/s]


['loss: 0.15623322129249573', 'binary_accuracy: 0.9389500021934509']
['val_loss: 0.15257935225963593', 'val_binary_accuracy: 0.9391999840736389']
