In [6]:
import logging; logging.basicConfig(level=logging.INFO)
import tensorflow as tf
import numpy as np
import ltn

In [7]:
a = ltn.Proposition(0.2,trainable=True)
b = ltn.Proposition(0.5,trainable=True)
c = ltn.Proposition(0.5,trainable=True)
w1 = ltn.Proposition(0.3, trainable=False)
w2 = ltn.Proposition(0.9, trainable=False)

x = ltn.Variable("x", np.array([[1,2],[3,4],[5,6]]))
P = ltn.Predicate.MLP(input_shapes=[(2)])

In [8]:
Not = ltn.Wrapper_Connective(ltn.fuzzy_ops.Not_Std())
And = ltn.Wrapper_Connective(ltn.fuzzy_ops.And_Prod())
Or = ltn.Wrapper_Connective(ltn.fuzzy_ops.Or_ProbSum())
Implies = ltn.Wrapper_Connective(ltn.fuzzy_ops.Implies_Reichenbach())
Forall = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMeanError(p=5),semantics="forall")
Exists = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMean(p=10),semantics="exists")

In [9]:
formula_aggregator = ltn.Wrapper_Formula_Aggregator(ltn.fuzzy_ops.Aggreg_Mean())

@tf.function
def axioms():
    axioms = [
        # [ (A and B and (forall x: P(x))) -> Not C ] and C
        And(
            Implies(And(And(a,b),Forall(x,P(x))),
                    Not(c)),
            c
        ),
        # w1 -> (forall x: P(x))
        Implies(w1, Forall(x,P(x))),
        # w2 -> (Exists x: P(x))
        Implies(w2, Exists(x,P(x)))
    ]
    sat_level = formula_aggregator(axioms).tensor
    return sat_level

In [10]:
trainable_variables = ltn.as_tensors([a,b,c])
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

for epoch in range(1000):
    with tf.GradientTape() as tape:
        loss_value = 1. - axioms()
    grads = tape.gradient(loss_value, trainable_variables)
    optimizer.apply_gradients(zip(grads, trainable_variables))
    if epoch%100 == 0:
        print("Epoch %d: Sat Level %.3f"%(epoch, axioms()))
print("Training finished at Epoch %d with Sat Level %.3f"%(epoch, axioms()))

Epoch 0: Sat Level 0.575
Epoch 100: Sat Level 0.678
Epoch 200: Sat Level 0.737
Epoch 300: Sat Level 0.740
Epoch 400: Sat Level 0.742
Epoch 500: Sat Level 0.743
Epoch 600: Sat Level 0.743
Epoch 700: Sat Level 0.743
Epoch 800: Sat Level 0.743
Epoch 900: Sat Level 0.743
Training finished at Epoch 999 with Sat Level 0.743
