In [32]:
import abc
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Callable, Tuple, Sized
from typing import Optional, Iterator, Sequence, TypeVar, Set, Mapping, Union, FrozenSet, Collection, Iterable

import clingo
import clingraph
import more_itertools
from clingraph import Factbase, compute_graphs


In [33]:

def powerset(iterable):
    for i in more_itertools.powerset(iterable):
        yield set(i)


In [34]:

def frozen_powerset(iterable):
    for i in more_itertools.powerset(iterable):
        yield frozenset(i)

In [35]:


def asp_solve(programs,
              ctl: Optional[clingo.Control] = None,
              parts=(('base', ()),),
              context=None,
              report=False,
              report_models=True,
              report_result=True,
              symbol_sep=' ',
              model_sep='\n'
              ) -> Iterator[Sequence[clingo.Symbol]]:
    if ctl is None:
        ctl = clingo.Control(logger=lambda *args, **kwargs: None)
        ctl.configuration.solve.models = 0
    if programs:
        if isinstance(programs, str):
            ctl.add('base', [], programs)
        else:
            for program in programs:
                if isinstance(program, ASPProgram):
                    ctl.add('base', [], str(program))
                elif isinstance(program, str):
                    ctl.add('base', [], program)

    ctl.ground(parts, context=context)
    with ctl.solve(yield_=True) as solve_handle:
        models = 0
        for model in solve_handle:
            symbols = sorted(model.symbols(shown=True))
            if report and report_models:
                print("Answer {}:".format(model.number), end=' ')
                print("{",
                      symbol_sep.join(map(str, sorted(symbols))), "}", sep=symbol_sep, end=model_sep)
            models += 1
            yield symbols
        if report and report_result:
            solve_result = solve_handle.get()
            print(solve_result, end='')
            if solve_result.satisfiable:
                print(" {}{}".format(models, '' if solve_result.exhausted else '+'))
            else:
                print()


def draw_graph(programs,
               ctl: Optional[clingo.Control] = None,
               parts=(('base', ()),)):
    fb = Factbase()
    if ctl is None:
        ctl = clingo.Control(logger=lambda *args, **kwargs: None)
        ctl.configuration.solve.models = 0
    ctl.add('base', [], '\n'.join(programs))
    ctl.ground(parts, clingraph.clingo_utils.ClingraphContext())
    with ctl.solve(yield_=True) as solve_handle:
        for model in solve_handle:
            fb.add_model(model)
            break
    return compute_graphs(fb)


In [36]:

ForwardASPSymbol = TypeVar('ForwardASPSymbol', bound='ASPSymbol')
ForwardASPVariable = TypeVar('ForwardASPVariable', bound='ASPVariable')


class ASPSymbol(abc.ABC):

    @staticmethod
    def from_clingo_symbol(symbol: clingo.Symbol) -> ForwardASPSymbol:
        if symbol.type is clingo.SymbolType.Number:
            return ASPTerm(ASPIntegerConstant(symbol.number))
        elif symbol.type is clingo.SymbolType.String:
            return ASPTerm(ASPStringConstant(symbol.string))
        elif symbol.type is clingo.SymbolType.Function:
            arguments = tuple(ASPFunction.from_clingo_symbol(argument) for argument in symbol.arguments)
            name = symbol.name
            if symbol.negative:
                name = "-{}".format(name)
            return ASPFunction(name, arguments)
        else:
            assert False, "Unhandled clingo.Symbol {} with type {}.".format(symbol, symbol.type.name)


@dataclass(order=True, frozen=True)
class ASPStringConstant:
    string: str = field(default="")

    def __str__(self):
        return '"{}"'.format(self.string)


@dataclass(order=True, frozen=True)
class ASPIntegerConstant:
    number: int = field(default=0)

    def __str__(self):
        return str(self.number)


@dataclass(order=True, frozen=True)
class ASPTerm(ASPSymbol):
    constant: Union[ASPStringConstant, ASPIntegerConstant] = field(default_factory=ASPIntegerConstant)

    def __str__(self):
        return str(self.constant)

    @staticmethod
    def zero():
        return ASPTerm(ASPIntegerConstant(0))

    @staticmethod
    def one():
        return ASPTerm(ASPIntegerConstant(1))


ForwardASPTopLevelSymbol = TypeVar('ForwardASPTopLevelSymbol', bound='ASPTopLevelSymbol')


class ASPTopLevelSymbol(ASPSymbol, abc.ABC):

    def __neg__(self):
        raise NotImplementedError

    @property
    def signature(self) -> str:
        return "{}/{}.".format(self.function_name, self.arity)

    @property
    @abc.abstractmethod
    def function_name(self) -> str:
        raise NotImplementedError

    @property
    def arity(self) -> int:
        return len(self.function_arguments)

    @property
    @abc.abstractmethod
    def function_arguments(self) -> Sequence[ASPSymbol]:
        raise NotImplementedError

    def match(self, other: ForwardASPTopLevelSymbol) -> bool:
        return self.function_name == other.function_name and len(self.function_arguments) == len(
            other.function_arguments)


def asp_evaluate(*programs, report=False) -> Iterator[FrozenSet[ASPTopLevelSymbol]]:
    for answer_set in asp_solve(programs=programs, report=report):
        yield frozenset(ASPFunction.from_clingo_symbol(symbol) for symbol in answer_set)


@dataclass(order=True, frozen=True)
class ASPFunction(ASPTopLevelSymbol):
    name: Optional[str] = field(default=None)
    arguments: Sequence[ASPSymbol] = field(default_factory=tuple)

    @property
    def function_name(self) -> str:
        return self.name

    @property
    def function_arguments(self) -> Sequence[ASPSymbol]:
        return self.arguments

    @property
    def arity(self) -> int:
        return len(self.arguments)

    def __str__(self):
        if self.name is None:
            return '({})'.format(','.join(map(str, self.arguments)))
        elif not self.arguments:
            return self.name
        else:
            return '{}({})'.format(self.name, ','.join(map(str, self.arguments)))

    def __neg__(self):
        if self.name.startswith('-'):
            return ASPFunction(self.name[1:], self.arguments)
        else:
            return ASPFunction('-{}'.format(self.name), self.arguments)


ForwardASPAtom = TypeVar('ForwardASPAtom', bound='ASPAtom')


@dataclass(order=True, frozen=True)
class ASPAtom:
    symbol: ASPTopLevelSymbol = field(default_factory=ASPFunction)

    @property
    def signature(self) -> str:
        return self.symbol.signature

    def match(self, other: ForwardASPAtom) -> bool:
        return self.symbol.function_name == other.symbol.function_name and len(self.symbol.function_arguments) == len(
            other.symbol.function_arguments)

    def __str__(self):
        return str(self.symbol)

    def __neg__(self):
        return ASPAtom(-self.symbol)


@dataclass(order=True, frozen=True)
class ASPClauseElement(abc.ABC):

    @property
    @abc.abstractmethod
    def signature(self) -> str:
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def is_pos(self):
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def is_neg(self):
        raise NotImplementedError

    def __neg__(self):
        raise NotImplementedError

    def __abs__(self):
        raise NotImplementedError


class ASPHeadClauseElement(ASPClauseElement):
    pass


class ASPLiteral(ASPHeadClauseElement):

    @property
    @abc.abstractmethod
    def signature(self) -> str:
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def is_pos(self) -> bool:
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def is_neg(self) -> bool:
        raise NotImplementedError


class ASPSign(IntEnum):
    NoSign = 0
    Negation = 1

    def __str__(self):
        if self is ASPSign.NoSign:
            return ''
        elif self is ASPSign.Negation:
            return 'not'
        else:
            assert False, 'Unknown IntEnum {} = {}.'.format(self.name, self.value)


@dataclass(order=True, frozen=True)
class ASPBasicLiteral(ASPLiteral):
    sign: ASPSign = ASPSign.NoSign
    atom: ASPAtom = field(default_factory=ASPAtom)

    @property
    def is_pos(self) -> bool:
        return self.sign is ASPSign.NoSign

    @property
    def is_neg(self) -> bool:
        return self.sign is ASPSign.Negation

    @property
    def signature(self) -> str:
        return self.atom.signature

    def __str__(self):
        if self.sign is ASPSign.NoSign:
            return "{}".format(self.atom)
        else:
            return "{} {}".format(self.sign, self.atom)

    def __neg__(self):
        return ASPBasicLiteral(ASPSign((self.sign ^ 1) % 2), self.atom)

    def __invert__(self):
        return ASPBasicLiteral(sign=self.sign, atom=-self.atom)

    def __abs__(self):
        return ASPBasicLiteral(ASPSign.NoSign, self.atom)

    def as_classical_atom(self):
        return ClassicalAtom(str(self.atom.symbol))

    @staticmethod
    def make_literal(name: Optional[str], *arguments: Union[str, int, ASPFunction]):
        function_arguments = []
        for arg_ in arguments:
            if isinstance(arg_, str):
                arg = ASPFunction(arg_)
            elif isinstance(arg_, int):
                arg = ASPTerm(ASPIntegerConstant(arg_))
            else:
                arg = arg_
            function_arguments.append(arg)
        return ASPBasicLiteral(atom=ASPAtom(ASPFunction(name=name, arguments=tuple(function_arguments))))


@dataclass(order=True, frozen=True)
class ASPConditionalLiteral(ASPLiteral):
    literal: ASPBasicLiteral = field(default_factory=ASPBasicLiteral)
    conditions: Sequence[ASPClauseElement] = field(default_factory=tuple)

    @property
    def signature(self) -> str:
        return self.literal.signature

    @property
    def is_pos(self) -> bool:
        return self.literal.is_pos

    @property
    def is_neg(self) -> bool:
        return self.literal.is_neg

    def __neg__(self):
        return ASPConditionalLiteral(-self.literal, self.conditions)

    def __abs__(self):
        return ASPConditionalLiteral(abs(self.literal), self.conditions)

    def __str__(self):
        return "{} : {}".format(self.literal, ','.join(map(str, self.conditions)))


@dataclass(order=True, frozen=True)
class ASPDirective(ASPHeadClauseElement):
    name: str
    arguments: Sequence[Union[Sequence[ASPSymbol], ASPSymbol]] = field(default_factory=tuple)

    @property
    def is_true(self) -> bool:
        return self.name == 'true'

    @property
    def is_pos(self) -> bool:
        return self.is_true

    @property
    def is_false(self) -> bool:
        return self.name == 'false'

    @property
    def is_neg(self) -> bool:
        return self.is_false

    @property
    def is_forall(self) -> bool:
        return self.name == 'forall'

    @property
    def is_show(self) -> bool:
        return self.name == 'show'

    @property
    def signature(self) -> str:
        return '#{}/{}.'.format(self.name, len(self.arguments))

    def __abs__(self):
        raise NotImplementedError

    def __neg__(self):
        if self.name == 'true':
            return ASPDirective.false()
        elif self.name == 'false':
            return ASPDirective.true()
        else:
            raise NotImplementedError

    def __invert__(self):
        return -self

    def __str__(self):
        if not self.arguments:
            return "#{}".format(self.name)
        else:
            return "#{}({})".format(self.name, ','.join(map(str, self.arguments)))

    @staticmethod
    def true():
        return ASPDirective('true')

    @staticmethod
    def false():
        return ASPDirective('false')

    @staticmethod
    def show(*args):
        return ASPDirective('show', arguments=args)


@dataclass(order=True, frozen=True)
class ASPRule(abc.ABC):
    head: Optional[ASPHeadClauseElement] = field(default=None)
    body: Optional[Sequence[ASPClauseElement]] = field(default=None)

    @property
    @abc.abstractmethod
    def head_signature(self) -> Union[str, Set[str]]:
        raise NotImplementedError

    @staticmethod
    def fmt_body(body: Sequence[ASPClauseElement]):
        return ', '.join(map(str, body))

    def reduct(self, elems: Set[ASPTopLevelSymbol]):
        if self.body is not None:
            if any(literal.is_neg and literal.atom.symbol in elems for literal in self.body):
                return None
        raise NotImplementedError


@dataclass(order=True, frozen=True)
class ASPNormalRule(ASPRule):
    head: ASPBasicLiteral = field(default_factory=ASPBasicLiteral)
    body: Sequence[ASPClauseElement] = field(default_factory=tuple)

    @property
    def head_signature(self) -> str:
        return self.head.signature

    def __str__(self):
        if self.body:
            return "{} :- {}.".format(self.head, ASPRule.fmt_body(self.body))
        else:
            return "{}.".format(self.head)

    def reduct(self, elems: Set[ASPTopLevelSymbol]):
        if any(literal.is_neg and literal.atom.symbol in elems for literal in self.body):
            return None
        return ASPNormalRule(self.head, tuple(literal for literal in self.body if literal.is_pos))


@dataclass(order=True, frozen=True)
class ASPIntegrityConstraint(ASPRule):
    body: Sequence[ASPClauseElement] = field(default_factory=tuple)
    head: ASPDirective = field(default_factory=ASPDirective.false, init=False)

    @property
    def head_signature(self) -> str:
        return '#false/0.'

    def __str__(self):
        if self.body:
            return ":- {}.".format(ASPRule.fmt_body(self.body))
        else:
            return ":-."

    def reduct(self, elems: Set[ASPTopLevelSymbol]):
        if any(literal.is_neg and literal.atom.symbol in elems for literal in self.body):
            return None
        return ASPIntegrityConstraint(tuple(literal for literal in self.body if literal.is_pos))


@dataclass(order=True, frozen=True)
class ASPDisjunctiveRule(ASPRule):
    head: Sequence[ASPLiteral] = field(default_factory=tuple)
    body: Sequence[ASPClauseElement] = field(default_factory=tuple)

    @property
    def head_signature(self) -> Set[str]:
        return {h.signature for h in self.head}

    def __str__(self):
        return "{} :- {}.".format('; '.join(map(str, self.head)), ASPRule.fmt_body(self.body))

    def reduct(self, elems: Set[ASPTopLevelSymbol]):
        if any(literal.is_neg and literal.atom.symbol in elems for literal in self.body):
            return None
        return ASPDisjunctiveRule(
            head=tuple(literal for literal in self.head if literal.is_pos or (literal.atom.symbol not in elems)),
            body=tuple(literal for literal in self.body if literal.is_pos))


@dataclass(order=True, frozen=True)
class ASPChoiceRule(ASPRule):
    lower: Optional[ASPSymbol] = field(default=None)
    head: Sequence[ASPLiteral] = field(default_factory=tuple)
    upper: Optional[ASPSymbol] = field(default=None)
    body: Sequence[ASPClauseElement] = field(default_factory=tuple)

    @property
    def head_signature(self) -> Union[str, Set[str]]:
        return {h.signature for h in self.head}

    def __str__(self):
        lower = ""
        if self.lower is not None:
            lower = "{} <=".format(self.lower)
        upper = ""
        if self.upper is not None:
            upper = "<= {}".format(self.upper)
        body = ""
        if self.body:
            body = " :- {}".format(ASPRule.fmt_body(self.body))
        return "{}{}{}{}{}{}".format(lower, '{', "; ".join(map(str, self.head)), '}', upper, body)


@dataclass(order=True, frozen=True)
class ASPProgram:
    rules: Sequence[ASPRule] = field(default_factory=tuple)

    def fmt(self, sep='\n'):
        return sep.join(map(str, self.rules))

    def __str__(self):
        return self.fmt(' ')

    def reduct(self, elems: Set[ASPTopLevelSymbol]):
        return ASPProgram(tuple(filter(lambda r: r is not None, (rule.reduct(elems) for rule in self.rules))))

In [37]:


@dataclass(frozen=True, order=True)
class ClassicalAtom:
    symbol: str

    def __neg__(self):
        if self.is_complement:
            return ClassicalAtom(self.symbol[1:])
        else:
            return ClassicalAtom('-{}'.format(self.symbol))

    def __abs__(self):
        if self.is_complement:
            return -self
        return self

    def __str__(self):
        return self.symbol

    @property
    def is_complement(self) -> bool:
        return self.symbol.startswith('-')

    def as_asp_top_level_symbol(self) -> ASPTopLevelSymbol:
        return ASPFunction(self.symbol)

    def as_asp_atom(self) -> ASPAtom:
        return ASPAtom(self.as_asp_top_level_symbol())

    def as_asp_literal(self) -> ASPBasicLiteral:
        return ASPBasicLiteral(atom=self.as_asp_atom())


ClassicalAlphabet = Set[ClassicalAtom]
ClassicalValuation = Mapping[ClassicalAtom, bool]


@dataclass(frozen=True, order=True)
class ClassicalLiteral:
    atom: ClassicalAtom
    sign: bool = field(default=True)

    def __str__(self):
        sign_str = ""
        if not self.sign:
            sign_str = "¬"
        return "{}{}".format(sign_str, self.atom)

    def __repr__(self):
        return str(self)

    def __neg__(self):
        return ClassicalLiteral(self.atom, not self.sign)

    def __invert__(self):
        return ClassicalLiteral(-self.atom, self.sign)

    def __and__(self, other):
        left = ClassicalFormula(self)
        right = other
        if isinstance(other, ClassicalLiteral):
            right = ClassicalFormula(right)
        return ClassicalFormula(left, ClassicalConnective.And, right)

    def __or__(self, other):
        left = ClassicalFormula(self)
        right = other
        if isinstance(other, ClassicalLiteral):
            right = ClassicalFormula(right)
        return ClassicalFormula(left, ClassicalConnective.Or, right)

    def __rshift__(self, other):
        left = ClassicalFormula(self)
        right = other
        if isinstance(other, ClassicalLiteral):
            right = ClassicalFormula(right)
        return ClassicalFormula(left, ClassicalConnective.Implies, right)

    def __call__(self, valuation: ClassicalValuation, *args, **kwargs):
        return self.sign == valuation.get(self.atom, False)

@dataclass(frozen=True, order=True)
class ClassicalTop(ClassicalLiteral):
    atom: ClassicalAtom = field(default=ClassicalAtom('⊤'), init=False)
    sign: bool = field(default=True, init=False)

    def __str__(self):
        return str(self.atom)

    def __repr__(self):
        return str(self)

    def __neg__(self):
        return ClassicalBot()

    def __and__(self, other):
        if isinstance(other, ClassicalLiteral):
            return ClassicalFormula(other)
        return other

    def __or__(self, other):
        return ClassicalFormula(ClassicalTop())


@dataclass(frozen=True, order=True)
class ClassicalBot(ClassicalLiteral):
    atom: ClassicalAtom = field(default=ClassicalAtom('⊥'), init=False)
    sign: bool = field(default=False, init=False)

    def __neg__(self):
        return ClassicalTop()

    def __str__(self):
        return str(self.atom)

    def __and__(self, other):
        return ClassicalFormula(ClassicalBot())

    def __or__(self, other):
        if isinstance(other, ClassicalLiteral):
            return ClassicalFormula(other)
        return other


class ClassicalConnective(IntEnum):
    And = 0
    Or = 1
    Implies = 2

    def __str__(self):
        if self is ClassicalConnective.And:
            return "∧"
        elif self is ClassicalConnective.Or:
            return "∨"
        elif self is ClassicalConnective.Implies:
            return "→"
        else:
            assert False, "Unhandled Connective.__str__: {} = {}".format(self.name, self.value)

    def evaluate(self, left: bool, right: bool):
        if self is ClassicalConnective.And:
            return left and right
        elif self is ClassicalConnective.Or:
            return left or right
        elif self is ClassicalConnective.Implies:
            return not left or right
        else:
            assert False, "Unhandled Connective.evaluate: {} = {}".format(self.name, self.value)


ForwardClassicalFormula = TypeVar('ForwardClassicalFormula', bound='ClassicalFormula')


@dataclass(frozen=True, order=True)
class ClassicalFormula:
    left: Union[ForwardClassicalFormula, ClassicalLiteral]
    connective: Optional[ClassicalConnective] = field(default=None)
    right: Union[ForwardClassicalFormula, None] = field(default=None)

    def __str__(self):
        left_str = str(self.left)
        connective_str = ""
        if self.connective is not None:
            connective_str = " {}".format(self.connective)
            if isinstance(self.left,
                          ClassicalFormula) and self.left.connective is not None and self.left.connective > self.connective:
                left_str = "({})".format(left_str)
        right_str = ""
        if self.right is not None:
            if self.right.left == ClassicalBot() and self.right.right is None and self.connective is ClassicalConnective.Implies:
                left_str = "¬({})".format(left_str)
                connective_str = ""
            else:
                right_str = " {}".format(self.right)
            if isinstance(self.right,
                          ClassicalFormula) and self.right.connective is not None and self.right.connective > self.connective:
                left_str = "({})".format(left_str)
        return "{}{}{}".format(left_str, connective_str, right_str)

    def __repr__(self):
        return str(self)

    def __neg__(self):
        if self.connective is not None and self.right is None:
            raise TypeError("Formula.connective present, despite Formula.right missing.")
        elif self.connective is None and self.right is not None:
            raise TypeError("Formula.connective missing, despite Formula.right present.")

        if self.connective is None and self.right is None:
            return ClassicalFormula(-self.left)
        elif self.connective is ClassicalConnective.And:
            return ClassicalFormula(-self.left, ClassicalConnective.Or, -self.right)
        elif self.connective is ClassicalConnective.Or:
            return ClassicalFormula(-self.left, ClassicalConnective.And, -self.right)
        elif self.connective is ClassicalConnective.Implies:
            if self.right.left == ClassicalBot() and self.right.right is None:
                return self.left
            return ClassicalFormula(self, ClassicalConnective.Implies, ClassicalFormula(ClassicalBot()))
        else:
            assert False, "Unknown Formula.connective. {} = {}.".format(self.connective.name, self.connective.value)

    def __and__(self, other):
        left = self
        right = other
        if isinstance(other, ClassicalLiteral):
            right = ClassicalFormula(right)
        if left.is_top:
            return right
        elif right.is_top:
            return left
        if left.is_bot:
            return left
        elif right.is_bot:
            return right
        return ClassicalFormula(left, ClassicalConnective.And, right)

    def __or__(self, other):
        left = self
        right = other
        if isinstance(other, ClassicalLiteral):
            right = ClassicalFormula(right)
        if left.is_top:
            return left
        elif right.is_top:
            return right
        if left.is_bot:
            return right
        elif right.is_bot:
            return left
        return ClassicalFormula(left, ClassicalConnective.Or, right)

    def __rshift__(self, other):
        left = self
        right = other
        if isinstance(other, ClassicalLiteral):
            right = ClassicalFormula(right)
        return ClassicalFormula(left, ClassicalConnective.Implies, right)

    def __call__(self, valuation: Optional[ClassicalValuation] = None) -> bool:
        return self.evaluate(valuation)

    @property
    def literals(self) -> Set[ClassicalLiteral]:
        literals = set()
        if isinstance(self.left, ClassicalLiteral):
            if not isinstance(self.left, ClassicalTop) and not isinstance(self.left, ClassicalBot):
                literals.add(self.left)
        else:
            assert isinstance(self.left, ClassicalFormula), "Unknown type for Formula.right. {}: {}".format(
                type(self.left).__name__, self.left)
            literals.update(self.left.literals)
        if self.right is not None:
            assert isinstance(self.right, ClassicalFormula), "Unknown type for Formula.right. {}: {}".format(
                type(self.right).__name__, self.right)
            literals.update(self.right.literals)
        return literals

    @property
    def atoms(self) -> Set[ClassicalAtom]:
        return {literal.atom for literal in self.literals}

    @property
    def is_top(self) -> bool:
        return self.right is None and isinstance(self.left, ClassicalTop)

    @property
    def is_bot(self) -> bool:
        return self.right is None and isinstance(self.left, ClassicalBot)

    def evaluate(self, valuation: Optional[ClassicalValuation] = None) -> bool:
        if isinstance(self.left, ClassicalLiteral):
            value_left = self.__evaluate_literal(self.left, valuation)
        else:
            assert isinstance(self.left, ClassicalFormula), "Unknown type for Formula.left. {}: {}".format(
                type(self.left).__name__, self.left)
            value_left = self.left.evaluate(valuation)
        if self.connective is not None and self.right is None:
            raise TypeError("Formula.connective present, despite Formula.right missing.")
        elif self.connective is None and self.right is not None:
            raise TypeError("Formula.connective missing, despite Formula.right present.")

        if self.connective is None and self.right is None:
            return value_left
        else:
            assert isinstance(self.right, ClassicalFormula), "Unknown type for Formula.right. {}: {}".format(
                type(self.right).__name__, self.right)
            value_right = self.right.evaluate(valuation)

            return self.connective.evaluate(value_left, value_right)

    def __evaluate_literal(self, literal: ClassicalLiteral, valuation: Optional[ClassicalValuation] = None) -> bool:
        if isinstance(literal, ClassicalTop) or isinstance(literal, ClassicalBot):
            return literal.sign
        else:
            # get assigned truth value of atom (per default false) and flip the result if negated
            return valuation is not None and bool(valuation.get(literal.atom, False) ^ (not literal.sign))

    def set_to_bot(self, *atoms: ClassicalAtom) -> ForwardClassicalFormula:
        if isinstance(self.left, ClassicalLiteral):
            assert self.connective is None
            assert self.right is None
            if self.left.atom in atoms:
                if self.left.sign:
                    return ClassicalFormula(ClassicalBot())
                else:
                    return ClassicalFormula(ClassicalTop())
            else:
                return self
        else:
            left = self.left.set_to_bot(*atoms)
            right = self.right
            if right is not None:
                right = self.right.set_to_bot(*atoms)
            return ClassicalFormula(left, self.connective, right)

    def as_asp_bodies(self) -> Sequence[Sequence[ASPBasicLiteral]]:
        bodies: Sequence[Sequence[ASPBasicLiteral]] = [()]
        if isinstance(self.left, ClassicalLiteral):
            if isinstance(self.left, ClassicalTop):
                bodies = ((ASPDirective.true(),),)
            elif isinstance(self.left, ClassicalBot):
                bodies = ((ASPDirective.false(),),)
            else:
                literal = ASPBasicLiteral(sign=ASPSign(int(not self.left.sign)),
                                          atom=ASPAtom(ASPFunction(self.left.atom.symbol)))
                bodies = ((literal,),)
        elif self.right is None:
            bodies = self.left.as_asp_bodies()
        elif self.right is not None:
            bodies_left = self.left.as_asp_bodies()
            bodies_right = self.right.as_asp_bodies()
            bodies_ = []
            if self.connective is ClassicalConnective.Or:
                for body in bodies:
                    for body_left in bodies_left:
                        bodies_.append((*body, *body_left))
                    for body_right in bodies_right:
                        bodies_.append((*body, *body_right))
            elif self.connective is ClassicalConnective.And:
                for body in bodies:
                    for body_left in bodies_left:
                        for body_right in bodies_right:
                            bodies_.append((*body, *body_left, *body_right))
            else:
                assert self.connective is ClassicalConnective.Implies
                if self.right.is_bot:
                    if self.left.connective is not None and self.left.connective is ClassicalConnective.Implies:
                        bodies_ = (-((-self.left.left) | self.left.right)).as_asp_bodies()
                    else:
                        bodies_ = (-self.left).as_asp_bodies()
                else:
                    for body in bodies:
                        for body_left in bodies_left:
                            bodies_.append((*body, *(-literal for literal in body_left)))
                        for body_right in bodies_right:
                            bodies_.append((*body, *body_right))
            bodies = bodies_
        else:
            assert False
        return bodies


In [38]:

def all_valuations(alphabet: ClassicalAlphabet, complete: bool = False) -> Iterator[ClassicalValuation]:
    subsets = powerset(alphabet)
    for subset in subsets:
        valuation = defaultdict(lambda: False)
        for atom in subset:
            valuation[atom] = True
        if complete:
            for atom in alphabet:
                if atom not in subset:
                    valuation[atom] = False
        yield valuation


def models(formulas: Set[ClassicalFormula], alphabet: Optional[ClassicalAlphabet] = None) -> Iterator[
    ClassicalValuation]:
    if alphabet is None:
        alphabet = {atom for formula in formulas for atom in formula.atoms}
    for valuation in all_valuations(alphabet):
        if all(formula.evaluate(valuation) for formula in formulas):
            yield valuation


def sat(formulas: Set[ClassicalFormula], alphabet: Optional[ClassicalAlphabet] = None) -> bool:
    model = next(models(formulas, alphabet), None)
    return model is not None


def unsat(formulas: Set[ClassicalFormula], alphabet: Optional[ClassicalAlphabet] = None) -> bool:
    return not sat(formulas, alphabet)


def entails(formulas: Set[ClassicalFormula], formula: ClassicalFormula) -> bool:
    return unsat(formulas | {-formula})


def valid(formulas: Set[ClassicalFormula], alphabet: Optional[ClassicalAlphabet] = None) -> bool:
    if alphabet is None:
        alphabet = {atom for formula in formulas for atom in formula.atoms}
    for valuation in all_valuations(alphabet):
        if any(not formula.evaluate(valuation) for formula in formulas):
            return False
    return True

In [39]:
Fluent = ClassicalAtom
FluentLiteral = ClassicalLiteral
_ForwardState = TypeVar("_ForwardState", bound="State")

In [40]:
@dataclass(order=True, frozen=True)
class Event:
    symbol: str

    def __str__(self):
        return self.symbol



In [41]:
class State(FrozenSet[FluentLiteral]):

    @property
    def coherent(self) -> bool:
        if not all(isinstance(elem, FluentLiteral) for elem in self):
            raise NotImplementedError
        return all(-literal not in self for literal in self)

    def __str__(self):
        return "{}{}{}".format('{', ','.join(map(str, sorted(self))), '}')

    def __neg__(self):
        return State(-elem for elem in self)

    def __invert__(self):
        return State(~elem for elem in self)

    def complete(self, fluents: Collection[Fluent]) -> bool:
        if not all(isinstance(elem, FluentLiteral) for elem in self):
            raise NotImplementedError
        return len(fluents) == len(self)

    def as_valuation(self) -> ClassicalValuation:
        if not all(isinstance(elem, FluentLiteral) for elem in self):
            raise NotImplementedError
        return defaultdict(lambda: False, {literal.atom: literal.sign for literal in self})

In [42]:
EventPO = Callable[[Event, Event], int]
Fluents = Collection[Fluent]
Events = Collection[Event]
EventChain = Iterable[Iterable[Event]]
EventSequence = Sequence[Set[Event]]
Preconditions = Mapping[Event, FrozenSet[ClassicalFormula]]
Triggers = Mapping[Event, FrozenSet[ClassicalFormula]]
Effects = Mapping[Event, FrozenSet[Tuple[FrozenSet[ClassicalFormula], FrozenSet[ClassicalLiteral]]]]
Context = Tuple[Fluents, Events, Preconditions, Triggers, Effects, State, EventPO, range]

In [43]:
def fmt_cond(cond: Collection[ClassicalFormula]) -> str:
    if not cond:
        return "⊤"
    elif len(cond) == 1:
        return "{}".format(" ∧ ".join(map(str, cond)))
    else:
        return "{}".format(" ∧ ".join(map(lambda c: "({})".format(c), cond)))

def fmt_eff(conditionals, effects) -> str:
    conditionals_str = fmt_cond(conditionals)
    if len(effects) == 1:
        return "[{}]{}".format(conditionals_str, ", ".join((map(str, effects))))
    else:
        return "[{}]{}{}{}".format(conditionals_str, '{', ", ".join((map(str, effects))), '}')

def fmt_effs(events: Iterable[Event], eff: Effects, fmt: Optional[str] = None):
    for event in events:
        fmt_ = "{} = ".format(event)
        fmts_ = []
        for effs in eff[event]:
            conditionals, effects = effs
            fmts_.append(fmt_eff(conditionals, effects))
        fmt_ += '{' + ', '.join(fmts_) + '}'
        if fmt is None:
            fmt = fmt_
        else:
            fmt += ", " + fmt_
    return fmt


def print_events(events: Iterable[Event], eff: Effects):
    print(fmt_effs(events, eff))

In [44]:
def _actual_eff_gen(context: Context, events: Iterable[Event], state: State) -> Iterator[FluentLiteral]:
    _, _, _, _, eff, _, _, _ = context
    return (l for ev in events for (cond,lits) in eff[ev] for l in lits if all(c(state.as_valuation()) for c in cond) and l not in state)


def actual_eff(context: Context, events: Iterable[Event], state: State) -> State:
    return State(_actual_eff_gen(context, events, state))

In [45]:
def update(context: Context, state: State, events: Iterable[Event]) -> State:
    ae = actual_eff(context, events, state)
    return State((state - (-ae)) | ae)

def _update_inv_gen(context: Context, state: State, events: Iterable[Event]) -> Iterator[State]:
    fluents,_,_,_,eff,_,_,_ = context
    fluent_literals_ = State(ClassicalLiteral(fluent) for fluent in fluents)
    fluent_literals = fluent_literals_ | -fluent_literals_
    for candidate_ in more_itertools.powerset(fluent_literals):
        candidate = State(candidate_)
        if not candidate.coherent:
            continue
        after = update(context, candidate, events)
        if after == state:
            yield candidate


def update_inv(context: Context, state: State, events: Iterable[Event]) -> FrozenSet[State]:
    return frozenset(_update_inv_gen(context, state, events))

def update_inv_min(context: Context, state: State, events: Iterable[Event]) -> FrozenSet[State]:
    invs = set()
    for inv in _update_inv_gen(context, state, events):
        subsumed = set()
        is_subsumed = False
        for inv_ in invs:
            if inv < inv_:
                subsumed.add(inv)
            elif inv_ < inv:
                is_subsumed = True
        invs -= subsumed
        if not is_subsumed:
            invs.add(inv)
    return frozenset(invs)




In [46]:
def _induced_state_sequence_gen(context: Context, state: State, event_sequence: EventSequence) -> Iterator[State]:
    _, _, pre, _, _, _, _, _ = context
    state_ = state
    for events in event_sequence:
        state_ = update(context, state_, events)
        yield state_


def induced_state_sequence(context: Context, state: State, event_sequence: EventSequence) -> Sequence[State]:
    return tuple(_induced_state_sequence_gen(context, state, event_sequence))

In [47]:
def executable_in_context(context: Context, event_sequence: EventSequence) -> bool:
    _, _, _, _, _, s0, _, _ = context
    induced_state_seq = _induced_state_sequence_gen(context, s0, event_sequence)
    return all(all(event.pre_sat(induced_state) for event in event_sequence[i]) for i, induced_state in
               enumerate(induced_state_seq))


def concurrent_correct_in_context(context: Context, event_sequence: EventSequence) -> bool:
    _, _, _, _, _, s0, po, _ = context
    for events in event_sequence:
        for e in events:
            for e_ in events:
                if po(e, e_) > 0:
                    return False
    return True


def trigger_correct_in_context(context: Context, event_sequence: EventSequence) -> bool:
    _, events, _, tri, _, s0, po, _ = context
    induced_state_seq = tuple(_induced_state_sequence_gen(context, s0, event_sequence))
    for e_ in events:
        for t, s in enumerate(induced_state_seq):
            if events not in tri or tri[events](s):
                if not (e_ in event_sequence[t] or any(po(e, e_) > 0 for e in event_sequence[t])):
                    return False
    return True


def valid_in_context(context: Context, event_sequence: EventSequence) -> bool:
    return executable_in_context(context, event_sequence) and concurrent_correct_in_context(context,
                                                                                            event_sequence) and trigger_correct_in_context(
        context, event_sequence)

In [48]:
def empty_partial_order() -> EventPO:
    def empty(*args, **kwargs) -> int:
        return 0

    return empty

In [49]:
def fmt_context(context: Context) -> str:
    fluents, events, pre, tri, eff, s0, po, t = context
    return "<{}{}{}, {}{}{}, {}{}{}, {}{}{}, <, {}-{}>".format('{', ', '.join(map(str, fluents)), '}', '{', ', '.join(map(str, events)), '}', '{', fmt_effs(events, eff),
                                                       '}', '{', ', '.join(map(str, s0)), '}', t.start, t.stop)


def print_context(context: Context):
    print(fmt_context(context))

In [50]:
l1 = ClassicalLiteral(ClassicalAtom('l1'))
l2 = ClassicalLiteral(ClassicalAtom('l2'))
l3 = ClassicalLiteral(ClassicalAtom('l3'))
l4 = ClassicalLiteral(ClassicalAtom('l4'))
ini_l1 = Event('ini_l1')
ini_neg_l2 = Event('ini_neg_l2')
ini_neg_l3 = Event('ini_neg_l3')
ini_neg_l4 = Event('ini_neg_l4')
e_neg_1 = Event('e_neg_1')
e_1 = Event('e_1')
e_2 = Event('e_2')
e_3 = Event('e_3')
e_4 = Event('e_4')
ex_4_pre: Preconditions = defaultdict(lambda: frozenset({ClassicalFormula(ClassicalTop())}))
ex_4_tri: Triggers = defaultdict(lambda: frozenset({ClassicalFormula(ClassicalTop())}))
ex_4_eff: Effects = defaultdict(lambda: frozenset({tuple((frozenset(), frozenset()))}), {
    ini_l1: frozenset({tuple((frozenset(), frozenset({l1})))}),
    ini_neg_l2: frozenset({tuple((frozenset(), frozenset({-l2})))}),
    ini_neg_l3: frozenset({tuple((frozenset(), frozenset({-l3})))}),
    ini_neg_l4: frozenset({tuple((frozenset(), frozenset({-l4})))}),
    e_neg_1: frozenset({tuple((frozenset(), frozenset({-l1})))}),
    e_1: frozenset({tuple((frozenset(), frozenset({l1})))}),
    e_2: frozenset({tuple((frozenset(), frozenset({l2})))}),
    e_3: frozenset({tuple((frozenset(), frozenset({l3})))}),
    e_4: frozenset({tuple((frozenset(), frozenset({l4})))}),
})
ex_4_fluents: Fluents = frozenset({l1.atom, l2.atom, l3.atom, l4.atom})
ex_4_events: Events = frozenset({ini_l1, ini_neg_l2, ini_neg_l3, ini_neg_l4, e_neg_1, e_1, e_2, e_3, e_4})
ex_4_s0: State = State({l1, -l2, -l3, -l4})
ex_4_context: Context = (ex_4_fluents, ex_4_events, ex_4_pre, ex_4_tri, ex_4_eff, ex_4_s0, empty_partial_order(), range(-1, 3))
print_context(ex_4_context)

<{l2, l3, l4, l1}, {e_1, e_3, e_2, ini_neg_l2, ini_l1, ini_neg_l4, e_4, ini_neg_l3, e_neg_1}, {e_1 = {[⊤]l1}, e_3 = {[⊤]l3}, e_2 = {[⊤]l2}, ini_neg_l2 = {[⊤]¬l2}, ini_l1 = {[⊤]l1}, ini_neg_l4 = {[⊤]¬l4}, e_4 = {[⊤]l4}, ini_neg_l3 = {[⊤]¬l3}, e_neg_1 = {[⊤]¬l1}}, {l1, ¬l2, ¬l4, ¬l3}, <, -1-3>


In [51]:
ex_4_event_sequence: EventSequence = [{e_neg_1, e_2}, {e_3, e_4}]
ex_4_event_sequence

[{Event(symbol='e_2'), Event(symbol='e_neg_1')},
 {Event(symbol='e_3'), Event(symbol='e_4')}]

In [52]:
ex_4_induced_event_sequence = induced_state_sequence(ex_4_context, ex_4_s0, ex_4_event_sequence)
print("0:", ex_4_s0)
for i, state in enumerate(ex_4_induced_event_sequence):
    print("{}:".format(i + 1), state)

0: {l1,¬l2,¬l3,¬l4}
1: {¬l1,l2,¬l3,¬l4}
2: {¬l1,l2,l3,l4}


In [53]:
l_c1 = ClassicalLiteral(ClassicalAtom('l_c1'))
l_c3 = ClassicalLiteral(ClassicalAtom('l_c3'))
ini_l_c1 = Event('ini_l_c1')
ini_neg_l_c3 = Event('ini_neg_l_c3')
e = Event('e')
e_ = Event("e'")
ex_5_fluents: Fluents = frozenset({l1.atom, l2.atom, l3.atom, l_c1.atom, l_c3.atom})
ex_5_events: Events = frozenset({ini_l1, ini_neg_l2, ini_neg_l3, e, e_})
ex_5_pre = defaultdict(lambda: frozenset({ClassicalFormula(ClassicalTop())}), {
    e: frozenset({})
})
ex_5_tri = defaultdict(lambda: frozenset({ClassicalFormula(ClassicalTop())}))
ex_5_eff: Effects = defaultdict(lambda: frozenset({tuple((frozenset(), frozenset()))}), {
    ini_l1: frozenset({tuple((frozenset(), frozenset({l1})))}),
    ini_neg_l2: frozenset({tuple((frozenset(), frozenset({-l2})))}),
    ini_neg_l3: frozenset({tuple((frozenset(), frozenset({-l3})))}),
    e_neg_1: frozenset({tuple((frozenset(), frozenset({-l1})))}),
    e: frozenset({
        (frozenset({l_c1}), frozenset({-l1})),
        (frozenset(), frozenset({l2})),
        (frozenset({l_c3}), frozenset({l3})),
    }),
    e_: frozenset({
        (frozenset(), frozenset({-l_c1})),
        (frozenset(), frozenset({l_c3})),
    }),
})
ex_5_s0: State = State({l1, -l2, -l3, l_c1, -l_c3})
ex_5_context: Context = (ex_5_fluents, ex_5_events, ex_5_pre, ex_5_tri, ex_5_eff, ex_5_s0, empty_partial_order(), range(-1, 3))
print_context(ex_5_context)


<{l_c3, l3, l1, l2, l_c1}, {ini_l1, e', e, ini_neg_l3, ini_neg_l2}, {ini_l1 = {[⊤]l1}, e' = {[⊤]l_c3, [⊤]¬l_c1}, e = {[l_c3]l3, [⊤]l2, [l_c1]¬l1}, ini_neg_l3 = {[⊤]¬l3}, ini_neg_l2 = {[⊤]¬l2}}, {l1, ¬l2, ¬l3, l_c1, ¬l_c3}, <, -1-3>


In [54]:
ex_5_event_sequence: EventSequence = [{e_}, {e}]
ex_5_event_sequence

[{Event(symbol="e'")}, {Event(symbol='e')}]

In [55]:
ex_5_induced_event_sequence = induced_state_sequence(ex_5_context, ex_5_s0, ex_5_event_sequence)
print("0:", ex_5_s0)
for i, state in enumerate(ex_5_induced_event_sequence):
    print("{}:".format(i + 1), state)

0: {l1,¬l2,¬l3,l_c1,¬l_c3}
1: {l1,¬l2,¬l3,¬l_c1,l_c3}
2: {l1,l2,l3,¬l_c1,l_c3}


In [56]:
prev_states = update_inv(ex_5_context, State({l1, l2, l3, -l_c1, l_c3}), {e})
for prev_state in prev_states:
    print(prev_state)

{l1,¬l2,l3,¬l_c1,l_c3}
{l1,¬l2,¬l3,¬l_c1,l_c3}
{l1,¬l2,¬l_c1,l_c3}
{l1,l2,¬l3,¬l_c1,l_c3}
{l1,¬l3,¬l_c1,l_c3}
{l1,l2,¬l_c1,l_c3}
{l1,l3,¬l_c1,l_c3}
{l1,¬l_c1,l_c3}
{l1,l2,l3,¬l_c1,l_c3}


In [57]:
prev_states = update_inv_min(ex_5_context, State({l1, l2, l3, -l_c1, l_c3}), {e})
for prev_state in prev_states:
    print(prev_state)

{l1,¬l_c1,l_c3}


In [58]:
def backing_candidates(context:Context, query: Set[ClassicalFormula]) -> FrozenSet[State]:
    fluents, _, _, _, _, _, _, _ = context
    candidates = []
    size = float('inf')
    for model in models(query, fluents):
        size_ = sum(model.values())
        if size_ <= size:
            size = size_
            candidates.append(State(atom for atom,value in model.items() if value))
        else:
            break
    return frozenset(candidates)

In [59]:
ex_4_backing_candidates = backing_candidates(ex_4_context, {(l1 & l4) | (l2 & l4) | (l3 & l4)})
for backing_candidate in ex_4_backing_candidates:
    print(backing_candidate)

{l3,l4}
{l2,l4}
{l1,l4}


In [60]:
def decreasing_seq(context: Context, backing: State, event_sequence: EventSequence, induced_state_seq: Optional[Sequence[State]] = None):
    fluents, events, pre, tri, eff, s0, po, t = context
    if induced_state_seq is None:
        induced_state_seq = (s0, *induced_state_sequence(context, s0, event_sequence))
    if not backing or not induced_state_seq:
        return ()
    *induced_state_seq_,state = induced_state_seq
    t_i = len(induced_state_seq_)
    actual_effects = backing - state
    if not actual_effects:
        return decreasing_seq(context, backing, event_sequence, induced_state_seq_)
    events = event_sequence[t_i]
    event_sets = {}
    for subset in frozen_powerset(events):
        actual_effect = actual_eff(context, subset, state)
        if actual_effect and actual_effect <= actual_effects:
            subsumed = set()
            is_subsumed = False
            for event_set in event_sets:
                if subset < event_set:
                    subsumed.add(event_set)
                elif event_set < subset:
                    is_subsumed = True
            for s in subsumed:
                del event_sets[s]
            if not is_subsumed:
                event_sets[subset] = actual_effect

    w_i = state & backing
    seq = decreasing_seq(context, w_i, event_sequence, induced_state_seq_)
    return ((w_i, (t_i, event_sets)), *seq)


ex_4_decreasing_seq = decreasing_seq(ex_4_context, State({l2,l4}), [{e_neg_1, e_2}, {e_3, e_4}])
ex_4_decreasing_seq


((frozenset({l2}), (1, {frozenset({Event(symbol='e_4')}): frozenset({l4})})),
 (frozenset(), (0, {frozenset({Event(symbol='e_2')}): frozenset({l2})})))

In [61]:
ex_4_decreasing_seq = decreasing_seq(ex_4_context, State({l3,l4}), [{e_neg_1, e_2}, {e_3, e_4}])
ex_4_decreasing_seq

((frozenset(),
  (1,
   {frozenset({Event(symbol='e_4')}): frozenset({l4}),
    frozenset({Event(symbol='e_3')}): frozenset({l3})})),)

In [62]:
def direct_ness_causes(context: Context, backing: State, event_sequence: EventSequence):
    pass

In [63]:
ex_5_backing_candidates = backing_candidates(ex_5_context, {ClassicalFormula(l1), ClassicalFormula(l2), ClassicalFormula(l3)})
for backing_candidate in ex_5_backing_candidates:
    print(backing_candidate)

{l1,l2,l3}


In [64]:
ex_5_decreasing_seq = decreasing_seq(ex_5_context, State({l1,l2,l3}), [{e_}, {e}])
ex_5_decreasing_seq

((frozenset({l1}),
  (1, {frozenset({Event(symbol='e')}): frozenset({l2, l3})})),)