# Flipped Classroom 04: Training a Twin Model on MNIST 

In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
import datetime

# magic line only needed in jupyter notebooks!
%load_ext tensorboard 

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


## Creating a dataset

**Task: take two random mnist digit 28x28x1 images and compute the equality of the numbers that they display.**


Input to the model will be: batch-size times two randomly chosen mnist digits, flattened to vectors.

Output of the model (its targets) should be: for each element in the batch, 1 if the two mnist images have the same label, and 0 otherwise.

*Notice how this will lead to an imbalanced classification dataset. We can address this by setting a sample_weight argument in the loss calculation, weighting the contribution to gradients stronger for the rare class than for the common class. We have similar issues with class imbalance in many applied fields (medicine, finance, fraud-detection etc.)*

In [9]:
# 1. get mnist from tensorflow_datasets
mnist = tfds.load("mnist", split =["train","test"], as_supervised=True)
train_ds = mnist[0]
val_ds = mnist[1]

# 2. write function to create the dataset that we want
def preprocess(data, batch_size):
    # image should be float
    data = data.map(lambda x, t: (tf.cast(x, float), t))
    # image should be flattened
    data = data.map(lambda x, t: (tf.reshape(x, (-1,)), t))
    # image vector will here have values between -1 and 1
    data = data.map(lambda x,t: ((x/128.)-1., t))
    # we want to have two mnist images in each example
    # this leads to a single example being ((x1,y1),(x2,y2))
    zipped_ds = tf.data.Dataset.zip((data.shuffle(2000), 
                                     data.shuffle(2000)))
    # map ((x1,y1),(x2,y2)) to (x1,x2, y1==y2*) *boolean
    zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], x1[1]==x2[1]))
    # transform boolean target to int
    zipped_ds = zipped_ds.map(lambda x1, x2, t: (x1,x2, tf.cast(t, tf.int32)))
    # batch the dataset
    zipped_ds = zipped_ds.batch(batch_size)
    # prefetch
    zipped_ds = zipped_ds.prefetch(tf.data.AUTOTUNE)
    return zipped_ds

train_ds = preprocess(train_ds, batch_size=32) #train_ds.apply(preprocess)
val_ds = preprocess(val_ds, batch_size=32) #val_ds.apply(preprocess)

In [10]:
# check the contents of the dataset
for img1, img2, label in train_ds:
    print(img1.shape, img2.shape, label.shape)
    break

(32, 784) (32, 784) (32,)


2022-11-25 12:31:05.975626: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2022-11-25 12:31:05.978701: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


## Create a model for the task

- The two images should be processed using the same parameters (sometimes called a siamese or twin model)

- What do we need?

1. Subclass from **tf.keras.Model**

2. Implement **constructor method** that contains 
    - a) metrics objects to keep track (tf.keras.metrics)
    - b) optimizer
    - c) loss function
    - d) the layers that are used for prediction, incl. output layer
3. A **call method** that applies the computations (makes use of the model's layers)
    - has a training flag argument to allow for different training and inference behavior
    - uses the same layers for both images and then e.g. concatenates the representations before feeding them to an output layer.
    
4. A **metrics property** that returns a list of all metrics objects in the model

5. A **reset_metrics** method that iterates through the metrics in the metrics property and resets their state.
    - This is called before after each epoch and between training and validation during an epoch (to not mix train score with val score)
    
6. A **train_step** method that takes a data tuple as input, containing one batch of e.g. inputs, targets. In this case it is (inputs_1, inputs_2, target)
    - compute prediction and loss within a gradient tape context
    - get gradients w.r.t. trainable variables from the tape
    - pass the gradients and the trainable variables to the optimizer
    - update the metrics' states using the loss, the prediction and the target
    - return a dictionary containing each metric's name as a key and the corresponding result state as its value. This allows to easily print, and it is needed for the more high level compile and fit methods.
    
    
7. A **test_step** method that does the same as the train_step but without gradient calculations or optimization. 
    - The training flag in the model's call method should be set to False.

In [11]:
class TwinMNISTModel(tf.keras.Model):

    # 1. constructor
    def __init__(self):
        super().__init__()
        # inherit functionality from parent class

        # optimizer, loss function and metrics
        self.metrics_list = [tf.keras.metrics.BinaryAccuracy(),
                             tf.keras.metrics.Mean(name="loss")]
        
        self.optimizer = tf.keras.optimizers.Adam()
        
        self.loss_function = tf.keras.losses.BinaryCrossentropy()

        # layers to encode the images (both layers used for both images)
        self.dense1 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        
        self.dense3 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        
        self.out_layer = tf.keras.layers.Dense(1,activation=tf.nn.sigmoid)
        
    # 2. call method (forward computation)
    def call(self, images, training=False):
        img1, img2 = images
        
        img1_x = self.dense1(img1)
        img1_x = self.dense2(img1_x)
        
        img2_x = self.dense1(img2)
        img2_x = self.dense2(img2_x)
        
        combined_x = tf.concat([img1_x, img2_x ], axis=1)
        combined_x = self.dense3(combined_x)
        return self.out_layer(combined_x)

    # 3. metrics property
    @property
    def metrics(self):
        return self.metrics_list
        # return a list with all metrics in the model

    # 4. reset all metrics objects
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()

    # 5. train step method
    @tf.function
    def train_step(self, data):
        img1, img2, label = data
        
        with tf.GradientTape() as tape:
            output = self((img1, img2), training=True)
            loss = self.loss_function(label, output)
            
        gradients = tape.gradient(loss, self.trainable_variables)
        
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # update the state of the metrics according to loss
        self.metrics[0].update_state(label, output)
        self.metrics[1].update_state(loss)
        
        # return a dictionary with metric names as keys and metric results as values
        return {m.name : m.result() for m in self.metrics}

    # 6. test_step method
    @tf.function
    def test_step(self, data):
        img1, img2, label = data
        # same as train step (without parameter updates)
        output = self((img1, img2), training=False)
        loss = self.loss_function(label, output)
        self.metrics[0].update_state(label, output)
        self.metrics[1].update_state(loss)
        
        return {m.name : m.result() for m in self.metrics}

## Create a summary writer to log data

- use tf.summary.create_file_writer(log_path)

In [12]:
def create_summary_writers(config_name):
    
    # Define where to save the logs
    # along with this, you may want to save a config file with the same name so you know what the hyperparameters were used
    # alternatively make a copy of the code that is used for later reference
    
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

    train_log_path = f"logs/{config_name}/{current_time}/train"
    val_log_path = f"logs/{config_name}/{current_time}/val"

    # log writer for training metrics
    train_summary_writer = tf.summary.create_file_writer(train_log_path)

    # log writer for validation metrics
    val_summary_writer = tf.summary.create_file_writer(val_log_path)
    
    return train_summary_writer, val_summary_writer

train_summary_writer, val_summary_writer = create_summary_writers(config_name="RUN1")

## Write a training loop function

Arguments: 
 - the model to train, 
 - the data to train on, 
 - the data to test on, 
 - what the first epoch is,
 - how many epochs to train, 
 - the train summary writer object to use for logging
 - the validation summary writer object to use for logging
 - a path to save trained model weights to

In [13]:
import tqdm
def training_loop(model, train_ds, val_ds, start_epoch,
                  epochs, train_summary_writer, 
                  val_summary_writer, save_path):

    # 1. iterate over epochs
    for e in range(start_epoch, epochs):

        # 2. train steps on all batches in the training data
        for data in tqdm.tqdm(train_ds, position=0, leave=True):
            metrics = model.train_step(data)

        # 3. log and print training metrics

        with train_summary_writer.as_default():
            # for scalar metrics:
            for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=e)
            # alternatively, log metrics individually (allows for non-scalar metrics such as tf.keras.metrics.MeanTensor)
            # e.g. tf.summary.image(name="mean_activation_layer3", data = metrics["mean_activation_layer3"],step=e)
        
        #print the metrics
        print([f"{key}: {value.numpy()}" for (key, value) in metrics.items()])
        
        # 4. reset metric objects
        model.reset_metrics()


        # 5. evaluate on validation data
        for data in val_ds:
            metrics = model.test_step(data)
        

        # 6. log validation metrics

        with val_summary_writer.as_default():
            # for scalar metrics:
            for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=e)
            # alternatively, log metrics individually (allows for non-scalar metrics such as tf.keras.metrics.MeanTensor)
            # e.g. tf.summary.image(name="mean_activation_layer3", data = metrics["mean_activation_layer3"],step=e)
            
        print([f"val_{key}: {value.numpy()}" for (key, value) in metrics.items()])
        # 7. reset metric objects
        model.reset_metrics()
        
    # 8. save model weights if save_path is given
    if save_path:
        model.save_weights(save_path)

## Use the training loop function to train the model

In [14]:
# open the tensorboard logs
%tensorboard --logdir logs/

In [15]:
# 1. instantiate model
model = TwinMNISTModel()

# 2. choose a path to save the weights
save_path = "trained_model_RUN1"

# 2. pass arguments to training loop function

training_loop(model=model,
    train_ds=train_ds,
    val_ds=val_ds,
    start_epoch=0,
    epochs=10,
    train_summary_writer=train_summary_writer,
    val_summary_writer=val_summary_writer,
    save_path=save_path)

100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:05<00:00, 354.65it/s]


['binary_accuracy: 0.9060333371162415', 'loss: 0.26210615038871765']
['val_binary_accuracy: 0.9147999882698059', 'val_loss: 0.21181651949882507']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 476.34it/s]


['binary_accuracy: 0.9337499737739563', 'loss: 0.16407448053359985']
['val_binary_accuracy: 0.9488000273704529', 'val_loss: 0.14193446934223175']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 481.36it/s]


['binary_accuracy: 0.9487500190734863', 'loss: 0.13859908282756805']
['val_binary_accuracy: 0.9605000019073486', 'val_loss: 0.11177009344100952']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 488.37it/s]


['binary_accuracy: 0.960099995136261', 'loss: 0.11381417512893677']
['val_binary_accuracy: 0.9620000123977661', 'val_loss: 0.10693037509918213']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 480.44it/s]


['binary_accuracy: 0.9645333290100098', 'loss: 0.10216216742992401']
['val_binary_accuracy: 0.9664999842643738', 'val_loss: 0.09265995025634766']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 472.63it/s]


['binary_accuracy: 0.9659166932106018', 'loss: 0.09643787145614624']
['val_binary_accuracy: 0.9732999801635742', 'val_loss: 0.07639753073453903']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:05<00:00, 355.76it/s]


['binary_accuracy: 0.9688000082969666', 'loss: 0.090479277074337']
['val_binary_accuracy: 0.9628000259399414', 'val_loss: 0.10456037521362305']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:05<00:00, 357.71it/s]


['binary_accuracy: 0.9696666598320007', 'loss: 0.08797740936279297']
['val_binary_accuracy: 0.9688000082969666', 'val_loss: 0.08604127168655396']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:05<00:00, 348.25it/s]


['binary_accuracy: 0.9714499711990356', 'loss: 0.0812680646777153']
['val_binary_accuracy: 0.9722999930381775', 'val_loss: 0.07755609601736069']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:05<00:00, 362.66it/s]


['binary_accuracy: 0.9710833430290222', 'loss: 0.08124373108148575']
['val_binary_accuracy: 0.9736999869346619', 'val_loss: 0.07744208723306656']


In [16]:
# load the model:
fresh_model = TwinMNISTModel()

# build the model's parameters by calling it on input
for img1,img2,label in train_ds:
    fresh_model((img1,img2));
    break

# load the saved weights
model = fresh_model.load_weights(save_path)

# we could now continue training (will use a fresh optimizer, without the old optimizer state)

# (OPTIONAL) Bonus: Sample weights

Below you can find the same model where the loss is weighted by the probability of both classes occuring. The class "images are showing the same number" has a probability of 0.1, while the case in which the numbers are not the same has a probability of 0.9. We can weight the loss contribution (and hence the gradient contribution) by the probability of not occuring. For every element in a batch, we weight its contribution to the loss by 0.9 if this element is showing label 1 (images are the same), and by 0.1 if the label is 0. Since this class is much more frequent, but we want equal contribution to gradients across the classes.

To get these sample weights, we write a new method "get_sample_weights" and use tf.where to fill the sample weights tensor according to a boolean tensor obtained from the labels.

In [19]:
class TwinMNISTModelSW(tf.keras.Model):

    # 1. constructor
    def __init__(self):
        super().__init__()
        # inherit functionality from parent class

        # optimizer, loss function and metrics
        self.metrics_list = [tf.keras.metrics.BinaryAccuracy(),
                             tf.keras.metrics.Mean(name="loss")]
        
        self.optimizer = tf.keras.optimizers.Adam()
        
        self.loss_function = tf.keras.losses.BinaryCrossentropy()

        # layers to encode the images (both layers used for both images)
        self.dense1 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        
        self.dense3 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        
        self.out_layer = tf.keras.layers.Dense(1,activation=tf.nn.sigmoid)
        
    # 2. call method (forward computation)
    def call(self, images, training=False):
        img1, img2 = images
        
        img1_x = self.dense1(img1)
        img1_x = self.dense2(img1_x)
        
        img2_x = self.dense1(img2)
        img2_x = self.dense2(img2_x)
        
        combined_x = tf.concat([img1_x, img2_x ], axis=1)
        combined_x = self.dense3(combined_x)
        return self.out_layer(combined_x)

    # 3. metrics property
    @property
    def metrics(self):
        return self.metrics_list
        # return a list with all metrics in the model

    # 4. reset all metrics objects
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()
            
    def get_sample_weights(self, label):
        """
        Tensor of shape (1,32), containing 0.9 where the label is 1 and 0.1 where the label is 0
        """
        return tf.reshape(tf.where(tf.cast(label,tf.bool), 0.9, 0.1), (1,-1))

    # 5. train step method
    @tf.function
    def train_step(self, data):
        img1, img2, label = data
        
        with tf.GradientTape() as tape:
            output = self((img1, img2), training=True)
            binary_sample_weights = self.get_sample_weights(label)
            loss = self.loss_function(label, output, sample_weight=binary_sample_weights)
            
        gradients = tape.gradient(loss, self.trainable_variables)
        
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # update the state of the metrics according to loss
        self.metrics[0].update_state(label, output)
        self.metrics[1].update_state(loss)
        
        # return a dictionary with metric names as keys and metric results as values
        return {m.name : m.result() for m in self.metrics}

    # 6. test_step method
    @tf.function
    def test_step(self, data):
        img1, img2, label = data
        # same as train step (without parameter updates)
        output = self((img1, img2), training=False)
        binary_sample_weights = self.get_sample_weights(label)
        loss = self.loss_function(label, output, sample_weight=binary_sample_weights)
        self.metrics[0].update_state(label, output)
        self.metrics[1].update_state(loss)
        
        return {m.name : m.result() for m in self.metrics}

In [20]:
# create new summary_writers
train_summary_writer, val_summary_writer = create_summary_writers(config_name="RUN1_with_sample_weights")

# 1. instantiate model
model = TwinMNISTModelSW()

# 2. choose a path to save the weights
save_path = "trained_model_RUN1_sample_weights"

# 2. pass arguments to training loop function

training_loop(model=model,
    train_ds=train_ds,
    val_ds=val_ds,
    start_epoch=0,
    epochs=10,
    train_summary_writer=train_summary_writer,
    val_summary_writer=val_summary_writer,
    save_path=save_path)

100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:04<00:00, 390.88it/s]


['binary_accuracy: 0.9121500253677368', 'loss: 0.04621931165456772']
['val_binary_accuracy: 0.9301999807357788', 'val_loss: 0.03250011056661606']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:04<00:00, 466.90it/s]


['binary_accuracy: 0.9449999928474426', 'loss: 0.02709951251745224']
['val_binary_accuracy: 0.9544000029563904', 'val_loss: 0.023736122995615005']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:04<00:00, 463.73it/s]


['binary_accuracy: 0.9616166949272156', 'loss: 0.021192438900470734']
['val_binary_accuracy: 0.9593999981880188', 'val_loss: 0.02040999010205269']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 470.75it/s]


['binary_accuracy: 0.9663666486740112', 'loss: 0.01814686320722103']
['val_binary_accuracy: 0.9699000120162964', 'val_loss: 0.017676211893558502']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:04<00:00, 468.09it/s]


['binary_accuracy: 0.9714333415031433', 'loss: 0.015792982652783394']
['val_binary_accuracy: 0.9753999710083008', 'val_loss: 0.014733675867319107']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:04<00:00, 466.47it/s]


['binary_accuracy: 0.9740833044052124', 'loss: 0.01481231115758419']
['val_binary_accuracy: 0.9757999777793884', 'val_loss: 0.013603908009827137']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:04<00:00, 461.19it/s]


['binary_accuracy: 0.9753000140190125', 'loss: 0.0138473529368639']
['val_binary_accuracy: 0.9776999950408936', 'val_loss: 0.012952522374689579']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 471.00it/s]


['binary_accuracy: 0.9769999980926514', 'loss: 0.013092205859720707']
['val_binary_accuracy: 0.9812999963760376', 'val_loss: 0.011002237908542156']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:05<00:00, 332.39it/s]


['binary_accuracy: 0.9783166646957397', 'loss: 0.011921226978302002']
['val_binary_accuracy: 0.974399983882904', 'val_loss: 0.013268744572997093']


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:04<00:00, 444.00it/s]


['binary_accuracy: 0.9800166487693787', 'loss: 0.011332252994179726']
['val_binary_accuracy: 0.9793000221252441', 'val_loss: 0.011616488918662071']
