Changes: core->_SoftmaxTfModel -> SoftmaxTfModel
 metrics.reset_states() -> metrics.reset_state()
 remove 0 from the dataset

In [26]:
import tensorflow as tf
import ltn
import baselines, data, commons
import matplotlib.pyplot as plt
import numpy as np

# Dataset preprocessing
## Importing dataset

In [27]:
mnist = tf.keras.datasets.mnist
(img_train, label_train), (img_test, label_test) = mnist.load_data()

# normalising the pixel values
img_train, img_test = img_train/255.0, img_test/255.0

# adding a channel dimension for compatibility with the convolutional layers
img_train = img_train[...,tf.newaxis]
img_test = img_test[...,tf.newaxis]

## Removing images with the 0 digit

In [28]:
# train data without label 0
not_zeros_train = label_train != 0
img_train = img_train[not_zeros_train]
label_train = label_train[not_zeros_train]

#test data without label 0
not_zeros_test = label_test != 0
img_test = img_test[not_zeros_test]
label_test = label_test[not_zeros_test]

# how much data will be considered
count_train = 10000
count_test = 3000
n_operands = 2

# operation
op = lambda args: args[0]%args[1]

# train data
img_per_operand_train = [img_train[i*count_train:i*count_train+count_train] for i in range(n_operands)]
label_per_operand_train = [label_train[i*count_train:i*count_train+count_train] for i in range(n_operands)]
label_result_train = np.apply_along_axis(op,0,label_per_operand_train)

# test data
img_per_operand_test = [img_test[i*count_test:i*count_test+count_test] for i in range(n_operands)]
label_per_operand_test = [label_test[i*count_test:i*count_test+count_test] for i in range(n_operands)]
label_result_test = np.apply_along_axis(op,0,label_per_operand_test)

## Creating tf datasets of specific buffer and batch size

In [29]:
buffer_size = 3000
batch_size  = 16

# training set
ds_train = tf.data.Dataset.from_tensor_slices(
              ((img_per_operand_train[0],
                img_per_operand_train[1]),
               label_result_train)
           )\
           .shuffle(buffer_size)\
           .batch(batch_size)\
           .prefetch(tf.data.AUTOTUNE)

# test set
ds_test  = tf.data.Dataset.from_tensor_slices(
              ((img_per_operand_test[0],
                img_per_operand_test[1]),
               label_result_test)
           )\
           .batch(batch_size)\
           .prefetch(tf.data.AUTOTUNE)

## Neural Network

In [30]:
def make_base_cnn():
    return tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation="relu", input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D(64, 3, activation="relu"),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dropout(0.25)
    ], name="digit_cnn")

base_cnn = make_base_cnn()

inp1 = tf.keras.layers.Input(shape=(28, 28, 1), name="x")
inp2 = tf.keras.layers.Input(shape=(28, 28, 1), name="y")

feat1 = base_cnn(inp1)
feat2 = base_cnn(inp2)

concat = tf.keras.layers.Concatenate()([feat1, feat2])
out    = tf.keras.layers.Dense(9, activation="softmax")(concat)

model = tf.keras.Model(inputs=[inp1, inp2], outputs=out)

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

EPOCHS = 20
history = model.fit(
    ds_train,
    epochs=EPOCHS,
    validation_data=ds_test,
    verbose=2
)

test_loss, test_acc = model.evaluate(ds_test, verbose=0)
print(f"\nModulo-CNN test accuracy: {test_acc:.4f}")


Epoch 1/20
625/625 - 7s - 10ms/step - accuracy: 0.6197 - loss: 1.0071 - val_accuracy: 0.7130 - val_loss: 0.7097
Epoch 2/20
625/625 - 6s - 9ms/step - accuracy: 0.7328 - loss: 0.6679 - val_accuracy: 0.7513 - val_loss: 0.6095
Epoch 3/20
625/625 - 6s - 10ms/step - accuracy: 0.7693 - loss: 0.5833 - val_accuracy: 0.7753 - val_loss: 0.5676
Epoch 4/20
625/625 - 6s - 10ms/step - accuracy: 0.7876 - loss: 0.5333 - val_accuracy: 0.7860 - val_loss: 0.5344
Epoch 5/20
625/625 - 6s - 10ms/step - accuracy: 0.7958 - loss: 0.5028 - val_accuracy: 0.7843 - val_loss: 0.5251
Epoch 6/20
625/625 - 6s - 10ms/step - accuracy: 0.8061 - loss: 0.4688 - val_accuracy: 0.7920 - val_loss: 0.5173
Epoch 7/20
625/625 - 8s - 12ms/step - accuracy: 0.8111 - loss: 0.4481 - val_accuracy: 0.7930 - val_loss: 0.5382
Epoch 8/20
625/625 - 9s - 15ms/step - accuracy: 0.8242 - loss: 0.4267 - val_accuracy: 0.7903 - val_loss: 0.5349
Epoch 9/20
625/625 - 6s - 10ms/step - accuracy: 0.8361 - loss: 0.4080 - val_accuracy: 0.7840 - val_loss: 