In [104]:
import tensorflow as tf
import tensorflow_datasets as tfds
import datetime

%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [105]:
# 1. get mnist from tensorflow_datasets
mnist = tfds.load("mnist", split =["train","test"], as_supervised=True)
train_ds = mnist[0]
val_ds = mnist[1]

In [106]:
# 2. write function to create the dataset that we want, 
#condition==0 being the comparison, and condition == 1 being the subtraction
def preprocess(data, batch_size, condition):
    
    
    # image should be float
    data = data.map(lambda x, t: (tf.cast(x, float), t))
    # image should be flattened
    data = data.map(lambda x, t: (tf.reshape(x, (-1,)), t))
    # image vector will here have values between -1 and 1
    data = data.map(lambda x,t: ((x/255.), t))
    
    
    # we want to have two mnist images in each example
    # this leads to a single example being ((x1,y1),(x2,y2))
    zipped_ds = tf.data.Dataset.zip((data.shuffle(2000), 
                                     data.shuffle(2000)))
    
    if condition == 0:
        # use maping to change the target
        zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], x1[1]+x2[1] >= 5))
        # transform boolean target to int
        zipped_ds = zipped_ds.map(lambda x1, x2, t: (x1,x2, tf.cast(t, tf.int32)))
    else:
        # use maping to change the target
        zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], tf.one_hot(x1[1]-x2[1], 19, dtype=tf.float32)))
    
    # batch the dataset
    zipped_ds = zipped_ds.batch(batch_size)
    # prefetch
    zipped_ds = zipped_ds.prefetch(tf.data.AUTOTUNE)
    return zipped_ds

In [107]:
class TwinMNISTModel(tf.keras.Model):

    # 1. constructor
    def __init__(self, mode):
        super().__init__()
        # inherit functionality from parent class

        # optimizer, loss function and metrics
        if mode == 0:
            self.metrics_list = [tf.keras.metrics.BinaryAccuracy(),
                                 tf.keras.metrics.Mean(name="loss")]
        else:
            self.metrics_list = [tf.keras.metrics.CategoricalAccuracy(),
                                 tf.keras.metrics.Mean(name="loss"),
                                 tf.keras.metrics.TopKCategoricalAccuracy(3)]
        
        if mode == 1:
            self.optimizer = tf.keras.optimizers.Adam()
        else: 
            self.optimizer = tf.keras.optimizers.SGD(0.01, 0.03)
        
        if mode == 0:
            self.loss_function = tf.keras.losses.BinaryCrossentropy()
        else:
            self.loss_function = tf.keras.losses.CategoricalCrossentropy(from_logits = True)
            
        # layers to encode the images (both layers used for both images)
        self.dense1 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        
        self.dense3 = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        
        if mode == 0:
            self.out_layer = tf.keras.layers.Dense(1,activation=tf.nn.sigmoid)
        else:
            self.out_layer = tf.keras.layers.Dense(19,activation=None)
        
    # 2. call method (forward computation)
    def call(self, images, training=False):
        img1, img2 = images
        
        img1_x = self.dense1(img1)
        img1_x = self.dense2(img1_x)
        
        img2_x = self.dense1(img2)
        img2_x = self.dense2(img2_x)
        
        combined_x = tf.concat([img1_x, img2_x ], axis=1)
        combined_x = self.dense3(combined_x)
        return self.out_layer(combined_x)

    # 3. metrics property
    @property
    def metrics(self):
        return self.metrics_list
        # return a list with all metrics in the model

    # 4. reset all metrics objects
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()

    # 5. train step method
    @tf.function
    def train_step(self, data):
        img1, img2, label = data
        
        with tf.GradientTape() as tape:
            output = self((img1, img2), training=True)
            loss = self.loss_function(label, output)
            
        gradients = tape.gradient(loss, self.trainable_variables)
        
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # update the state of the metrics according to loss
        self.metrics[0].update_state(label, output)
        self.metrics[1].update_state(loss)
        
        # return a dictionary with metric names as keys and metric results as values
        return {m.name : m.result() for m in self.metrics}

    # 6. test_step method
    @tf.function
    def test_step(self, data):
        img1, img2, label = data
        # same as train step (without parameter updates)
        output = self((img1, img2), training=False)
        loss = self.loss_function(label, output)
        self.metrics[0].update_state(label, output)
        self.metrics[1].update_state(loss)
        
        return {m.name : m.result() for m in self.metrics}

In [108]:
def create_summary_writers(config_name):
    
    # Define where to save the logs
    # along with this, you may want to save a config file with the same name so you know what the hyperparameters were used
    # alternatively make a copy of the code that is used for later reference
    
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

    train_log_path = f"logs/{config_name}/{current_time}/train"
    val_log_path = f"logs/{config_name}/{current_time}/val"

    # log writer for training metrics
    train_summary_writer = tf.summary.create_file_writer(train_log_path)

    # log writer for validation metrics
    val_summary_writer = tf.summary.create_file_writer(val_log_path)
    
    return train_summary_writer, val_summary_writer

train_summary_writer, val_summary_writer = create_summary_writers(config_name="RUN1")

In [109]:
import tqdm
def training_loop(model, train_ds, val_ds, start_epoch,
                  epochs, train_summary_writer, 
                  val_summary_writer, save_path):

    # 1. iterate over epochs
    for e in range(start_epoch, epochs):

        # 2. train steps on all batches in the training data
        for data in tqdm.tqdm(train_ds, position=0, leave=True):
            metrics = model.train_step(data)

        # 3. log and print training metrics

        with train_summary_writer.as_default():
            # for scalar metrics:
            for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=e)
            # alternatively, log metrics individually (allows for non-scalar metrics such as tf.keras.metrics.MeanTensor)
            # e.g. tf.summary.image(name="mean_activation_layer3", data = metrics["mean_activation_layer3"],step=e)
        
        #print the metrics
        print([f"{key}: {value.numpy()}" for (key, value) in metrics.items()])
        
        # 4. reset metric objects
        model.reset_metrics()


        # 5. evaluate on validation data
        for data in val_ds:
            metrics = model.test_step(data)
        

        # 6. log validation metrics

        with val_summary_writer.as_default():
            # for scalar metrics:
            for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=e)
            # alternatively, log metrics individually (allows for non-scalar metrics such as tf.keras.metrics.MeanTensor)
            # e.g. tf.summary.image(name="mean_activation_layer3", data = metrics["mean_activation_layer3"],step=e)
            
        print([f"val_{key}: {value.numpy()}" for (key, value) in metrics.items()])
        # 7. reset metric objects
        model.reset_metrics()
        
    # 8. save model weights if save_path is given
    if save_path:
        model.save_weights(save_path)

In [110]:
# open the tensorboard logs
%tensorboard --logdir logs/

Reusing TensorBoard on port 6006 (pid 16516), started 1:20:08 ago. (Use '!kill 16516' to kill it.)

In [111]:
train_ds = preprocess(train_ds, batch_size=32, condition = 0) 
val_ds= preprocess(val_ds, batch_size=32, condition = 0) 


# 2. instantiate model
model = TwinMNISTModel(mode = 0)

# 3. choose a path to save the weights
save_path = "trained_model_RUN1"

# 4. pass arguments to training loop function

training_loop(model=model,
    train_ds=train_ds,
    val_ds=val_ds,
    start_epoch=0,
    epochs=10,
    train_summary_writer=train_summary_writer,
    val_summary_writer=val_summary_writer,
    save_path=save_path)

100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:20<00:00, 92.50it/s]


['binary_accuracy: 0.8902166485786438', 'loss: 0.27418258786201477']
['val_binary_accuracy: 0.9200999736785889', 'val_loss: 0.2031259685754776']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:13<00:00, 140.78it/s]


['binary_accuracy: 0.9273166656494141', 'loss: 0.185628741979599']
['val_binary_accuracy: 0.9406999945640564', 'val_loss: 0.15293492376804352']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:14<00:00, 127.66it/s]


['binary_accuracy: 0.9417166709899902', 'loss: 0.15113359689712524']
['val_binary_accuracy: 0.9472000002861023', 'val_loss: 0.13581843674182892']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:16<00:00, 110.65it/s]


['binary_accuracy: 0.9502666592597961', 'loss: 0.12889443337917328']
['val_binary_accuracy: 0.9555000066757202', 'val_loss: 0.11677730083465576']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:13<00:00, 140.24it/s]


['binary_accuracy: 0.9553833603858948', 'loss: 0.11709773540496826']
['val_binary_accuracy: 0.958299994468689', 'val_loss: 0.10855156183242798']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:11<00:00, 158.29it/s]


['binary_accuracy: 0.9623666405677795', 'loss: 0.1031699851155281']
['val_binary_accuracy: 0.9656000137329102', 'val_loss: 0.0964994728565216']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:11<00:00, 170.40it/s]


['binary_accuracy: 0.9666833281517029', 'loss: 0.09238152205944061']
['val_binary_accuracy: 0.9634000062942505', 'val_loss: 0.09956786036491394']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:11<00:00, 157.04it/s]


['binary_accuracy: 0.9695000052452087', 'loss: 0.0835886299610138']
['val_binary_accuracy: 0.9692999720573425', 'val_loss: 0.08595817536115646']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:12<00:00, 153.14it/s]


['binary_accuracy: 0.9733499884605408', 'loss: 0.0762295052409172']
['val_binary_accuracy: 0.9689000248908997', 'val_loss: 0.08126877248287201']


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:11<00:00, 168.48it/s]


['binary_accuracy: 0.9745333194732666', 'loss: 0.0725889801979065']
['val_binary_accuracy: 0.9761999845504761', 'val_loss: 0.07005023956298828']


In [112]:
# load the model:
fresh_model = TwinMNISTModel(1)

# build the model's parameters by calling it on input
for img1,img2,label in train_ds:
    fresh_model((img1,img2));
    break

# load the saved weights
model = fresh_model.load_weights(save_path)

# we could now continue training (will use a fresh optimizer, without the old optimizer state)

























ValueError: Received incompatible tensor with shape (1,) when attempting to restore variable with shape (19,) and name out_layer/bias/.ATTRIBUTES/VARIABLE_VALUE.