In [11]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tqdm

In [12]:
def prepare_data(input_ds, subtask):
    # First Step
    # change datatype to float32
    ds = input_ds.map(lambda img, target: (tf.cast(img, tf.float32), target))
    # flatten images
    ds = ds.map(lambda img, target: (tf.reshape(img, (-1,)), target))
    # normalize image values
    ds = ds.map(lambda img, target: ((img/128.)-1.0, target))

    # Second Step
    zipped_ds = tf.data.Dataset.zip((ds.shuffle(1000), ds.shuffle(1000)))
    # a + b >= 5 ?
    if subtask == 1:
        zipped_ds = zipped_ds.map(lambda x1, x2:
                                 (x1[0], x2[0], x1[1] + x2[1] >= 5))
    # a - b = ?
    elif subtask == 2:
        zipped_ds = zipped_ds.map(lambda x1, x2:
                                 (x1[0], x2[0], x1[1] - x2[1]))
    else:
        print(f"{subtask} is not a valid subtask!")
        exit(-1)

    # Third Step
    # shuffle the data
    zipped_ds = zipped_ds.shuffle(1000)
    # create batches
    zipped_ds = zipped_ds.batch(32)
    # prefetch
    zipped_ds = zipped_ds.prefetch(tf.data.AUTOTUNE)

    return zipped_ds


In [13]:
def get_mnist_data(subtask):
    train_dataset, test_dataset = tfds.load("mnist", split=["train", "test"],
                                        as_supervised=True)

    train_ds = prepare_data(train_dataset, subtask)
    test_ds = prepare_data(test_dataset, subtask)

    return (train_ds, test_ds)

In [14]:
# load dataset
train_dataset, test_dataset = tfds.load("mnist", split=["train", "test"],
                                        as_supervised=True)

# transform dataset
# subtask 1
train_ds_t1 = prepare_data(train_dataset, 1)
test_ds_t1 = prepare_data(test_dataset, 1)

# subtask 2
train_ds_t2 = prepare_data(train_dataset, 2)
test_ds_t2 = prepare_data(test_dataset, 2)



In [15]:
class MNISTMath(tf.keras.Model):
    def __init__(self, subtask, optimizer):
        super().__init__()
        self.optimizer = optimizer

        self.dense1 = tf.keras.layers.Dense(32, activation="relu")
        self.dense2 = tf.keras.layers.Dense(32, activation="relu")

        # a + b >= 5 ?
        if subtask == 1:
            self.metrics_list = [tf.keras.metrics.Mean(name="loss"),
                                 tf.keras.metrics.BinaryAccuracy()]
            self.loss_function = tf.keras.losses.BinaryCrossentropy()
            self.out_layer = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)
        # a + b = ?
        elif subtask == 2:
            self.metrics_list = [tf.keras.metrics.Mean(name="loss"),
                                 tf.keras.metrics.MeanSquaredError()]
            self.loss_function = tf.keras.losses.MeanSquaredError()
            self.out_layer = tf.keras.layers.Dense(1, activation=None)
        else:
            print(f"{subtask} is not a valid subtask!")
            exit(-1)



    def call(self, images):
        img1, img2 = images

        img1_x = self.dense1(img1)
        img1_x = self.dense2(img1_x)

        img2_x = self.dense1(img2)
        img2_x = self.dense2(img2_x)

        combined_x = tf.concat([img1_x, img2_x], axis=1)

        return self.out_layer(combined_x)

    
    @property
    def metrics(self):
        return self.metrics_list

    
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()

    
    def train_step(self, data):
        img1, img2, label = data

        with tf.GradientTape() as tape:
            output = self((img1, img2))
            loss = self.loss_function(label, output)
        
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, 
                                           self.trainable_variables))

        self.metrics[0].update_state(loss)
        self.metrics[1].update_state(output, label)

        return {m.name: m.result() for m in self.metrics}

    
    def test_step(self, data):
        img1, img2, label = data
        output = self((img1, img2))
        loss = self.loss_function(label, output)

        self.metrics[0].update_state(loss)
        self.metrics[1].update_state(output, label)

        return {f"val_{m.name}": m.result() for m in self.metrics}

In [16]:
def train_model(model, subtask):
    train_ds, test_ds = get_mnist_data(subtask)

    for e in range(10):
        for data in tqdm.tqdm(train_ds):
            metrics = model.train_step(data)

        for name, value in metrics.items():
            print(f"Epoch {e}: {name}: {value}")
        model.reset_metrics()

        for data in test_ds:
            metrics = model.test_step(data)
        
        for name, value in metrics.items():
            print(f"Epoch {e}: {name}: {value}")
        model.reset_metrics()        



In [10]:
optimizer = tf.keras.optimizers.Adam()
# optimizer = tf.keras.optimizers.SGD()

model1 = MNISTMath(1, optimizer)
model2 = MNISTMath(2, optimizer)

train_model(model1, 1)
train_model(model2, 2)

100%|██████████| 1875/1875 [00:29<00:00, 63.43it/s]


Epoch 0: loss: 0.2126404494047165
Epoch 0: binary_accuracy: 0.0
Epoch 0: val_loss: 0.15290650725364685
Epoch 0: val_binary_accuracy: 0.0


100%|██████████| 1875/1875 [00:27<00:00, 68.42it/s]


Epoch 1: loss: 0.1497514545917511
Epoch 1: binary_accuracy: 0.0013000000035390258
Epoch 1: val_loss: 0.1421184539794922
Epoch 1: val_binary_accuracy: 0.006500000134110451


100%|██████████| 1875/1875 [00:25<00:00, 73.57it/s]


Epoch 2: loss: 0.13071174919605255
Epoch 2: binary_accuracy: 0.013383333571255207
Epoch 2: val_loss: 0.1119108721613884
Epoch 2: val_binary_accuracy: 0.018799999728798866


100%|██████████| 1875/1875 [00:27<00:00, 68.02it/s]


Epoch 3: loss: 0.1191619336605072
Epoch 3: binary_accuracy: 0.04659999907016754
Epoch 3: val_loss: 0.12256928533315659
Epoch 3: val_binary_accuracy: 0.039000000804662704


100%|██████████| 1875/1875 [00:26<00:00, 71.89it/s]


Epoch 4: loss: 0.11309373378753662
Epoch 4: binary_accuracy: 0.07916666567325592
Epoch 4: val_loss: 0.10953357070684433
Epoch 4: val_binary_accuracy: 0.0674000009894371


100%|██████████| 1875/1875 [00:25<00:00, 74.05it/s]


Epoch 5: loss: 0.10686416923999786
Epoch 5: binary_accuracy: 0.10268333554267883
Epoch 5: val_loss: 0.11979203671216965
Epoch 5: val_binary_accuracy: 0.11469999700784683


100%|██████████| 1875/1875 [00:25<00:00, 74.41it/s]


Epoch 6: loss: 0.10306323319673538
Epoch 6: binary_accuracy: 0.13696666061878204
Epoch 6: val_loss: 0.10645131021738052
Epoch 6: val_binary_accuracy: 0.14800000190734863


100%|██████████| 1875/1875 [00:25<00:00, 73.57it/s]


Epoch 7: loss: 0.1007736548781395
Epoch 7: binary_accuracy: 0.14678333699703217
Epoch 7: val_loss: 0.10923969745635986
Epoch 7: val_binary_accuracy: 0.10199999809265137


100%|██████████| 1875/1875 [00:28<00:00, 65.91it/s]


Epoch 8: loss: 0.09834270179271698
Epoch 8: binary_accuracy: 0.16113333404064178
Epoch 8: val_loss: 0.12087224423885345
Epoch 8: val_binary_accuracy: 0.2029000073671341


100%|██████████| 1875/1875 [00:25<00:00, 74.60it/s]


Epoch 9: loss: 0.09840365499258041
Epoch 9: binary_accuracy: 0.16553333401679993
Epoch 9: val_loss: 0.09328638762235641
Epoch 9: val_binary_accuracy: 0.18960000574588776


100%|██████████| 1875/1875 [00:25<00:00, 74.69it/s]


Epoch 10: loss: 0.09403600543737411
Epoch 10: binary_accuracy: 0.20021666586399078
Epoch 10: val_loss: 0.10267162322998047
Epoch 10: val_binary_accuracy: 0.23090000450611115


100%|██████████| 1875/1875 [00:25<00:00, 72.89it/s]


Epoch 11: loss: 0.09114468097686768
Epoch 11: binary_accuracy: 0.21141666173934937
Epoch 11: val_loss: 0.09608200192451477
Epoch 11: val_binary_accuracy: 0.23309999704360962


100%|██████████| 1875/1875 [00:25<00:00, 74.01it/s]


Epoch 12: loss: 0.0903710126876831
Epoch 12: binary_accuracy: 0.20928333699703217
Epoch 12: val_loss: 0.10011415928602219
Epoch 12: val_binary_accuracy: 0.2289000004529953


100%|██████████| 1875/1875 [00:24<00:00, 75.42it/s]


Epoch 13: loss: 0.09052987396717072
Epoch 13: binary_accuracy: 0.2054833322763443
Epoch 13: val_loss: 0.09986131638288498
Epoch 13: val_binary_accuracy: 0.19820000231266022


100%|██████████| 1875/1875 [00:24<00:00, 76.42it/s]


Epoch 14: loss: 0.09195605665445328
Epoch 14: binary_accuracy: 0.2061833292245865
Epoch 14: val_loss: 0.10334300249814987
Epoch 14: val_binary_accuracy: 0.22339999675750732


100%|██████████| 1875/1875 [00:24<00:00, 75.48it/s]


Epoch 15: loss: 0.08883222937583923
Epoch 15: binary_accuracy: 0.2242833375930786
Epoch 15: val_loss: 0.09720482677221298
Epoch 15: val_binary_accuracy: 0.1826000064611435


100%|██████████| 1875/1875 [00:24<00:00, 76.91it/s]


Epoch 16: loss: 0.08489541709423065
Epoch 16: binary_accuracy: 0.23768332600593567
Epoch 16: val_loss: 0.09464848041534424
Epoch 16: val_binary_accuracy: 0.2176000028848648


100%|██████████| 1875/1875 [00:24<00:00, 76.34it/s]


Epoch 17: loss: 0.08626019954681396
Epoch 17: binary_accuracy: 0.23890000581741333
Epoch 17: val_loss: 0.10198549926280975
Epoch 17: val_binary_accuracy: 0.24500000476837158


100%|██████████| 1875/1875 [00:24<00:00, 76.33it/s]


Epoch 18: loss: 0.08555004745721817
Epoch 18: binary_accuracy: 0.2433999925851822
Epoch 18: val_loss: 0.09461842477321625
Epoch 18: val_binary_accuracy: 0.24899999797344208


100%|██████████| 1875/1875 [00:24<00:00, 76.22it/s]


Epoch 19: loss: 0.0880432054400444
Epoch 19: binary_accuracy: 0.23943333327770233
Epoch 19: val_loss: 0.0829295739531517
Epoch 19: val_binary_accuracy: 0.2547000050544739


100%|██████████| 1875/1875 [00:24<00:00, 75.77it/s]


Epoch 20: loss: 0.08079573512077332
Epoch 20: binary_accuracy: 0.2716333270072937
Epoch 20: val_loss: 0.08727940917015076
Epoch 20: val_binary_accuracy: 0.2897000014781952


100%|██████████| 1875/1875 [00:24<00:00, 76.88it/s]


Epoch 21: loss: 0.08404643833637238
Epoch 21: binary_accuracy: 0.2662000060081482
Epoch 21: val_loss: 0.09838862717151642
Epoch 21: val_binary_accuracy: 0.26840001344680786


100%|██████████| 1875/1875 [00:24<00:00, 76.85it/s]


Epoch 22: loss: 0.0799475908279419
Epoch 22: binary_accuracy: 0.281416654586792
Epoch 22: val_loss: 0.10145848244428635
Epoch 22: val_binary_accuracy: 0.25060001015663147


100%|██████████| 1875/1875 [00:24<00:00, 76.93it/s]


Epoch 23: loss: 0.08235591650009155
Epoch 23: binary_accuracy: 0.27364999055862427
Epoch 23: val_loss: 0.09721849113702774
Epoch 23: val_binary_accuracy: 0.31859999895095825


100%|██████████| 1875/1875 [00:24<00:00, 76.89it/s]


Epoch 24: loss: 0.07812958210706711
Epoch 24: binary_accuracy: 0.303849995136261
Epoch 24: val_loss: 0.10552582144737244
Epoch 24: val_binary_accuracy: 0.2890999913215637


100%|██████████| 1875/1875 [00:26<00:00, 70.42it/s]


Epoch 25: loss: 0.0817280113697052
Epoch 25: binary_accuracy: 0.2893333435058594
Epoch 25: val_loss: 0.10069206357002258
Epoch 25: val_binary_accuracy: 0.2985999882221222


 32%|███▏      | 603/1875 [00:09<00:19, 66.70it/s]


KeyboardInterrupt: 