In [10]:
import tensorflow as tf
import tensorflow_datasets as tfds
import datetime

%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [11]:
# get mnist from tensorflow_datasets
mnist = tfds.load("mnist", split =["train","test"], as_supervised=True)
train_ds = mnist[0]
val_ds = mnist[1]

In [12]:
# write function to create the dataset that we want
# "condition" expects of of two arguments: condition==0 being the comparison, and condition == 1 being the subtraction
def preprocess(data, batch_size, condition):
    
    
    # image should be float
    data = data.map(lambda x, t: (tf.cast(x, float), t))
    # image should be flattened
    data = data.map(lambda x, t: (tf.reshape(x, (-1,)), t))
    # image vector will here have values between 0 and 1
    data = data.map(lambda x,t: ((x/128.)-1, t))
    
    
    # we want to have two mnist images in each example
    # this leads to a single example being ((x1,y1),(x2,y2))
    zipped_ds = tf.data.Dataset.zip((data.shuffle(2000), 
                                     data.shuffle(2000)))
    
    if condition == 0:
        # use maping to change the target
        zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], x1[1]+x2[1] >= 5))
        # transform boolean target to int
        zipped_ds = zipped_ds.map(lambda x1, x2, t: (x1,x2, tf.cast(t, tf.int32)))
    else:
        # use maping to change the target
        zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], x1[1]-x2[1]))
        # transform target to float
        zipped_ds = zipped_ds.map(lambda x1, x2, t: (x1,x2, tf.cast(t, tf.float32)))
        
    
    # batch the dataset
    zipped_ds = zipped_ds.batch(batch_size)
    # prefetch
    zipped_ds = zipped_ds.prefetch(tf.data.AUTOTUNE)
    return zipped_ds

In [13]:
class TwinMNISTModel(tf.keras.Model):

    # 1. constructor
    # "condition" expects of of two arguments: condition==0 being the comparison, and condition == 1 being the subtraction
    # "opt" expects of of two arguments: opt==0 being the Adam optimizer, opt == 1 being the SGD optimizer, opt == 2 being RMSprop and opt == 3 being ADAGrad
    def __init__(self, condition, opt):
        # inherit functionality from parent class
        super().__init__()

        # metrics 
        if condition == 0:
            self.metrics_list = [tf.keras.metrics.Mean(name="loss"),
                                tf.keras.metrics.BinaryAccuracy()]
        else:
            self.metrics_list = [tf.keras.metrics.Mean(name="loss"),
                                 tf.keras.metrics.RootMeanSquaredError()]
        #optimizers
        if opt == 0:
            self.optimizer = tf.keras.optimizers.Adam()
        elif opt == 1: 
            self.optimizer = tf.keras.optimizers.SGD(0.01, 0.03)
        elif opt == 2: 
            self.optimizer = tf.keras.optimizers.RMSprop()
        elif opt == 3: 
            self.optimizer = tf.keras.optimizers.Adagrad()
        
        #loss functions 
        if condition == 0:
            self.loss_function = tf.keras.losses.BinaryCrossentropy()
        else:
            self.loss_function = tf.keras.losses.MeanSquaredError()
            
        # layers to encode the images (both layers used for both images)
        self.dense1 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        
        self.dense3 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        
        if condition == 0:
            self.out_layer = tf.keras.layers.Dense(1,activation=tf.nn.sigmoid)
        else:
            self.out_layer = tf.keras.layers.Dense(1,activation=None)
        
    # 2. call method (forward computation)
    def call(self, images, training=False):
        img1, img2 = images
        
        img1_x = self.dense1(img1)
        img1_x = self.dense2(img1_x)
        
        img2_x = self.dense1(img2)
        img2_x = self.dense2(img2_x)
        
        combined_x = tf.concat([img1_x, img2_x ], axis=1)
        combined_x = self.dense3(combined_x)
        return self.out_layer(combined_x)

    # 3. metrics property
    @property
    def metrics(self):
        return self.metrics_list
        # return a list with all metrics in the model

    # 4. reset all metrics objects
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()

    # 5. train step method
    @tf.function
    def train_step(self, data):
        img1, img2, label = data
        
        with tf.GradientTape() as tape:
            output = self((img1, img2), training=True)
            loss = self.loss_function(label, output)
            
        gradients = tape.gradient(loss, self.trainable_variables)
        
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # update the state of the metrics
        self.metrics[0].update_state(loss)
        for metric in self.metrics[1:]:
            metric.update_state(label, output)
            
        
        # return a dictionary with metric names as keys and metric results as values
        return {m.name : m.result() for m in self.metrics}

    # 6. test_step method
    @tf.function
    def test_step(self, data):
        img1, img2, label = data
        # same as train step (without parameter updates)
        output = self((img1, img2), training=False)
        loss = self.loss_function(label, output)
        self.metrics[0].update_state(loss)
        for metric in self.metrics[1:]:
            metric.update_state(label, output)
        
        return {m.name : m.result() for m in self.metrics}

In [14]:
def create_summary_writers(config_name):
    
    # Define where to save the logs
    # along with this, you may want to save a config file with the same name so you know what the hyperparameters were used
    # alternatively make a copy of the code that is used for later reference
    
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

    train_log_path = f"logs/{config_name}/{current_time}/train"
    val_log_path = f"logs/{config_name}/{current_time}/val"

    # log writer for training metrics
    train_summary_writer = tf.summary.create_file_writer(train_log_path)

    # log writer for validation metrics
    val_summary_writer = tf.summary.create_file_writer(val_log_path)
    
    return train_summary_writer, val_summary_writer


In [15]:
import tqdm
def training_loop(model, train_ds, val_ds, start_epoch,
                  epochs, train_summary_writer, 
                  val_summary_writer, save_path):

    # 1. iterate over epochs
    for e in range(start_epoch, epochs):

        # 2. train steps on all batches in the training data
        for data in tqdm.tqdm(train_ds, position=0, leave=True):
            metrics = model.train_step(data)

        # 3. log and print training metrics

        with train_summary_writer.as_default():
            # for scalar metrics:
            for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=e)
            # alternatively, log metrics individually (allows for non-scalar metrics such as tf.keras.metrics.MeanTensor)
            # e.g. tf.summary.image(name="mean_activation_layer3", data = metrics["mean_activation_layer3"],step=e)
        
        #print the metrics
        print([f"{key}: {value.numpy()}" for (key, value) in metrics.items()])
        
        # 4. reset metric objects
        model.reset_metrics()


        # 5. evaluate on validation data
        for data in val_ds:
            metrics = model.test_step(data)
        

        # 6. log validation metrics

        with val_summary_writer.as_default():
            # for scalar metrics:
            for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=e)
            # alternatively, log metrics individually (allows for non-scalar metrics such as tf.keras.metrics.MeanTensor)
            # e.g. tf.summary.image(name="mean_activation_layer3", data = metrics["mean_activation_layer3"],step=e)
            
        print([f"val_{key}: {value.numpy()}" for (key, value) in metrics.items()])
        # 7. reset metric objects
        model.reset_metrics()
        
    # 8. save model weights if save_path is given
    if save_path:
        model.save_weights(save_path)

In [16]:
# open the tensorboard logs
%tensorboard --logdir logs/ --host localhost --port 8086

In [17]:
# 1. Apply preprocessing to datasets

train_ds = train_ds.take(1000)
val_ds = val_ds.take(100)

train_ds = preprocess(train_ds, batch_size=32, condition = 0) 
val_ds= preprocess(val_ds, batch_size=32, condition = 0) 


# 2. instantiate model
models = [TwinMNISTModel(condition = 0, opt = 0), TwinMNISTModel(condition = 0, opt = 1), TwinMNISTModel(condition = 0, opt = 2), TwinMNISTModel(condition = 0, opt = 3)]

# 3. choose a path to save the weights
save_path = "trained_model_RUN1"

# 4. pass arguments to training loop function
n = 1

for model in models: 
    train_summary_writer, val_summary_writer = create_summary_writers(config_name="RUN"+ str(n))
    n += 1
    
    training_loop(model=model,
        train_ds=train_ds,
        val_ds=val_ds,
        start_epoch=0,
        epochs=10,
        train_summary_writer=train_summary_writer,
        val_summary_writer=val_summary_writer,
        save_path=save_path)
    

100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.79it/s]


['loss: 0.3523547053337097', 'binary_accuracy: 0.8669999837875366']
['val_loss: 0.37033623456954956', 'val_binary_accuracy: 0.8799999952316284']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 34.59it/s]


['loss: 0.2745150625705719', 'binary_accuracy: 0.8870000243186951']
['val_loss: 0.1771755814552307', 'val_binary_accuracy: 0.9200000166893005']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 50.74it/s]


['loss: 0.26036474108695984', 'binary_accuracy: 0.8989999890327454']
['val_loss: 0.25127530097961426', 'val_binary_accuracy: 0.8899999856948853']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 38.79it/s]


['loss: 0.2510366439819336', 'binary_accuracy: 0.9010000228881836']
['val_loss: 0.2726595997810364', 'val_binary_accuracy: 0.9200000166893005']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 53.32it/s]


['loss: 0.19778679311275482', 'binary_accuracy: 0.9210000038146973']
['val_loss: 0.2737528383731842', 'val_binary_accuracy: 0.9200000166893005']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 44.51it/s]


['loss: 0.2013217806816101', 'binary_accuracy: 0.9129999876022339']
['val_loss: 0.35991454124450684', 'val_binary_accuracy: 0.8100000023841858']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 54.51it/s]


['loss: 0.19838932156562805', 'binary_accuracy: 0.9210000038146973']
['val_loss: 0.29532843828201294', 'val_binary_accuracy: 0.9200000166893005']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:01<00:00, 31.23it/s]


['loss: 0.16424421966075897', 'binary_accuracy: 0.9300000071525574']
['val_loss: 0.2415727823972702', 'val_binary_accuracy: 0.8899999856948853']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 42.71it/s]


['loss: 0.1848580539226532', 'binary_accuracy: 0.9300000071525574']
['val_loss: 0.29376235604286194', 'val_binary_accuracy: 0.9200000166893005']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 45.29it/s]


['loss: 0.1548178791999817', 'binary_accuracy: 0.9390000104904175']
['val_loss: 0.3371521830558777', 'val_binary_accuracy: 0.9300000071525574']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00, 10.55it/s]


['loss: 0.40790417790412903', 'binary_accuracy: 0.8610000014305115']
['val_loss: 0.285757839679718', 'val_binary_accuracy: 0.8899999856948853']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 52.93it/s]


['loss: 0.3887278437614441', 'binary_accuracy: 0.8460000157356262']
['val_loss: 0.2750319242477417', 'val_binary_accuracy: 0.8999999761581421']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 40.82it/s]


['loss: 0.34964555501937866', 'binary_accuracy: 0.8560000061988831']
['val_loss: 0.3317318558692932', 'val_binary_accuracy: 0.8799999952316284']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 50.73it/s]


['loss: 0.33543941378593445', 'binary_accuracy: 0.8600000143051147']
['val_loss: 0.24264562129974365', 'val_binary_accuracy: 0.9399999976158142']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 41.53it/s]


['loss: 0.3037593364715576', 'binary_accuracy: 0.871999979019165']
['val_loss: 0.31958991289138794', 'val_binary_accuracy: 0.9300000071525574']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:01<00:00, 29.27it/s]


['loss: 0.2888820171356201', 'binary_accuracy: 0.8809999823570251']
['val_loss: 0.29791438579559326', 'val_binary_accuracy: 0.8299999833106995']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 32.12it/s]


['loss: 0.28823089599609375', 'binary_accuracy: 0.8830000162124634']
['val_loss: 0.21346940100193024', 'val_binary_accuracy: 0.8899999856948853']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 42.59it/s]


['loss: 0.27639833092689514', 'binary_accuracy: 0.8889999985694885']
['val_loss: 0.21890036761760712', 'val_binary_accuracy: 0.9200000166893005']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 40.13it/s]


['loss: 0.2808763384819031', 'binary_accuracy: 0.8799999952316284']
['val_loss: 0.34969455003738403', 'val_binary_accuracy: 0.8999999761581421']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:01<00:00, 30.23it/s]


['loss: 0.2891523241996765', 'binary_accuracy: 0.890999972820282']
['val_loss: 0.19459415972232819', 'val_binary_accuracy: 0.9300000071525574']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:04<00:00,  6.78it/s]


['loss: 0.4313521981239319', 'binary_accuracy: 0.8360000252723694']
['val_loss: 0.3671383857727051', 'val_binary_accuracy: 0.8799999952316284']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 58.18it/s]


['loss: 0.31508034467697144', 'binary_accuracy: 0.8730000257492065']
['val_loss: 0.590692400932312', 'val_binary_accuracy: 0.9100000262260437']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 57.71it/s]


['loss: 0.29144811630249023', 'binary_accuracy: 0.8889999985694885']
['val_loss: 0.5881556272506714', 'val_binary_accuracy: 0.75']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 58.70it/s]


['loss: 0.25521159172058105', 'binary_accuracy: 0.8849999904632568']
['val_loss: 0.11423946917057037', 'val_binary_accuracy: 0.949999988079071']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 41.97it/s]


['loss: 0.22616419196128845', 'binary_accuracy: 0.9070000052452087']
['val_loss: 0.29532599449157715', 'val_binary_accuracy: 0.8600000143051147']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 46.39it/s]


['loss: 0.20906275510787964', 'binary_accuracy: 0.9179999828338623']
['val_loss: 0.248576819896698', 'val_binary_accuracy: 0.9200000166893005']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 55.09it/s]


['loss: 0.18091349303722382', 'binary_accuracy: 0.9259999990463257']
['val_loss: 0.11350036412477493', 'val_binary_accuracy: 0.9599999785423279']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 33.34it/s]


['loss: 0.18077410757541656', 'binary_accuracy: 0.9279999732971191']
['val_loss: 0.13531364500522614', 'val_binary_accuracy: 0.949999988079071']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 47.36it/s]


['loss: 0.1626872569322586', 'binary_accuracy: 0.9350000023841858']
['val_loss: 0.3111376166343689', 'val_binary_accuracy: 0.9300000071525574']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 49.59it/s]


['loss: 0.16903971135616302', 'binary_accuracy: 0.9300000071525574']
['val_loss: 0.35222935676574707', 'val_binary_accuracy: 0.8299999833106995']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:02<00:00, 11.18it/s]


['loss: 0.42421749234199524', 'binary_accuracy: 0.8539999723434448']
['val_loss: 0.3431069552898407', 'val_binary_accuracy: 0.8399999737739563']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 41.15it/s]


['loss: 0.3887685239315033', 'binary_accuracy: 0.8690000176429749']
['val_loss: 0.2721507251262665', 'val_binary_accuracy: 0.8899999856948853']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 45.15it/s]


['loss: 0.3846343755722046', 'binary_accuracy: 0.8560000061988831']
['val_loss: 0.2989509105682373', 'val_binary_accuracy: 0.8700000047683716']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 37.21it/s]


['loss: 0.38745370507240295', 'binary_accuracy: 0.8489999771118164']
['val_loss: 0.28702491521835327', 'val_binary_accuracy: 0.8899999856948853']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 55.68it/s]


['loss: 0.3592173457145691', 'binary_accuracy: 0.8629999756813049']
['val_loss: 0.3794255554676056', 'val_binary_accuracy: 0.8600000143051147']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 37.14it/s]


['loss: 0.3557065725326538', 'binary_accuracy: 0.8669999837875366']
['val_loss: 0.27296632528305054', 'val_binary_accuracy: 0.8999999761581421']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:01<00:00, 31.80it/s]


['loss: 0.36774033308029175', 'binary_accuracy: 0.843999981880188']
['val_loss: 0.28886207938194275', 'val_binary_accuracy: 0.8899999856948853']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 47.20it/s]


['loss: 0.3430318832397461', 'binary_accuracy: 0.8650000095367432']
['val_loss: 0.3405369520187378', 'val_binary_accuracy: 0.8999999761581421']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 36.34it/s]


['loss: 0.3295154571533203', 'binary_accuracy: 0.8700000047683716']
['val_loss: 0.24748769402503967', 'val_binary_accuracy: 0.8899999856948853']


100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 52.65it/s]


['loss: 0.34645217657089233', 'binary_accuracy: 0.8529999852180481']
['val_loss: 0.24118660390377045', 'val_binary_accuracy: 0.8999999761581421']
