# MNIST Digit Subtraction Problem

In [19]:
import tensorflow as tf
import ltn
import baselines, data, commons
import matplotlib.pyplot as plt
from collections import defaultdict

## Data

In [20]:
ds_train, ds_test = data.get_op_dataset(
        data_loader_fn = data.get_mnist_data_as_numpy,
        count_train    = 3000,
        count_test     = 1000,
        buffer_size    = 3000,
        batch_size     = 16,
        n_operands     = 2,
        op             = lambda args: args[0] + args[1] 
)

## LTN

In [21]:
logits_model = baselines.SingleDigit(inputs_as_a_list=True)

Digit = ltn.Predicate.FromLogits(logits_model, activation_function="softmax")

d1 = ltn.Variable("digits1", range(10))
d2 = ltn.Variable("digits2", range(10))

Not = ltn.Wrapper_Connective(ltn.fuzzy_ops.Not_Std())
And = ltn.Wrapper_Connective(ltn.fuzzy_ops.And_Prod())
Or = ltn.Wrapper_Connective(ltn.fuzzy_ops.Or_ProbSum())
Implies = ltn.Wrapper_Connective(ltn.fuzzy_ops.Implies_Reichenbach())
Forall = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMeanError(),semantics="forall")
Exists = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMean(),semantics="exists")

In [23]:
# RUNNING IT AGAIN MAKES THE ERROR GO AWAY!
# mask
add = ltn.Function.Lambda(lambda inputs: inputs[0]+inputs[1])
equals = ltn.Predicate.Lambda(lambda inputs: inputs[0] == inputs[1])

### Axioms
@tf.function
def axioms(images_x, images_y, labels_z, p_schedule=tf.constant(2.)):
    images_x = ltn.Variable("x", images_x)
    images_y = ltn.Variable("y", images_y)
    labels_z = ltn.Variable("z", labels_z)
    axiom = Forall(
            ltn.diag(images_x,images_y,labels_z),
            Exists(
                (d1,d2),
                And(Digit([images_x,d1]),Digit([images_y,d2])),
                mask=equals([add([d1,d2]), labels_z]),
                p=p_schedule
            ),
            p=2
        )
    sat = axiom.tensor
    return sat

images_x, images_y, labels_z = next(ds_train.as_numpy_iterator())
axioms(images_x, images_y, labels_z)

<tf.Tensor: shape=(), dtype=float32, numpy=0.010468542575836182>

# Optimizer, training steps and metrics

In [24]:
optimizer = tf.keras.optimizers.Adam(0.001)
metrics_dict = {
    'train_loss': tf.keras.metrics.Mean(name="train_loss"),
    'train_accuracy': tf.keras.metrics.Mean(name="train_accuracy"),
    'test_loss': tf.keras.metrics.Mean(name="test_loss"),
    'test_accuracy': tf.keras.metrics.Mean(name="test_accuracy")    
}

@tf.function
def train_step(images_x, images_y, labels_z, **parameters):
    # loss
    with tf.GradientTape() as tape:
        loss = 1.- axioms(images_x, images_y, labels_z, **parameters)
    gradients = tape.gradient(loss, logits_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, logits_model.trainable_variables))
    metrics_dict['train_loss'](loss)
    # accuracy
    predictions_x = tf.argmax(logits_model([images_x]),axis=-1)
    predictions_y = tf.argmax(logits_model([images_y]),axis=-1)
    predictions_z = predictions_x + predictions_y
    match = tf.equal(predictions_z,tf.cast(labels_z,predictions_z.dtype))
    metrics_dict['train_accuracy'](tf.reduce_mean(tf.cast(match,tf.float32)))
    
@tf.function
def test_step(images_x, images_y, labels_z, **parameters):
    # loss
    loss = 1.- axioms(images_x, images_y, labels_z, **parameters)
    metrics_dict['test_loss'](loss)
    # accuracy
    predictions_x = tf.argmax(logits_model([images_x]),axis=-1)
    predictions_y = tf.argmax(logits_model([images_y]),axis=-1)
    predictions_z = predictions_x + predictions_y
    match = tf.equal(predictions_z,tf.cast(labels_z,predictions_z.dtype))
    metrics_dict['test_accuracy'](tf.reduce_mean(tf.cast(match,tf.float32)))

Training

In [25]:


scheduled_parameters = defaultdict(lambda: {})
for epoch in range(0,4):
    scheduled_parameters[epoch] = {"p_schedule":tf.constant(1.)}
for epoch in range(4,8):
    scheduled_parameters[epoch] = {"p_schedule":tf.constant(2.)}
for epoch in range(8,12):
    scheduled_parameters[epoch] = {"p_schedule":tf.constant(4.)}
for epoch in range(12,20):
    scheduled_parameters[epoch] = {"p_schedule":tf.constant(6.)}

In [26]:
commons.train(
    20,
    metrics_dict,
    ds_train,
    ds_test,
    train_step,
    test_step,
    scheduled_parameters=scheduled_parameters
)

Epoch 0, train_loss: 0.9365, train_accuracy: 0.3803, test_loss: 0.8861, test_accuracy: 0.6290
Epoch 1, train_loss: 0.8603, train_accuracy: 0.8152, test_loss: 0.8570, test_accuracy: 0.7937
Epoch 2, train_loss: 0.8459, train_accuracy: 0.8886, test_loss: 0.8463, test_accuracy: 0.8423
Epoch 3, train_loss: 0.8398, train_accuracy: 0.9186, test_loss: 0.8496, test_accuracy: 0.8313
Epoch 4, train_loss: 0.6506, train_accuracy: 0.9146, test_loss: 0.6536, test_accuracy: 0.8730
Epoch 5, train_loss: 0.6315, train_accuracy: 0.9382, test_loss: 0.6392, test_accuracy: 0.9067
Epoch 6, train_loss: 0.6252, train_accuracy: 0.9495, test_loss: 0.6386, test_accuracy: 0.9028
Epoch 7, train_loss: 0.6222, train_accuracy: 0.9564, test_loss: 0.6406, test_accuracy: 0.8978
Epoch 8, train_loss: 0.4312, train_accuracy: 0.9505, test_loss: 0.5057, test_accuracy: 0.8502
Epoch 9, train_loss: 0.4292, train_accuracy: 0.9491, test_loss: 0.4654, test_accuracy: 0.8948
Epoch 10, train_loss: 0.4197, train_accuracy: 0.9571, test_l