In [1]:


import tensorflow as tf
import tensorflow_datasets as tfds
import datetime

%load_ext tensorboard



In [2]:
# 1. get mnist from tensorflow_datasets
mnist = tfds.load("mnist", split =["train","test"], as_supervised=True)
train_ds = mnist[0]
val_ds = mnist[1]

# 2. write function to create the dataset that we want
def preprocess(data, batch_size):
    # image should be float
    data = data.map(lambda x, t: (tf.cast(x, float), t))
    # image should be flattened
    data = data.map(lambda x, t: (tf.reshape(x, (-1,)), t))
    # image vector will here have values between -1 and 1
    data = data.map(lambda x,t: ((x/128.)-1., t))
    # we want to have two mnist images in each example
    # this leads to a single example being ((x1,y1),(x2,y2))
    zipped_ds = tf.data.Dataset.zip((data.shuffle(2000), 
                                     data.shuffle(2000)))
    # map ((x1,y1),(x2,y2)) to (x1,x2, y1==y2*) *boolean
    zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], x1[1]==x2[1]))
    # transform boolean target to int
    zipped_ds = zipped_ds.map(lambda x1, x2, t: (x1,x2, tf.cast(t, tf.int32)))
    # batch the dataset
    zipped_ds = zipped_ds.batch(batch_size)
    # prefetch
    zipped_ds = zipped_ds.prefetch(tf.data.AUTOTUNE)
    return zipped_ds

train_ds = preprocess(train_ds, batch_size=32) #train_ds.apply(preprocess)
val_ds = preprocess(val_ds, batch_size=32) #val_ds.apply(preprocess)

2022-11-23 14:23:10.814281: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-23 14:23:10.816648: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


In [3]:
# check the contents of the dataset
for img1, img2, label in train_ds.take(1):
    print(img1.shape, img2.shape, label.shape)

(32, 784) (32, 784) (32,)


2022-11-23 14:23:22.440064: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2022-11-23 14:23:22.442298: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [None]:
class TwinMNISTModel(tf.keras.Model):

    # 1. constructor
    def __init__(self):
        super().__init__()
        # inherit functionality from parent class

        # optimizer, loss function and metrics
        self.metrics_list = [tf.keras.metrics.BinaryAccuracy(),
                             tf.keras.metrics.Mean(name="loss")]
        
        self.optimizer = tf.keras.optimizers.Adam()
        
        self.loss_function = tf.keras.losses.BinaryCrossentropy()
        
        # layers to be used
        self.dense1 = tf.keras.layers.Dense(32, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(32, activation=tf.nn.relu)
        
        self.out_layer = tf.keras.layer.Dense(1,activation=tf.nn.sigmoid)
        
        
        
        
    # 2. call method (forward computation)
    def call(self, images, training=False):
        img1, img2 = images
        
        img1_x = self.dense1(img1)
        img1_x = self.dense2(img1_x)
        
        img2_x = self.dense1(img2)
        img2_x = self.dense2(img2_x)
        
        combined_x = tf.concat([img1_x, img2_x ], axis=1)
        
        return self.out_layer(combined_x)



    # 3. metrics property
    @property
    def metrics(self):
        return self.metrics_list
        # return a list with all metrics in the model



    # 4. reset all metrics objects
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()



    # 5. train step method
    def train_step(self, data):
        img1, img2, label = data
        
        with tf.GradientTape() as tape:
            output = self((img1, img2), training=True)
            loss = self.loss_function(label, output)
            
        gradients = tape.gradient(loss, self.trainable_variables)
        
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # update the state of the metrics according to loss
        self.metrics[0].update_state(label, output)
        self.metrics[1].update_state(loss)
        
        # return a dictionary with metric names as keys and metric results as values
        return {m.name: m.result() for m in self.metrics}




    # 6. test_step method
    def test_step(self, data):
        img_1, img_2, label = data
        # same as train step (without parameter updates)
        output = self((img_1, img_2), training=False)
        

In [None]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

train_log_path = f"logs/"

val_log_path = f"logs/"

train_summary_writer = ...

val_summary_writer = ...

In [None]:
def training_loop(model, train_ds, val_ds,
                  epochs, trains_summary_writer,
                  val_summary_writer, save_path):

    # 1. iterate over epochs
    for e in range(epochs):
        
    
        # 2. train steps on all batches in the training data
        for data in tqdm.tqdm(train_ds): #tqdm optional
            metrics = model.train_step(data)


        # 3. log and print training metrics

        with train_summary_writer.as_default():
            tf.summary.scalar(name="binary_accuracy",data=metrics["binary_accuracy"], step=e)
            tf.summary.scalar(name="loss",data=metrics["loss"], step=e)
            
        print(metrics.items())


        # 4. reset metric objects
        model.reset_metrics()



        # 5. evaluate on validation data
        for data in val_ds:
            metrics = model.test_step(data)


        # 6. log validation metrics

        with val_summary_writer.as_default():
            tf.summary.scalar(name="val_binary_accuracy",data=metrics["val_binary_accuracy"], step=e)
            tf.summary.scalar(name="val_loss",data=metrics["val_loss"], step=e)
        
        print(metrics.items())

        # 7. reset metric objects
        model.reset_metrics()

    # 8. save model weights
    model.save_weights(save_path)

In [None]:


# open the tensorboard logs
%tensorboard --logdir logs/



In [None]:
# 1. instantiate model
model = TWINMNISTModel()

# 2. choose a path to save the weights

save_path = "trained_model"

# 2. pass arguments to training loop function

training_loop(model=model, 
              train_ds=train=ds, 
              epochs=5, 
              train_summary_writer=train_summary_writer, 
              val_summary_writer=val_summary_writer, 
              save_path=save_path)