In [1]:
import functools
from typing import Callable, Union

import clingo
import clingo.ast
from clingo.ast import ProgramBuilder
from clingox.backend import SymbolicBackend

In [2]:
yale = """

time(0..3).

observe(chicken, alive, 0).
observe(alice, alive, 0).
occurs(alice, (shoot, chicken), 1).

observe(Agent, alive, T+1) :-
  time(T+1),
  observe(Agent, alive, T),
  not observe(Agent, dead, T+1).

observe(Agent, dead, T+1) :-
  time(T+1),
  observe(Agent, dead, T).

observe(Agent1, dead, T+1) :-
  time(T+1),
  occurs(Agent2, (shoot, Agent1), T),
  observe(Agent1, alive, T).

"""

In [3]:
ctl = clingo.Control()
ctl.configuration.solve.models = 0

In [4]:
ctl.add("base", (), yale)

In [5]:
ctl.ground([("base", [])])

In [6]:
def temporal(symbol: clingo.Symbol, default: Union[int, Callable[[clingo.Symbol], int], None] = None):
    if symbol.match('occurs', 3) or symbol.match('observe', 3):
        return symbol.arguments[-1].number
    if callable(default):
        return default(symbol)
    return default

In [7]:
with ctl.solve(yield_=True) as solve_handle:
    models = []
    for model in solve_handle:
        symbols = sorted(sorted(model.symbols(shown=True)), key=functools.partial(temporal, default=-1))
        print("Answer {}:".format(model.number), end=' ')
        print("{",
              '\n'.join(map(str, symbols)), "}", sep='\n')
        models.append(symbols)
    solve_result = solve_handle.get()
    print(solve_result, end='')
    if models:
        print(" {}{}".format(len(models), '' if solve_result.exhausted else '+'))

Answer 1: {
time(0)
time(1)
time(2)
time(3)
observe(alice,alive,0)
observe(chicken,alive,0)
observe(alice,alive,1)
observe(chicken,alive,1)
occurs(alice,(shoot,chicken),1)
observe(alice,alive,2)
observe(chicken,dead,2)
observe(alice,alive,3)
observe(chicken,dead,3)
}
SAT 1


In [8]:
nodes = []
clingo.ast.parse_string(yale, lambda stm: nodes.append(stm))
print('\n'.join(map(str, nodes)))

#program base.
time((0..3)).
observe(chicken,alive,0).
observe(alice,alive,0).
occurs(alice,(shoot,chicken),1).
observe(Agent,alive,(T+1)) :- time((T+1)); observe(Agent,alive,T); not observe(Agent,dead,(T+1)).
observe(Agent,dead,(T+1)) :- time((T+1)); observe(Agent,dead,T).
observe(Agent1,dead,(T+1)) :- time((T+1)); occurs(Agent2,(shoot,Agent1),T); observe(Agent1,alive,T).


In [9]:
model = models[0]
print('. '.join(map(str, model)), '.', sep='')

time(0). time(1). time(2). time(3). observe(alice,alive,0). observe(chicken,alive,0). observe(alice,alive,1). observe(chicken,alive,1). occurs(alice,(shoot,chicken),1). observe(alice,alive,2). observe(chicken,dead,2). observe(alice,alive,3). observe(chicken,dead,3).


In [10]:
def ast_match(node: clingo.ast.AST, symbol: clingo.Symbol):
    print(f"{node} ({node.ast_type})", '=?=', f"{symbol} ({symbol.type})")
    if node.ast_type is clingo.ast.ASTType.SymbolicTerm:
        return node.symbol == symbol
    elif node.ast_type is clingo.ast.ASTType.Variable:
        return True
    elif node.ast_type is clingo.ast.ASTType.Function:
        return ast_function_match(node, symbol)
    elif node.ast_type is clingo.ast.ASTType.BinaryOperation:
        return ast_descend_variable(node)
    else:
        return False


def ast_descend_variable(node: clingo.ast.AST):
    queue = [node]
    while queue:
        current = queue.pop()
        if current.ast_type is clingo.ast.ASTType.Variable:
            return True
        elif current.ast_type is clingo.ast.ASTType.BinaryOperation:
            queue.append(current.right)
            queue.append(current.left)

    return False


def ast_function_match(function: clingo.ast.AST, symbol: clingo.Symbol):
    if function.ast_type is not clingo.ast.ASTType.Function or symbol.type is not clingo.SymbolType.Function:
        print(function.ast_type, '=x=', clingo.SymbolType)
        return False
    else:
        matches = True
        if function.name != symbol.name:
            print(function.name, '=/=', symbol.name)
            matches = False
        if len(function.arguments) != len(symbol.arguments):
            print(len(function.arguments), '=/=', len(symbol.arguments))
            matches = False
        for i in range(len(symbol.arguments)):
            matches = matches and ast_match(function.arguments[i], symbol.arguments[i])
        if matches:
            print(function, '==', symbol)
        return matches

In [11]:
observations = [symbol for symbol in model if symbol.match('observe', 3)]
print('\n'.join(map(str, observations)))

observe(alice,alive,0)
observe(chicken,alive,0)
observe(alice,alive,1)
observe(chicken,alive,1)
observe(alice,alive,2)
observe(chicken,dead,2)
observe(alice,alive,3)
observe(chicken,dead,3)


In [12]:
target_belief = observations[-1]
print(target_belief)  # Why is the chicken dead at the end?

observe(chicken,dead,3)


In [13]:
support_axioms = [node for node in nodes if
                  node.ast_type is clingo.ast.ASTType.Rule and ast_match(node.head.atom.symbol, target_belief)]

time((0..3)) (ASTType.Function) =?= observe(chicken,dead,3) (SymbolType.Function)
time =/= observe
1 =/= 3
observe(chicken,alive,0) (ASTType.Function) =?= observe(chicken,dead,3) (SymbolType.Function)
chicken (ASTType.SymbolicTerm) =?= chicken (SymbolType.Function)
alive (ASTType.SymbolicTerm) =?= dead (SymbolType.Function)
observe(alice,alive,0) (ASTType.Function) =?= observe(chicken,dead,3) (SymbolType.Function)
alice (ASTType.SymbolicTerm) =?= chicken (SymbolType.Function)
occurs(alice,(shoot,chicken),1) (ASTType.Function) =?= observe(chicken,dead,3) (SymbolType.Function)
occurs =/= observe
observe(Agent,alive,(T+1)) (ASTType.Function) =?= observe(chicken,dead,3) (SymbolType.Function)
Agent (ASTType.Variable) =?= chicken (SymbolType.Function)
alive (ASTType.SymbolicTerm) =?= dead (SymbolType.Function)
observe(Agent,dead,(T+1)) (ASTType.Function) =?= observe(chicken,dead,3) (SymbolType.Function)
Agent (ASTType.Variable) =?= chicken (SymbolType.Function)
dead (ASTType.SymbolicTerm) =?

In [14]:
print('\n'.join(map(str, support_axioms)))

observe(Agent,dead,(T+1)) :- time((T+1)); observe(Agent,dead,T).
observe(Agent1,dead,(T+1)) :- time((T+1)); occurs(Agent2,(shoot,Agent1),T); observe(Agent1,alive,T).


In [15]:
pos = clingo.ast.Position('<string>', 1, 1)
loc = clingo.ast.Location(pos, pos)

In [16]:
ctl = clingo.Control()
ctl.configuration.solve.models = 0

In [17]:
with SymbolicBackend(ctl.backend()) as sb:
    for fact in model:
        sb.add_rule(head=[fact])

In [18]:
class VariableRewriter(clingo.ast.Transformer):

    def __init__(self, rewrite=None):
        self.rewrite: Callable[[clingo.ast.AST], clingo.ast.AST] = rewrite or (lambda x: x)

    def visit_Variable(self, variable: clingo.ast.AST):
        return self.rewrite(variable)

class BinaryOperationRewriter(clingo.ast.Transformer):

    def visit_BinaryOperation(self, binop: clingo.ast.AST):
        if binop.operator_type == clingo.ast.BinaryOperator.Plus:
            return clingo.ast.Function(loc, 'plus', (binop.left, binop.right), False)
        return binop

In [19]:
def rewrite_to_any(variable: clingo.ast.AST, condition=True):
    assert variable.ast_type is clingo.ast.ASTType.Variable
    if (callable(condition) and condition(variable)) or (type(condition) == bool and condition):
        print("Rewriting", variable, "to", "__any")
        return clingo.ast.SymbolicTerm(loc, clingo.Function('__any'))
    else:
        return variable


def rewrite_to_varstr(variable: clingo.ast.AST, condition=True):
    assert variable.ast_type is clingo.ast.ASTType.Variable
    if (callable(condition) and condition(variable)) or (type(condition) == bool and condition):
        print("Rewriting", variable, "to", f'var("{variable}")')
        return clingo.ast.Function(loc, 'var', [clingo.ast.SymbolicTerm(loc, clingo.String(variable.name))], False)
    else:
        return variable

to_varstr = VariableRewriter(rewrite_to_varstr)

def debug_in(s, e):
    is_in = e in s
    if is_in:
        print(e, 'in', s)
    else:
        print(e, 'not in', s)
    return is_in


In [20]:
rules = []
for support_axiom in support_axioms:
    head = support_axiom.head.atom.symbol
    head_name = clingo.ast.SymbolicTerm(loc, clingo.Function(head.name))
    head_parameters = head.arguments
    body = support_axiom.body
    fires_head = clingo.ast.Literal(loc, sign=0, atom=clingo.ast.SymbolicAtom(
        clingo.ast.Function(loc, '__fires', (head_name, clingo.ast.Function(loc, '', head_parameters, False)), False)))
    fires_rule = clingo.ast.Rule(loc, fires_head, body)
    rules.append(fires_rule)
    holds_body_variables = set()
    for symbolic_atom in head_parameters:
        queue = [symbolic_atom]
        while queue:
            current = queue.pop()
            if current.ast_type is clingo.ast.ASTType.Variable:
                holds_body_variables.add(current.name)
            elif current.ast_type is clingo.ast.ASTType.BinaryOperation:
                queue.append(current.left)
                queue.append(current.right)
            elif current.ast_type is clingo.ast.ASTType.Function:
                queue.extend(current.arguments)
    print(holds_body_variables)
    holds_head_variables = []
    holds_rules = []
    for literal in body:
        symbol = literal.atom.symbol
        name = symbol.name
        args = symbol.arguments
        body_parameters = []
        body_variables = set()
        for symbolic_atom in args:
            queue = [symbolic_atom]
            while queue:
                current = queue.pop()
                if current.ast_type is clingo.ast.ASTType.Variable:
                    body_variables.add(current.name)
                elif current.ast_type is clingo.ast.ASTType.BinaryOperation:
                    queue.append(current.left)
                    queue.append(current.right)
                elif current.ast_type is clingo.ast.ASTType.Function:
                    queue.extend(current.arguments)

        holds_head_variables.append(body_variables)
        if holds_body_variables < body_variables:
            print(holds_body_variables, '<', body_variables)
            toAny = VariableRewriter(
                functools.partial(rewrite_to_any, condition=lambda v: not debug_in(holds_body_variables, v.name)))
            for arg in args:
                body_parameters.append(toAny.visit(arg))
        else:
            body_parameters.extend(args)

        holds_head = clingo.ast.Literal(loc, sign=0, atom=clingo.ast.SymbolicAtom(
            clingo.ast.Function(loc, '__holds', (
                clingo.ast.SymbolicTerm(loc, clingo.Number(literal.sign)),
                clingo.ast.SymbolicTerm(loc, clingo.Function(name)),
                clingo.ast.Function(loc, '', body_parameters, False)
            ), False)
        ))
        holds_rule = clingo.ast.Rule(loc, holds_head, (fires_head,))
        holds_rules.append(holds_rule)
    rules.extend(holds_rules)

    rule_fires_head = BinaryOperationRewriter().visit(to_varstr.visit(fires_head))
    rule_fires_rule = clingo.ast.Rule(loc, rule_fires_head, (fires_head,))
    rules.append(rule_fires_rule)

print('\n'.join(map(str, rules)))

{'T', 'Agent'}
Rewriting Agent to var("Agent")
Rewriting T to var("T")
{'T', 'Agent1'}
{'T', 'Agent1'} < {'T', 'Agent1', 'Agent2'}
Agent2 not in {'T', 'Agent1'}
Rewriting Agent2 to __any
Agent1 in {'T', 'Agent1'}
T in {'T', 'Agent1'}
Rewriting Agent1 to var("Agent1")
Rewriting T to var("T")
__fires(observe,(Agent,dead,(T+1))) :- time((T+1)); observe(Agent,dead,T).
__holds(0,time,((T+1),)) :- __fires(observe,(Agent,dead,(T+1))).
__holds(0,observe,(Agent,dead,T)) :- __fires(observe,(Agent,dead,(T+1))).
__fires(observe,(var("Agent"),dead,plus(var("T"),1))) :- __fires(observe,(Agent,dead,(T+1))).
__fires(observe,(Agent1,dead,(T+1))) :- time((T+1)); occurs(Agent2,(shoot,Agent1),T); observe(Agent1,alive,T).
__holds(0,time,((T+1),)) :- __fires(observe,(Agent1,dead,(T+1))).
__holds(0,occurs,(__any,(shoot,Agent1),T)) :- __fires(observe,(Agent1,dead,(T+1))).
__holds(0,observe,(Agent1,alive,T)) :- __fires(observe,(Agent1,dead,(T+1))).
__fires(observe,(var("Agent1"),dead,plus(var("T"),1))) :- __fi

In [21]:
with ProgramBuilder(ctl) as pb:
    for rule in rules:
        pb.add(rule)

In [22]:
ctl.add('base', (), '#show __fires/2. #show __holds/3.')

In [23]:
ctl.ground([('base', [])])
with ctl.solve(yield_=True) as solve_handle:
    models = []
    for model in solve_handle:
        symbols = model.symbols(shown=True)
        print("Answer {}:".format(model.number), "{",
              ' '.join(map(str, sorted(symbols))), "}")
        models.append(symbols)
    solve_result = solve_handle.get()
    print(solve_result)


Answer 1: { __fires(observe,(chicken,dead,2)) __fires(observe,(chicken,dead,3)) __fires(observe,(var("Agent"),dead,plus(var("T"),1))) __fires(observe,(var("Agent1"),dead,plus(var("T"),1))) __holds(0,observe,(chicken,alive,1)) __holds(0,observe,(chicken,alive,2)) __holds(0,observe,(chicken,dead,1)) __holds(0,observe,(chicken,dead,2)) __holds(0,occurs,(__any,(shoot,chicken),1)) __holds(0,occurs,(__any,(shoot,chicken),2)) __holds(0,time,(2,)) __holds(0,time,(3,)) }
SAT


In [24]:
with ctl.solve(yield_=True) as solve_handle:
    models = []
    for model in solve_handle:
        symbols = sorted(sorted(model.symbols(shown=True)), key=functools.partial(temporal, default=-1))
        print("Answer {}:".format(model.number), end=' ')
        print("{",
              '\n'.join(map(str, symbols)), "}", sep='\n')
        models.append(symbols)
    solve_result = solve_handle.get()
    print(solve_result, end='')
    if models:
        print(" {}{}".format(len(models), '' if solve_result.exhausted else '+'))

Answer 1: {
__fires(observe,(chicken,dead,2))
__fires(observe,(chicken,dead,3))
__fires(observe,(var("Agent"),dead,plus(var("T"),1)))
__fires(observe,(var("Agent1"),dead,plus(var("T"),1)))
__holds(0,observe,(chicken,alive,1))
__holds(0,observe,(chicken,alive,2))
__holds(0,observe,(chicken,dead,1))
__holds(0,observe,(chicken,dead,2))
__holds(0,occurs,(__any,(shoot,chicken),1))
__holds(0,occurs,(__any,(shoot,chicken),2))
__holds(0,time,(2,))
__holds(0,time,(3,))
}
SAT 1
