In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [30]:
def prepare_data(input_ds, subtask):
    # First Step
    # change datatype to float32
    ds = input_ds.map(lambda img, target: (tf.cast(img, tf.float32), target))
    # flatten images
    ds = ds.map(lambda img, target: (tf.reshape(img, (-1,)), target))
    # normalize image values
    ds = ds.map(lambda img, target: ((img/128.)-1.0, target))

    # Second Step
    zipped_ds = tf.data.Dataset.zip((ds.shuffle(1000), ds.shuffle(1000)))
    # a + b >= 5 ?
    if subtask == 1:
        zipped_ds = zipped_ds.map(lambda x1, x2:
                                 (x1[0], x2[0], x1[1] + x2[1] >= 5))
    # a - b = ?
    elif subtask == 2:
        zipped_ds = zipped_ds.map(lambda x1, x2:
                                 (x1[0], x2[0], x1[1] - x2[1]))
    else:
        print(f"{subtask} is not a valid subtask!")
        exit(-1)

    # Third Step
    # shuffle the data
    zipped_ds = zipped_ds.shuffle(1000)
    # create batches
    zipped_ds = zipped_ds.batch(32)
    # prefetch
    zipped_ds = zipped_ds.prefetch(tf.data.AUTOTUNE)

    return zipped_ds


In [None]:
def get_mnist_data(subtask):
    train_dataset, test_dataset = tfds.load("mnist", split=["train", "test"],
                                        as_supervised=True)

    train_ds = prepare_data(train_dataset, subtask)
    test_ds = prepare_data(test_dataset, subtask)

    return (train_ds, test_ds)

In [31]:
# load dataset
train_dataset, test_dataset = tfds.load("mnist", split=["train", "test"],
                                        as_supervised=True)

# transform dataset
# subtask 1
train_ds_t1 = prepare_data(train_dataset, 1)
test_ds_t1 = prepare_data(test_dataset, 1)

# subtask 2
train_ds_t2 = prepare_data(train_dataset, 2)
test_ds_t2 = prepare_data(test_dataset, 2)



In [32]:
class MNISTMath(tf.keras.Model):
    def __init__(self, subtask, optimizer):
        super().__init__()

        self.optimizer = optimizer

        self.dense1 = tf.keras.layers.Dense(32, activation="relu")
        self.dense2 = tf.keras.layers.Dense(32, activation="relu")

        # a + b >= 5 ?
        if subtask == 1:
            self.metrics_list = [tf.keras.metrics.BinaryAccuracy()]
            self.loss_function = tf.keras.losses.BinaryCrossentropy()
            self.out_layer = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)
        # a + b = ?
        elif subtask == 2:
            self.metrics_list = [tf.keras.metrics.MeanSquaredError()]
            self.loss_function = tf.keras.losses.MeanSquaredError()
            self.out_layer = tf.keras.layers.Dense(1, activation=None)
        else:
            print(f"{subtask} is not a valid subtask!")
            exit(-1)



    def call(self, images):
        img1, img2 = images

        img1_x = self.dense1(img1)
        img1_x = self.dense2(img1_x)

        img2_x = self.dense1(img2)
        img2_x = self.dense2(img2_x)

        combined_x = tf.concat([img1_x, img2_x], axis=1)

        return self.outer_layer(combined_x)

    
    @property
    def metrics(self):
        return self.metrics_list

    
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()

    
    def train_step(self, data):
        img1, img2, label = data

        with tf.GradientTape() as tape:
            output = self((img1, img2))
            loss = self.loss_function(label, output)
        
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, 
                                           self.trainable_variables))
        self.metrics[0].update_state(loss)

        return {m.name: m.result() for m in self.metrics}

    
    def test_step(self, data):
        img1, img2, label = data
        output = self((img1, img2))
        loss = self.loss_function(output, label)

        self.metrics[0].update_state(loss)

        return {m.name: m.result() for m in self.metrics}

In [None]:
def train_model(subtask, optimizer):
    model = MNISTMath(subtask, optimizer)
    train_ds, test_ds = get_mnist_data(subtask)
