In [301]:
import copy
from dataclasses import dataclass, field
from enum import IntEnum
from typing import TypeVar, Sequence, Optional, List, Iterator, Union

import clingo
import clingox.program


In [302]:
ForwardSymbol = TypeVar('ForwardSymbol', bound='Symbol')


class Symbol:
    def is_function(self) -> bool:
        return isinstance(self, Function)

    def is_unary_operation(self) -> bool:
        return isinstance(self, UnaryOperation)

    def is_binary_operation(self) -> bool:
        return isinstance(self, BinaryOperation)

    def is_operation(self) -> bool:
        return self.is_unary_operation() or self.is_binary_operation()

    def is_variable(self) -> bool:
        return isinstance(self, Variable)

    def is_term(self) -> bool:
        return isinstance(self, Term)

    @classmethod
    def from_clingo_symbol(cls, symbol: clingo.Symbol) -> ForwardSymbol:
        if symbol.type is clingo.SymbolType.Function:
            name: str = symbol.name
            arguments = tuple(SubSymbol.from_clingo_symbol(argument) for argument in symbol.arguments)
            return Function(name, arguments)
        else:
            assert False, "Unknown clingo.SymbolType {}.".format(symbol.type)


class SubSymbol(Symbol):

    @classmethod
    def from_clingo_symbol(cls, symbol: clingo.Symbol) -> ForwardSymbol:
        if symbol.type is clingo.SymbolType.Number:
            return Term(IntegerConstant(symbol.number))
        elif symbol.type is clingo.SymbolType.String:
            return Term(StringConstant(symbol.string))
        elif symbol.type is clingo.SymbolType.Function:
            f_symbol = Symbol.from_clingo_symbol(symbol)
            if symbol.negative:
                return UnaryOperation(UnaryOperatorType.Minus, f_symbol)
            return f_symbol
        else:
            assert False, "Unknown clingo.SymbolType {}.".format(symbol.type)


In [303]:
@dataclass(frozen=True, order=True)
class Variable(SubSymbol):
    name: str

    def __str__(self):
        return self.name


@dataclass(frozen=True, order=True)
class StringConstant:
    string: str = ""

    def __str__(self):
        return '"{}"'.format(self.string)


@dataclass(frozen=True, order=True)
class IntegerConstant:
    number: int = 0

    def __str__(self):
        return str(self.number)


@dataclass(frozen=True, order=True)
class Term(SubSymbol):
    constant: Union[IntegerConstant, StringConstant] = field(default=IntegerConstant())

    def __str__(self):
        return str(self.constant)

In [304]:
ForwardAtom = TypeVar('ForwardAtom', bound='Atom')
ForwardFunction = TypeVar('ForwardFunction', bound='Function')
ForwardUnaryOperator = TypeVar('ForwardUnaryOperator', bound='UnaryOperator')


@dataclass(frozen=True, order=True)
class Function(Symbol):
    name: Optional[str] = None
    arguments: Sequence[SubSymbol] = ()

    @property
    def arity(self):
        return len(self.arguments)

    def __str__(self):
        if self.name is None and not self.arguments:
            return "()"
        elif self.name is not None and not self.arguments:
            return self.name
        elif self.name is None and self.arguments:
            return "({})".format(','.join(map(str, self.arguments)))
        else:
            return "{}({})".format(self.name, ','.join(map(str, self.arguments)))

    def match(self, name: Optional[str], arity: int = 0) -> bool:
        return name == self.name and arity == len(self.arguments)

    def match_signature(self, other: ForwardFunction) -> bool:
        return self.match(other.name, other.arity)


class UnaryOperatorType(IntEnum):
    Minus = 1


@dataclass(frozen=True, order=True)
class UnaryOperation(Symbol):
    operator: UnaryOperatorType
    argument: SubSymbol

    def __str__(self) -> str:
        if self.operator is UnaryOperatorType.Minus:
            if self.argument.is_operation():
                return "-({})".format(self.argument)
            return "-{}".format(self.argument)
        else:
            assert False, "Unknown UnaryOperatorType {}.".format(self.operator)


class BinaryOperatorType(IntEnum):
    Plus = 2


@dataclass(frozen=True, order=True)
class BinaryOperation(SubSymbol):
    left: SubSymbol
    operator: BinaryOperatorType
    right: SubSymbol

    def __str__(self) -> str:
        if self.operator is BinaryOperatorType.Plus:
            if self.left.is_operation() and self.right.is_operation():
                return "({})+({})".format(self.left, self.right)
            elif self.left.is_operation():
                return "({})+{}".format(self.left, self.right)
            elif self.right.is_operation():
                return "{}+({})".format(self.left, self.right)
            return "{}+{}".format(self.left, self.right)


In [305]:
@dataclass(frozen=True, order=True)
class Atom:
    symbol: Symbol = field(default_factory=Function)

    def __str__(self) -> str:
        return str(self.symbol)

    def is_function(self) -> bool:
        return isinstance(self.symbol, Function)

    def is_unary_operation(self) -> bool:
        return isinstance(self.symbol, UnaryOperation)

    def get_top_function(self) -> Function:
        current = self.symbol
        while not isinstance(current, Function):
            if isinstance(current, UnaryOperation):
                current = current.argument
            else:
                assert False, "Unknown Type {} for Symbol {}.".format(type(current).__name__, current)
        return current

    def is_isomorph_to(self, other: ForwardAtom) -> bool:
        assert isinstance(other, Atom), "Atom {} should have type {}, but has type {}.".format(other, Atom.__name__,
                                                                                               type(other).__name__)
        stack_self: List[SubSymbol] = [self.symbol]
        stack_other: List[SubSymbol] = [other.symbol]
        while stack_self:
            if len(stack_self) != len(stack_other):
                return False
            current_self = stack_self.pop()
            current_other = stack_other.pop()
            if current_self.is_variable() and current_other.is_variable():
                if current_self.name != current_other.name:
                    return False
            elif not current_self.is_variable() and not current_other.is_variable():
                if type(current_self) != type(current_other):
                    return False
                if current_self.is_unary_operation() and current_other.is_unary_operation():
                    if current_self.operator != current_other.operator:
                        return False
                    stack_self.append(current_self.argument)
                    stack_other.append(current_other.argument)
                elif current_self.is_term() and current_other.is_term():
                    if current_self != current_other:
                        return False
                elif current_self.is_function() and current_other.is_function():
                    if not current_self.match_signature(current_other):
                        return False
                    stack_self.extend(current_self.arguments)
                    stack_other.extend(current_other.arguments)

        return True

    @staticmethod
    def from_clingo_symbol(symbol: clingo.Symbol) -> ForwardAtom:
        assert symbol.type is clingo.SymbolType.Function, "clingo.Symbol {} should have type {}, but has type {}.".format(
            symbol, clingo.SymbolType.Function, symbol.type)
        return Atom(Symbol.from_clingo_symbol(symbol))




In [306]:
class Sign(IntEnum):
    NoSign = 0
    DefaultNeg = 1

In [307]:
@dataclass(frozen=True, order=True)
class Literal:
    atom: Atom = field(default_factory=Atom)
    sign: Sign = Sign.NoSign

    def is_neg(self):
        return self.sign is Sign.DefaultNeg

    def is_pos(self):
        return self.sign is Sign.NoSign

    def __str__(self):
        if self.sign is Sign.DefaultNeg:
            return "not {}".format(self.atom)
        return str(self.atom)

    def __abs__(self):
        return Literal(sign=Sign.NoSign, atom=copy.deepcopy(self.atom))

    def __neg__(self):
        return Literal(sign=Sign(self.sign ^ 1), atom=self.atom)

    def __invert__(self):
        return Literal(sign=Sign(self.sign ^ 1), atom=self.atom)

In [308]:
class RuleLike:
    def is_rule(self) -> bool:
        return isinstance(self, Rule)

    def is_external(self) -> bool:
        return isinstance(self, External)

In [309]:

ForwardRule = TypeVar('ForwardRule', bound='Rule')
ForwardExternal = TypeVar('ForwardExternal', bound='External')


@dataclass(frozen=True, order=True)
class Rule(RuleLike):
    head: Optional[Literal] = None
    body: Sequence[Literal] = ()

    def __str__(self) -> str:
        if self.head is None and not self.body:
            return ":-."
        elif self.head is None:
            return ":- {}.".format(', '.join(map(str, self.body)))
        elif not self.body:
            return "{}.".format(self.head)
        else:
            return "{} :- {}.".format(self.head, ', '.join(map(str, self.body)))

    def is_ground(self) -> bool:
        pass

    def is_fact(self) -> bool:
        return self.head is not None and not self.body

    def as_external(self) -> ForwardExternal:
        return External(self.head.atom, self.body, ExternalType.false)

    def is_constraint(self) -> bool:
        return self.head is None

    def is_normal_rule(self) -> bool:
        return self.head is not None and self.body

    def is_head_relevant(self, atom: Atom) -> bool:
        if self.head is None:
            return False
        #if self.head.atom == atom:
        #    return True
        return self.head.atom.is_isomorph_to(atom)


In [310]:
ForwardExternalType = TypeVar('ForwardExternalType', bound='ExternalType')


class ExternalType(IntEnum):
    false = 0
    true = 1
    free = 2

    @staticmethod
    def from_truth_value(tv: clingo.TruthValue) -> ForwardExternalType:
        if tv is clingo.TruthValue.False_:
            return ExternalType.false
        elif tv is clingo.TruthValue.True_:
            return ExternalType.true
        else:
            assert tv is clingo.TruthValue.Free
            return ExternalType.free


@dataclass(frozen=True, order=True)
class External(RuleLike):
    atom: Atom = field(default_factory=Atom)
    body: Sequence[Literal] = ()
    external_type: ExternalType = ExternalType.false

    def __str__(self):
        if not self.body:
            return "#external {}. [{}]".format(self.atom, self.external_type.name)
        return "#external {} : {}. [{}]".format(self.atom, ', '.join(map(str, self.body)), self.external_type.name)


In [311]:
ForwardProgram = TypeVar('ForwardProgram', bound='Program')


@dataclass(frozen=True, order=True)
class Program:
    part: str = field(default='base')
    parameters: Sequence[Function] = field(default_factory=tuple)
    rules: Sequence[RuleLike] = field(default_factory=tuple)

    def __str__(self):
        if not self.parameters:
            return "#program {}. {}".format(self.part, ' '.join(map(str, self.rules)))
        return "#program {}({}). {}".format(self.part, ','.join(map(str, self.parameters)),
                                            ' '.join(map(str, self.rules)))

    def facts(self) -> Iterator[Atom]:
        for rule in self.rules:
            if rule.is_rule():
                assert isinstance(rule, Rule)
                if rule.is_fact():
                    yield rule
            if rule.is_external():
                assert isinstance(rule, External)
                if not rule.body:
                    yield rule

    def facts_as_external(self) -> ForwardProgram:
        new_rules = []
        for rule in self.rules:
            if rule.is_fact():
                new_rules.append(rule.as_external())
            else:
                new_rules.append(rule)
        return Program(new_rules)

    def support_rules(self, atom: Atom) -> Iterator[RuleLike]:
        for rule in self.rules:
            if rule.is_rule():
                assert isinstance(rule, Rule)
                if rule.is_head_relevant(atom):
                    yield rule
            elif rule.is_external():
                assert isinstance(rule, External)
                if rule.atom.is_isomorph_to(atom):
                    yield rule

    def ground(self, ctl: Optional[clingo.Control] = None, parts=(('base', ()),)) -> ForwardProgram:
        if ctl is None:
            ctl = clingo.Control()
        prg = clingox.program.Program()
        obs = clingox.program.ProgramObserver(prg)
        ctl.register_observer(obs)
        ctl.add('base', [], str(self))
        ctl.ground(parts)
        new_rules = []
        for fact in prg.facts:
            new_rules.append(Rule(Literal(Atom.from_clingo_symbol(fact.symbol))))
        for rule in prg.rules:
            if len(rule.head) == 0:
                head = None
            elif len(rule.head) == 1:
                if rule.head[0] not in prg.output_atoms:
                    continue
                head = Literal(Atom.from_clingo_symbol(prg.output_atoms[rule.head[0]]))
            else:
                assert False, "Unexpected length of rule head"
            body = tuple(
                Literal(sign=Sign(body_literal < 0), atom=Atom.from_clingo_symbol(prg.output_atoms[abs(body_literal)]))
                for body_literal in rule.body)
            new_rules.append(Rule(head, body))
        for external in prg.externals:
            new_rules.append(External(Atom.from_clingo_symbol(prg.output_atoms[external.atom]), (),
                                      ExternalType.from_truth_value(external.value)))
        return Program(rules=new_rules)

    def evaluate_forwards(self, ctl: Optional[clingo.Control] = None, parts=(('base', ()),)) -> Iterator[
        Sequence[Atom]]:
        if ctl is None:
            ctl = clingo.Control()
        ctl.configuration.solve.models = 0
        ctl.add('base', [], str(self))
        ctl.ground(parts)
        with ctl.solve(yield_=True) as solve_handle:
            for model in solve_handle:
                symbols = sorted(model.symbols(shown=True))
                atoms = tuple(Atom.from_clingo_symbol(symbol) for symbol in symbols)
                yield atoms

    def cautious_consequences(self) -> Sequence[Atom]:
        ctl = clingo.Control()
        ctl.configuration.solve.models = 0
        ctl.configuration.solve.enum_mode = 'cautious'
        ctl.add('base', [], str(self))
        ctl.ground([('base', [])])
        with ctl.solve(yield_=True) as solve_handle:
            model = None
            for m in solve_handle:
                model = m
            symbols = sorted(model.symbols(shown=True))
            atoms = tuple(Atom.from_clingo_symbol(symbol) for symbol in symbols)
            return atoms


def ground(programs: Sequence[Program], ctl: Optional[clingo.Control], parts=(('base', ()),)):
    if ctl is None:
        ctl = clingo.Control()
    prg = clingox.program.Program()
    obs = clingox.program.ProgramObserver(prg)
    ctl.register_observer(obs)
    ctl.add('base', [], '\n'.join(map(str, programs)))
    ctl.ground(parts)
    new_rules = []
    for fact in prg.facts:
        new_rules.append(Rule(Literal(Atom.from_clingo_symbol(fact.symbol))))
    for rule in prg.rules:
        if len(rule.head) == 0:
            head = None
        elif len(rule.head) == 1:
            if rule.head[0] not in prg.output_atoms:
                continue
            head = Literal(Atom.from_clingo_symbol(prg.output_atoms[rule.head[0]]))
        else:
            assert False, "Unexpected length of rule head"
        body = tuple(
            Literal(sign=Sign(body_literal < 0), atom=Atom.from_clingo_symbol(prg.output_atoms[abs(body_literal)]))
            for body_literal in rule.body)
        new_rules.append(Rule(head, body))
    for external in prg.externals:
        new_rules.append(External(Atom.from_clingo_symbol(prg.output_atoms[external.atom]), (),
                                  ExternalType.from_truth_value(external.value)))
    return Program(rules=new_rules)


def evaluate_forwards(programs: Sequence[Program], ctl: Optional[clingo.Control] = None, parts=(('base', ()),)) -> Iterator[Sequence[Atom]]:
    if ctl is None:
        ctl = clingo.Control()
    ctl.configuration.solve.models = 0
    ctl.add('base', [], '\n'.join(map(str, programs)))
    ctl.ground(parts)
    with ctl.solve(yield_=True) as solve_handle:
        for model in solve_handle:
            symbols = sorted(model.symbols(shown=True))
            atoms = tuple(Atom.from_clingo_symbol(symbol) for symbol in symbols)
            yield atoms


In [312]:
e1 = Function('e', [Term(IntegerConstant(1))])
e2 = Function('e', [Term(IntegerConstant(2))])
e3 = Function('e', [Term(IntegerConstant(3))])
a = Function('a')
na = UnaryOperation(UnaryOperatorType.Minus, Function('a'))
b = Function('b')
nb = UnaryOperation(UnaryOperatorType.Minus, Function('b'))
c = Function('c')
nc = UnaryOperation(UnaryOperatorType.Minus, Function('c'))
d = Function('d')
nd = UnaryOperation(UnaryOperatorType.Minus, Function('d'))
e = Function('e')
ne = UnaryOperation(UnaryOperatorType.Minus, Function('e'))
f = Function('f')
nf = UnaryOperation(UnaryOperatorType.Minus, Function('f'))

r1 = Rule(Literal(Atom(Function('impossible', (e1, a)))))
r2 = Rule(Literal(Atom(Function('causes', (e1, e, ne)))))
r3 = Rule(Literal(Atom(Function('causes', (e2, d, nd)))))
r4 = Rule(Literal(Atom(Function('causes', (e3, a, na)))))
r5 = Rule(Literal(Atom(Function('causes', (e3, c, nc)))))
r6 = Rule(Literal(Atom(Function('impossible', (e3, ne)))))
r7 = Rule(Literal(Atom(Function('impossible', (e3, nf)))))
r8 = Rule(Literal(Atom(Function('if', (b, c)))))

rules = (r1, r2, r3, r4, r5, r6, r7, r8)
p_raw = Program(rules=rules)
print(p_raw)

#program base. impossible(e(1),a). causes(e(1),e,-e). causes(e(2),d,-d). causes(e(3),a,-a). causes(e(3),c,-c). impossible(e(3),-e). impossible(e(3),-f). if(b,c).


In [313]:
__t = Function('__t')
__t1 = BinaryOperation(__t, BinaryOperatorType.Plus, Term(IntegerConstant(1)))

In [314]:
new_rules = []
for rule in p_raw.rules:
    if rule.head.atom.symbol.name == 'impossible':
        impossible_if_stm: Function = rule.head.atom.symbol
        new_rules.append(
            Rule(None, (Literal(Atom(Function('occ_at', (impossible_if_stm.arguments[0], __t)))),
                        *(Literal(Atom(Function('obs_at', (argument, __t)))) for argument in
                          impossible_if_stm.arguments[1:]))))
    elif rule.head.atom.symbol.name == 'if':
        if_stm: Function = rule.head.atom.symbol
        head = Literal(Atom(Function('obs_at', (if_stm.arguments[0], __t))))
        body = tuple(
            Literal(Atom(Function('obs_at', (argument, __t)))) for argument in if_stm.arguments[1:])
        new_rules.append(Rule(head, body))
    elif rule.head.atom.symbol.name == 'causes':
        causes_stm: Function = rule.head.atom.symbol
        head = Literal(Atom(Function('obs_at', (
            causes_stm.arguments[1], __t1))))
        body = (Literal(Atom(Function('occ_at', (causes_stm.arguments[0], __t)))),
                *(Literal(Atom(Function('obs_at', (argument, __t)))) for argument in
                  causes_stm.arguments[2:]))
        new_rules.append(Rule(head, body))
    else:
        new_rules.append(rule)

new_rules.append(
    Rule(
        Literal(Atom(Function('obs_at', (
            Variable('F'), __t1)))),
        (
            Literal(Atom(Function('obs_at', (Variable('F'), __t)))),
            Literal(sign=Sign.DefaultNeg, atom=Atom(
                Function('obs_at', (UnaryOperation(UnaryOperatorType.Minus, Variable('F')), __t1))))),
    )
)

new_rules.append(
    Rule(
        Literal(Atom(Function('obs_at', (UnaryOperation(UnaryOperatorType.Minus, Variable('F')),
                                         __t1)))),
        (
            Literal(sign=Sign.DefaultNeg, atom=Atom(Function('obs_at', (Variable('F'), __t1)))),
            Literal(atom=Atom(Function('obs_at', (UnaryOperation(UnaryOperatorType.Minus, Variable('F')), __t)))),
        )
    )
)

p = Program(part='action_language', parameters=(__t,), rules=new_rules)
print('\n'.join(map(str, p.rules)))

:- occ_at(e(1),__t), obs_at(a,__t).
obs_at(e,__t+1) :- occ_at(e(1),__t), obs_at(-e,__t).
obs_at(d,__t+1) :- occ_at(e(2),__t), obs_at(-d,__t).
obs_at(a,__t+1) :- occ_at(e(3),__t), obs_at(-a,__t).
obs_at(c,__t+1) :- occ_at(e(3),__t), obs_at(-c,__t).
:- occ_at(e(3),__t), obs_at(-e,__t).
:- occ_at(e(3),__t), obs_at(-f,__t).
obs_at(b,__t) :- obs_at(c,__t).
obs_at(F,__t+1) :- obs_at(F,__t), not obs_at(-F,__t+1).
obs_at(-F,__t+1) :- not obs_at(F,__t+1), obs_at(-F,__t).


In [315]:
a1 = Rule(Literal(Atom(Function('occ_at', (e1, Term(IntegerConstant(1)))))))
a2 = Rule(Literal(Atom(Function('occ_at', (e2, Term(IntegerConstant(2)))))))
a3 = Rule(Literal(Atom(Function('occ_at', (e3, Term(IntegerConstant(3)))))))
s1 = Rule(Literal(Atom(Function('obs_at', (na, Term(IntegerConstant(1)))))))
s2 = Rule(Literal(Atom(Function('obs_at', (nb, Term(IntegerConstant(1)))))))
s3 = Rule(Literal(Atom(Function('obs_at', (nc, Term(IntegerConstant(1)))))))
s4 = Rule(Literal(Atom(Function('obs_at', (nd, Term(IntegerConstant(1)))))))
s5 = Rule(Literal(Atom(Function('obs_at', (ne, Term(IntegerConstant(1)))))))
s6 = Rule(Literal(Atom(Function('obs_at', (f, Term(IntegerConstant(1)))))))

scenario_path = Program(rules=(a1, a2, a3, s1, s2, s3, s4, s5, s6))
print(scenario_path)

#program base. occ_at(e(1),1). occ_at(e(2),2). occ_at(e(3),3). obs_at(-a,1). obs_at(-b,1). obs_at(-c,1). obs_at(-d,1). obs_at(-e,1). obs_at(f,1).


In [316]:
ctl = clingo.Control()

In [317]:



states = list(evaluate_forwards((p, scenario_path), ctl=ctl, parts=(('base', ()),
                                                                    ('action_language', (clingo.Number(1),)),
                                                                    ('action_language', (clingo.Number(2),)),
                                                                    ('action_language', (clingo.Number(3),)),
                                                                    ('action_language', (clingo.Number(4),)),
                                                                    )))
print("SAT", len(states))
print('\n'.join(map(str, sorted(states[0], key=lambda atom: atom.symbol.arguments[-1]))))

SAT 1
obs_at(f,1)
obs_at(-a,1)
obs_at(-b,1)
obs_at(-c,1)
obs_at(-d,1)
obs_at(-e,1)
occ_at(e(1),1)
obs_at(e,2)
obs_at(f,2)
obs_at(-a,2)
obs_at(-b,2)
obs_at(-c,2)
obs_at(-d,2)
occ_at(e(2),2)
obs_at(d,3)
obs_at(e,3)
obs_at(f,3)
obs_at(-a,3)
obs_at(-b,3)
obs_at(-c,3)
occ_at(e(3),3)
obs_at(a,4)
obs_at(b,4)
obs_at(c,4)
obs_at(d,4)
obs_at(e,4)
obs_at(f,4)
obs_at(a,5)
obs_at(b,5)
obs_at(c,5)
obs_at(d,5)
obs_at(e,5)
obs_at(f,5)


In [318]:
p_transition_event = """
#program transition_event(__t).

missed_outcome(__t) :- outcome(F), obs_at(-F,__t).
missed_outcome(__t) :- outcome(-F), obs_at(F,__t).
reached_outcome(__t) :- not missed_outcome(__t).

transition_event(E, __t) :- occ_at(E, __t), reached_outcome(__t+1), missed_outcome(__t).

"""
transition_event_rules = []
transition_event_rules.append(
    Rule(
        Literal(Atom(Function('missed_outcome', (__t,)))),
        (
            Literal(Atom(Function('outcome', (Variable('F'),)))),
            Literal(Atom(Function('obs_at', (UnaryOperation(UnaryOperatorType.Minus, Variable('F')), __t)))),
        )
    )
)
transition_event_rules.append(
    Rule(
        Literal(Atom(Function('missed_outcome', (__t,)))),
        (
            Literal(Atom(Function('outcome', (UnaryOperation(UnaryOperatorType.Minus, Variable('F')),)))),
            Literal(Atom(Function('obs_at', (Variable('F'), __t)))),
        )
    )
)
transition_event_rules.append(
    Rule(
        Literal(Atom(Function('reached_outcome', (__t,)))),
        (
            Literal(sign=Sign.DefaultNeg, atom=Atom(Function('missed_outcome', (__t,)))),
        )
    )
)
transition_event_rules.append(
    Rule(
        Literal(Atom(Function('transition_event', (Variable('E'), __t,)))),
        (
            Literal(Atom(Function('occ_at', (Variable('E'), __t)))),
            Literal(Atom(Function('reached_outcome', (__t1,)))),
            Literal(Atom(Function('missed_outcome', (__t,)))),
        )
    )
)

p_transition_event = Program(part='transition_event', parameters=(__t,), rules=transition_event_rules)
print(p_transition_event)

#program transition_event(__t). missed_outcome(__t) :- outcome(F), obs_at(-F,__t). missed_outcome(__t) :- outcome(-F), obs_at(F,__t). reached_outcome(__t) :- not missed_outcome(__t). transition_event(E,__t) :- occ_at(E,__t), reached_outcome(__t+1), missed_outcome(__t).


In [319]:
new_rules = []
for rule in p_raw.rules:
    if rule.head.atom.symbol.name == 'causes':
        causes_stm = rule.head.atom.symbol
        event = causes_stm.arguments[0]
        fluent = causes_stm.arguments[1]
        conditions = causes_stm.arguments[2:]
        head = Literal(Atom(Function('direct_effect', (event, fluent, __t))))
        body = tuple(Literal(Atom(Function('obs_at', (condition, __t)))) for condition in conditions)
        new_rules.append(Rule(head, body))

new_rules.append(
    Rule(
        Literal(Atom(Function('direct_effect_in_outcome', (Variable('E'), Variable('F'), __t)))),
        (
            Literal(Atom(Function('direct_effect', (Variable('E'), Variable('F'), __t)))),
            Literal(Atom(Function('outcome', (Variable('F'),))))
        )
    )
)
new_rules.append(
    Rule(
        Literal(Atom(Function('inertial', (Variable('F'), __t1)))),
        (
            Literal(Atom(Function('obs_at', (Variable('F'), __t)))),
            Literal(Atom(Function('obs_at', (Variable('F'), __t1)))),
        )
    )
)
new_rules.append(
    Rule(
        Literal(Atom(Function('indirect_effect', (Variable('F'), __t)))),
        (
            Literal(sign=Sign.DefaultNeg, atom=Atom(Function('obs_at', (Variable('F'), __t)))),
            Literal(Atom(Function('obs_at', (Variable('F'), __t1)))),
            Literal(sign=Sign.DefaultNeg, atom=Atom(Function('direct_effect', (Variable('_'), Variable('F'), __t)))),
            Literal(sign=Sign.DefaultNeg, atom=Atom(Function('inertial', (Variable('F'), __t))))
        )
    )
)
new_rules.append(
    Rule(
        Literal(Atom(Function('indirect_effect_in_outcome', (Variable('F'), __t)))),
        (
            Literal(Atom(Function('indirect_effect', (Variable('F'), __t)))),
            Literal(Atom(Function('outcome', (Variable('F'),))))
        )
    )
)

p_effects = Program(part="effects", parameters=(Function('__t'),), rules=new_rules)
print(p_effects)

#program effects(__t). direct_effect(e(1),e,__t) :- obs_at(-e,__t). direct_effect(e(2),d,__t) :- obs_at(-d,__t). direct_effect(e(3),a,__t) :- obs_at(-a,__t). direct_effect(e(3),c,__t) :- obs_at(-c,__t). direct_effect_in_outcome(E,F,__t) :- direct_effect(E,F,__t), outcome(F). inertial(F,__t+1) :- obs_at(F,__t), obs_at(F,__t+1). indirect_effect(F,__t) :- not obs_at(F,__t), obs_at(F,__t+1), not direct_effect(_,F,__t), not inertial(F,__t). indirect_effect_in_outcome(F,__t) :- indirect_effect(F,__t), outcome(F).


In [320]:
first_causal_explanation = """
#program first_causal_explanation(__t).

first_causal_explanation_direct(E, F,__t) :- direct_effect_in_outcome(E,F,__t), transition_event(E, __t).
first_causal_explanation_indirect(F,__t) :- indirect_effect_in_outcome(F,__t), transition_event(E, __t).

#show first_causal_explanation_direct/3.
#show first_causal_explanation_indirect/2.

"""

outcome = """

outcome(a).
outcome(b).
outcome(c).
outcome(d).
outcome(e).
outcome(f).

"""

explanations = list(evaluate_forwards((p,
                                       outcome,
                                       scenario_path,
                                       p_transition_event,
                                       p_effects,
                                       first_causal_explanation
                                       ), ctl=ctl, parts=(('base', ()),
                                                          ('transition_event', (clingo.Number(2),)),
                                                          ('transition_event', (clingo.Number(3),)),
                                                          ('transition_event', (clingo.Number(4),)),
                                                          ('effects', (clingo.Number(2),)),
                                                          ('effects', (clingo.Number(3),)),
                                                          ('effects', (clingo.Number(4),)),
                                                          ('first_causal_explanation', (clingo.Number(2),)),
                                                          ('first_causal_explanation', (clingo.Number(3),)),
                                                          ('first_causal_explanation', (clingo.Number(4),)),
                                                          ('action_language', (clingo.Number(1),)),
                                                          ('action_language', (clingo.Number(2),)),
                                                          ('action_language', (clingo.Number(3),)),
                                                          ('action_language', (clingo.Number(4),)),
                                                          )))
print('SAT', len(explanations))
print('\n'.join(map(str, explanations[0])))

SAT 1
first_causal_explanation_indirect(b,3)
first_causal_explanation_direct(e(3),a,3)
first_causal_explanation_direct(e(3),c,3)
