In [121]:
import abc
from collections import defaultdict
from copy import copy
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Optional, Iterator, Sequence, TypeVar, Set, Mapping, Union, FrozenSet, Callable, Tuple

import clingo
import clingraph
import more_itertools
from clingraph import Factbase, compute_graphs

In [122]:


def powerset(iterable):
    for i in more_itertools.powerset(iterable):
        yield set(i)

In [123]:


def frozen_powerset(iterable):
    for i in more_itertools.powerset(iterable):
        yield frozenset(i)

In [124]:


def asp_solve(programs,
              ctl: Optional[clingo.Control] = None,
              parts=(('base', ()),),
              context=None,
              report=False,
              report_models=True,
              report_result=True,
              symbol_sep=' ',
              model_sep='\n'
              ) -> Iterator[Sequence[clingo.Symbol]]:
    if ctl is None:
        ctl = clingo.Control(logger=lambda *args, **kwargs: None)
        ctl.configuration.solve.models = 0
    if programs:
        if isinstance(programs, str):
            ctl.add('base', [], programs)
        else:
            for program in programs:
                if isinstance(program, ASPProgram):
                    ctl.add('base', [], str(program))
                elif isinstance(program, str):
                    ctl.add('base', [], program)

    ctl.ground(parts, context=context)
    with ctl.solve(yield_=True) as solve_handle:
        models = 0
        for model in solve_handle:
            symbols = sorted(model.symbols(shown=True))
            if report and report_models:
                print("Answer {}:".format(model.number), end=' ')
                print("{",
                      symbol_sep.join(map(str, sorted(symbols))), "}", sep=symbol_sep, end=model_sep)
            models += 1
            yield symbols
        if report and report_result:
            solve_result = solve_handle.get()
            print(solve_result, end='')
            if solve_result.satisfiable:
                print(" {}{}".format(models, '' if solve_result.exhausted else '+'))
            else:
                print()


def draw_graph(programs,
               ctl: Optional[clingo.Control] = None,
               parts=(('base', ()),)):
    fb = Factbase()
    if ctl is None:
        ctl = clingo.Control(logger=lambda *args, **kwargs: None)
        ctl.configuration.solve.models = 0
    ctl.add('base', [], '\n'.join(programs))
    ctl.ground(parts, clingraph.clingo_utils.ClingraphContext())
    with ctl.solve(yield_=True) as solve_handle:
        for model in solve_handle:
            fb.add_model(model)
            break
    return compute_graphs(fb)

In [125]:


ForwardASPSymbol = TypeVar('ForwardASPSymbol', bound='ASPSymbol')
ForwardASPVariable = TypeVar('ForwardASPVariable', bound='ASPVariable')


class ASPSymbol(abc.ABC):

    @staticmethod
    def from_clingo_symbol(symbol: clingo.Symbol) -> ForwardASPSymbol:
        if symbol.type is clingo.SymbolType.Number:
            return ASPTerm(ASPIntegerConstant(symbol.number))
        elif symbol.type is clingo.SymbolType.String:
            return ASPTerm(ASPStringConstant(symbol.string))
        elif symbol.type is clingo.SymbolType.Function:
            arguments = tuple(ASPFunction.from_clingo_symbol(argument) for argument in symbol.arguments)
            name = symbol.name
            if symbol.negative:
                name = "-{}".format(name)
            return ASPFunction(name, arguments)
        else:
            assert False, "Unhandled clingo.Symbol {} with type {}.".format(symbol, symbol.type.name)


@dataclass(order=True, frozen=True)
class ASPStringConstant:
    string: str = field(default="")

    def __str__(self):
        return '"{}"'.format(self.string)


@dataclass(order=True, frozen=True)
class ASPIntegerConstant:
    number: int = field(default=0)

    def __str__(self):
        return str(self.number)


@dataclass(order=True, frozen=True)
class ASPTerm(ASPSymbol):
    constant: Union[ASPStringConstant, ASPIntegerConstant] = field(default_factory=ASPIntegerConstant)

    def __str__(self):
        return str(self.constant)

    @staticmethod
    def zero():
        return ASPTerm(ASPIntegerConstant(0))

    @staticmethod
    def one():
        return ASPTerm(ASPIntegerConstant(1))


ForwardASPTopLevelSymbol = TypeVar('ForwardASPTopLevelSymbol', bound='ASPTopLevelSymbol')


class ASPTopLevelSymbol(ASPSymbol, abc.ABC):

    def __neg__(self):
        raise NotImplementedError

    @property
    def signature(self) -> str:
        return "{}/{}.".format(self.function_name, self.arity)

    @property
    @abc.abstractmethod
    def function_name(self) -> str:
        raise NotImplementedError

    @property
    def arity(self) -> int:
        return len(self.function_arguments)

    @property
    @abc.abstractmethod
    def function_arguments(self) -> Sequence[ASPSymbol]:
        raise NotImplementedError

    def match(self, other: ForwardASPTopLevelSymbol) -> bool:
        return self.function_name == other.function_name and len(self.function_arguments) == len(
            other.function_arguments)


def asp_evaluate(*programs, report=False) -> Iterator[FrozenSet[ASPTopLevelSymbol]]:
    for answer_set in asp_solve(programs=programs, report=report):
        yield frozenset(ASPFunction.from_clingo_symbol(symbol) for symbol in answer_set)


@dataclass(order=True, frozen=True)
class ASPFunction(ASPTopLevelSymbol):
    name: Optional[str] = field(default=None)
    arguments: Sequence[ASPSymbol] = field(default_factory=tuple)

    @property
    def function_name(self) -> str:
        return self.name

    @property
    def function_arguments(self) -> Sequence[ASPSymbol]:
        return self.arguments

    @property
    def arity(self) -> int:
        return len(self.arguments)

    def __str__(self):
        if self.name is None:
            return '({})'.format(','.join(map(str, self.arguments)))
        elif not self.arguments:
            return self.name
        else:
            return '{}({})'.format(self.name, ','.join(map(str, self.arguments)))

    def __neg__(self):
        if self.name.startswith('-'):
            return ASPFunction(self.name[1:], self.arguments)
        else:
            return ASPFunction('-{}'.format(self.name), self.arguments)


ForwardASPAtom = TypeVar('ForwardASPAtom', bound='ASPAtom')


@dataclass(order=True, frozen=True)
class ASPAtom:
    symbol: ASPTopLevelSymbol = field(default_factory=ASPFunction)

    @property
    def signature(self) -> str:
        return self.symbol.signature

    def match(self, other: ForwardASPAtom) -> bool:
        return self.symbol.function_name == other.symbol.function_name and len(self.symbol.function_arguments) == len(
            other.symbol.function_arguments)

    def __str__(self):
        return str(self.symbol)

    def __neg__(self):
        return ASPAtom(-self.symbol)


@dataclass(order=True, frozen=True)
class ASPClauseElement(abc.ABC):

    @property
    @abc.abstractmethod
    def signature(self) -> str:
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def is_pos(self):
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def is_neg(self):
        raise NotImplementedError

    def __neg__(self):
        raise NotImplementedError

    def __abs__(self):
        raise NotImplementedError


class ASPHeadClauseElement(ASPClauseElement):
    pass


class ASPLiteral(ASPHeadClauseElement):

    @property
    @abc.abstractmethod
    def signature(self) -> str:
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def is_pos(self) -> bool:
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def is_neg(self) -> bool:
        raise NotImplementedError


class ASPSign(IntEnum):
    NoSign = 0
    Negation = 1

    def __str__(self):
        if self is ASPSign.NoSign:
            return ''
        elif self is ASPSign.Negation:
            return 'not'
        else:
            assert False, 'Unknown IntEnum {} = {}.'.format(self.name, self.value)


@dataclass(order=True, frozen=True)
class ASPBasicLiteral(ASPLiteral):
    sign: ASPSign = ASPSign.NoSign
    atom: ASPAtom = field(default_factory=ASPAtom)

    @property
    def is_pos(self) -> bool:
        return self.sign is ASPSign.NoSign

    @property
    def is_neg(self) -> bool:
        return self.sign is ASPSign.Negation

    @property
    def signature(self) -> str:
        return self.atom.signature

    def __str__(self):
        if self.sign is ASPSign.NoSign:
            return "{}".format(self.atom)
        else:
            return "{} {}".format(self.sign, self.atom)

    def __neg__(self):
        return ASPBasicLiteral(ASPSign((self.sign ^ 1) % 2), self.atom)

    def __invert__(self):
        return ASPBasicLiteral(sign=self.sign, atom=-self.atom)

    def __abs__(self):
        return ASPBasicLiteral(ASPSign.NoSign, self.atom)

    def as_classical_atom(self):
        return ClassicalAtom(str(self.atom.symbol))

    @staticmethod
    def make_literal(name: Optional[str], *arguments: Union[str, int, ASPFunction]):
        function_arguments = []
        for arg_ in arguments:
            if isinstance(arg_, str):
                arg = ASPFunction(arg_)
            elif isinstance(arg_, int):
                arg = ASPTerm(ASPIntegerConstant(arg_))
            else:
                arg = arg_
            function_arguments.append(arg)
        return ASPBasicLiteral(atom=ASPAtom(ASPFunction(name=name, arguments=tuple(function_arguments))))


@dataclass(order=True, frozen=True)
class ASPConditionalLiteral(ASPLiteral):
    literal: ASPBasicLiteral = field(default_factory=ASPBasicLiteral)
    conditions: Sequence[ASPClauseElement] = field(default_factory=tuple)

    @property
    def signature(self) -> str:
        return self.literal.signature

    @property
    def is_pos(self) -> bool:
        return self.literal.is_pos

    @property
    def is_neg(self) -> bool:
        return self.literal.is_neg

    def __neg__(self):
        return ASPConditionalLiteral(-self.literal, self.conditions)

    def __abs__(self):
        return ASPConditionalLiteral(abs(self.literal), self.conditions)

    def __str__(self):
        return "{} : {}".format(self.literal, ','.join(map(str, self.conditions)))


@dataclass(order=True, frozen=True)
class ASPDirective(ASPHeadClauseElement):
    name: str
    arguments: Sequence[Union[Sequence[ASPSymbol], ASPSymbol]] = field(default_factory=tuple)

    @property
    def is_true(self) -> bool:
        return self.name == 'true'

    @property
    def is_pos(self) -> bool:
        return self.is_true

    @property
    def is_false(self) -> bool:
        return self.name == 'false'

    @property
    def is_neg(self) -> bool:
        return self.is_false

    @property
    def is_forall(self) -> bool:
        return self.name == 'forall'

    @property
    def is_show(self) -> bool:
        return self.name == 'show'

    @property
    def signature(self) -> str:
        return '#{}/{}.'.format(self.name, len(self.arguments))

    def __abs__(self):
        raise NotImplementedError

    def __neg__(self):
        if self.name == 'true':
            return ASPDirective.false()
        elif self.name == 'false':
            return ASPDirective.true()
        else:
            raise NotImplementedError

    def __invert__(self):
        return -self

    def __str__(self):
        if not self.arguments:
            return "#{}".format(self.name)
        else:
            return "#{}({})".format(self.name, ','.join(map(str, self.arguments)))

    @staticmethod
    def true():
        return ASPDirective('true')

    @staticmethod
    def false():
        return ASPDirective('false')

    @staticmethod
    def show(*args):
        return ASPDirective('show', arguments=args)


@dataclass(order=True, frozen=True)
class ASPRule(abc.ABC):
    head: Optional[ASPHeadClauseElement] = field(default=None)
    body: Optional[Sequence[ASPClauseElement]] = field(default=None)

    @property
    @abc.abstractmethod
    def head_signature(self) -> Union[str, Set[str]]:
        raise NotImplementedError

    @staticmethod
    def fmt_body(body: Sequence[ASPClauseElement]):
        return ', '.join(map(str, body))

    def reduct(self, elems: Set[ASPTopLevelSymbol]):
        if self.body is not None:
            if any(literal.is_neg and literal.atom.symbol in elems for literal in self.body):
                return None
        raise NotImplementedError


@dataclass(order=True, frozen=True)
class ASPNormalRule(ASPRule):
    head: ASPBasicLiteral = field(default_factory=ASPBasicLiteral)
    body: Sequence[ASPClauseElement] = field(default_factory=tuple)

    @property
    def head_signature(self) -> str:
        return self.head.signature

    def __str__(self):
        if self.body:
            return "{} :- {}.".format(self.head, ASPRule.fmt_body(self.body))
        else:
            return "{}.".format(self.head)

    def reduct(self, elems: Set[ASPTopLevelSymbol]):
        if any(literal.is_neg and literal.atom.symbol in elems for literal in self.body):
            return None
        return ASPNormalRule(self.head, tuple(literal for literal in self.body if literal.is_pos))


@dataclass(order=True, frozen=True)
class ASPIntegrityConstraint(ASPRule):
    body: Sequence[ASPClauseElement] = field(default_factory=tuple)
    head: ASPDirective = field(default_factory=ASPDirective.false, init=False)

    @property
    def head_signature(self) -> str:
        return '#false/0.'

    def __str__(self):
        if self.body:
            return ":- {}.".format(ASPRule.fmt_body(self.body))
        else:
            return ":-."

    def reduct(self, elems: Set[ASPTopLevelSymbol]):
        if any(literal.is_neg and literal.atom.symbol in elems for literal in self.body):
            return None
        return ASPIntegrityConstraint(tuple(literal for literal in self.body if literal.is_pos))


@dataclass(order=True, frozen=True)
class ASPDisjunctiveRule(ASPRule):
    head: Sequence[ASPLiteral] = field(default_factory=tuple)
    body: Sequence[ASPClauseElement] = field(default_factory=tuple)

    @property
    def head_signature(self) -> Set[str]:
        return {h.signature for h in self.head}

    def __str__(self):
        return "{} :- {}.".format('; '.join(map(str, self.head)), ASPRule.fmt_body(self.body))

    def reduct(self, elems: Set[ASPTopLevelSymbol]):
        if any(literal.is_neg and literal.atom.symbol in elems for literal in self.body):
            return None
        return ASPDisjunctiveRule(
            head=tuple(literal for literal in self.head if literal.is_pos or (literal.atom.symbol not in elems)),
            body=tuple(literal for literal in self.body if literal.is_pos))


@dataclass(order=True, frozen=True)
class ASPChoiceRule(ASPRule):
    lower: Optional[ASPSymbol] = field(default=None)
    head: Sequence[ASPLiteral] = field(default_factory=tuple)
    upper: Optional[ASPSymbol] = field(default=None)
    body: Sequence[ASPClauseElement] = field(default_factory=tuple)

    @property
    def head_signature(self) -> Union[str, Set[str]]:
        return {h.signature for h in self.head}

    def __str__(self):
        lower = ""
        if self.lower is not None:
            lower = "{} <=".format(self.lower)
        upper = ""
        if self.upper is not None:
            upper = "<= {}".format(self.upper)
        body = ""
        if self.body:
            body = " :- {}".format(ASPRule.fmt_body(self.body))
        return "{}{}{}{}{}{}".format(lower, '{', "; ".join(map(str, self.head)), '}', upper, body)


@dataclass(order=True, frozen=True)
class ASPProgram:
    rules: Sequence[ASPRule] = field(default_factory=tuple)

    def fmt(self, sep='\n'):
        return sep.join(map(str, self.rules))

    def __str__(self):
        return self.fmt(' ')

    def reduct(self, elems: Set[ASPTopLevelSymbol]):
        return ASPProgram(tuple(filter(lambda r: r is not None, (rule.reduct(elems) for rule in self.rules))))


@dataclass(frozen=True, order=True)
class ClassicalAtom:
    symbol: str

    def __neg__(self):
        if self.is_complement:
            return ClassicalAtom(self.symbol[1:])
        else:
            return ClassicalAtom('-{}'.format(self.symbol))

    def __abs__(self):
        if self.is_complement:
            return -self
        return self

    def __str__(self):
        return self.symbol

    @property
    def is_complement(self) -> bool:
        return self.symbol.startswith('-')

    def as_asp_top_level_symbol(self) -> ASPTopLevelSymbol:
        return ASPFunction(self.symbol)

    def as_asp_atom(self) -> ASPAtom:
        return ASPAtom(self.as_asp_top_level_symbol())

    def as_asp_literal(self) -> ASPBasicLiteral:
        return ASPBasicLiteral(atom=self.as_asp_atom())


In [126]:

ClassicalAlphabet = Set[ClassicalAtom]
ClassicalValuation = Mapping[ClassicalAtom, bool]


@dataclass(frozen=True, order=True)
class ClassicalLiteral:
    atom: ClassicalAtom
    sign: bool = field(default=True)

    def __str__(self):
        sign_str = ""
        if not self.sign:
            sign_str = "¬"
        return "{}{}".format(sign_str, self.atom)

    def __repr__(self):
        return str(self)

    def __neg__(self):
        return ClassicalLiteral(self.atom, not self.sign)

    def __and__(self, other):
        left = ClassicalFormula(self)
        right = other
        if isinstance(other, ClassicalLiteral):
            right = ClassicalFormula(right)
        return ClassicalFormula(left, ClassicalConnective.And, right)

    def __or__(self, other):
        left = ClassicalFormula(self)
        right = other
        if isinstance(other, ClassicalLiteral):
            right = ClassicalFormula(right)
        return ClassicalFormula(left, ClassicalConnective.Or, right)

    def __rshift__(self, other):
        left = ClassicalFormula(self)
        right = other
        if isinstance(other, ClassicalLiteral):
            right = ClassicalFormula(right)
        return ClassicalFormula(left, ClassicalConnective.Implies, right)


@dataclass(frozen=True, order=True)
class ClassicalTop(ClassicalLiteral):
    atom: ClassicalAtom = field(default=ClassicalAtom('⊤'), init=False)
    sign: bool = field(default=True, init=False)

    def __str__(self):
        return str(self.atom)

    def __repr__(self):
        return str(self)

    def __neg__(self):
        return ClassicalBot()

    def __and__(self, other):
        if isinstance(other, ClassicalLiteral):
            return ClassicalFormula(other)
        return other

    def __or__(self, other):
        return ClassicalFormula(ClassicalTop())


@dataclass(frozen=True, order=True)
class ClassicalBot(ClassicalLiteral):
    atom: ClassicalAtom = field(default=ClassicalAtom('⊥'), init=False)
    sign: bool = field(default=False, init=False)

    def __neg__(self):
        return ClassicalTop()

    def __str__(self):
        return str(self.atom)

    def __and__(self, other):
        return ClassicalFormula(ClassicalBot())

    def __or__(self, other):
        if isinstance(other, ClassicalLiteral):
            return ClassicalFormula(other)
        return other


class ClassicalConnective(IntEnum):
    And = 0
    Or = 1
    Implies = 2

    def __str__(self):
        if self is ClassicalConnective.And:
            return "∧"
        elif self is ClassicalConnective.Or:
            return "∨"
        elif self is ClassicalConnective.Implies:
            return "→"
        else:
            assert False, "Unhandled Connective.__str__: {} = {}".format(self.name, self.value)

    def evaluate(self, left: bool, right: bool):
        if self is ClassicalConnective.And:
            return left and right
        elif self is ClassicalConnective.Or:
            return left or right
        elif self is ClassicalConnective.Implies:
            return not left or right
        else:
            assert False, "Unhandled Connective.evaluate: {} = {}".format(self.name, self.value)


ForwardClassicalFormula = TypeVar('ForwardClassicalFormula', bound='ClassicalFormula')


@dataclass(frozen=True, order=True)
class ClassicalFormula:
    left: Union[ForwardClassicalFormula, ClassicalLiteral]
    connective: Optional[ClassicalConnective] = field(default=None)
    right: Union[ForwardClassicalFormula, None] = field(default=None)

    def __str__(self):
        left_str = str(self.left)
        connective_str = ""
        if self.connective is not None:
            connective_str = " {}".format(self.connective)
            if isinstance(self.left,
                          ClassicalFormula) and self.left.connective is not None and self.left.connective > self.connective:
                left_str = "({})".format(left_str)
        right_str = ""
        if self.right is not None:
            if self.right.left == ClassicalBot() and self.right.right is None and self.connective is ClassicalConnective.Implies:
                left_str = "¬({})".format(left_str)
                connective_str = ""
            else:
                right_str = " {}".format(self.right)
            if isinstance(self.right,
                          ClassicalFormula) and self.right.connective is not None and self.right.connective > self.connective:
                left_str = "({})".format(left_str)
        return "{}{}{}".format(left_str, connective_str, right_str)

    def __repr__(self):
        return str(self)

    def __neg__(self):
        if self.connective is not None and self.right is None:
            raise TypeError("Formula.connective present, despite Formula.right missing.")
        elif self.connective is None and self.right is not None:
            raise TypeError("Formula.connective missing, despite Formula.right present.")

        if self.connective is None and self.right is None:
            return ClassicalFormula(-self.left)
        elif self.connective is ClassicalConnective.And:
            return ClassicalFormula(-self.left, ClassicalConnective.Or, -self.right)
        elif self.connective is ClassicalConnective.Or:
            return ClassicalFormula(-self.left, ClassicalConnective.And, -self.right)
        elif self.connective is ClassicalConnective.Implies:
            if self.right.left == ClassicalBot() and self.right.right is None:
                return self.left
            return ClassicalFormula(self, ClassicalConnective.Implies, ClassicalFormula(ClassicalBot()))
        else:
            assert False, "Unknown Formula.connective. {} = {}.".format(self.connective.name, self.connective.value)

    def __and__(self, other):
        left = self
        right = other
        if isinstance(other, ClassicalLiteral):
            right = ClassicalFormula(right)
        if left.is_top:
            return right
        elif right.is_top:
            return left
        if left.is_bot:
            return left
        elif right.is_bot:
            return right
        return ClassicalFormula(left, ClassicalConnective.And, right)

    def __or__(self, other):
        left = self
        right = other
        if isinstance(other, ClassicalLiteral):
            right = ClassicalFormula(right)
        if left.is_top:
            return left
        elif right.is_top:
            return right
        if left.is_bot:
            return right
        elif right.is_bot:
            return left
        return ClassicalFormula(left, ClassicalConnective.Or, right)

    def __rshift__(self, other):
        left = self
        right = other
        if isinstance(other, ClassicalLiteral):
            right = ClassicalFormula(right)
        return ClassicalFormula(left, ClassicalConnective.Implies, right)

    def __call__(self, valuation: Optional[ClassicalValuation] = None) -> bool:
        return self.evaluate(valuation)

    @property
    def literals(self) -> Set[ClassicalLiteral]:
        literals = set()
        if isinstance(self.left, ClassicalLiteral):
            if not isinstance(self.left, ClassicalTop) and not isinstance(self.left, ClassicalBot):
                literals.add(self.left)
        else:
            assert isinstance(self.left, ClassicalFormula), "Unknown type for Formula.right. {}: {}".format(
                type(self.left).__name__, self.left)
            literals.update(self.left.literals)
        if self.right is not None:
            assert isinstance(self.right, ClassicalFormula), "Unknown type for Formula.right. {}: {}".format(
                type(self.right).__name__, self.right)
            literals.update(self.right.literals)
        return literals

    @property
    def atoms(self) -> Set[ClassicalAtom]:
        return {literal.atom for literal in self.literals}

    @property
    def is_top(self) -> bool:
        return self.right is None and isinstance(self.left, ClassicalTop)

    @property
    def is_bot(self) -> bool:
        return self.right is None and isinstance(self.left, ClassicalBot)

    def evaluate(self, valuation: Optional[ClassicalValuation] = None) -> bool:
        if isinstance(self.left, ClassicalLiteral):
            value_left = self.__evaluate_literal(self.left, valuation)
        else:
            assert isinstance(self.left, ClassicalFormula), "Unknown type for Formula.left. {}: {}".format(
                type(self.left).__name__, self.left)
            value_left = self.left.evaluate(valuation)
        if self.connective is not None and self.right is None:
            raise TypeError("Formula.connective present, despite Formula.right missing.")
        elif self.connective is None and self.right is not None:
            raise TypeError("Formula.connective missing, despite Formula.right present.")

        if self.connective is None and self.right is None:
            return value_left
        else:
            assert isinstance(self.right, ClassicalFormula), "Unknown type for Formula.right. {}: {}".format(
                type(self.right).__name__, self.right)
            value_right = self.right.evaluate(valuation)

            return self.connective.evaluate(value_left, value_right)

    def __evaluate_literal(self, literal: ClassicalLiteral, valuation: Optional[ClassicalValuation] = None) -> bool:
        if isinstance(literal, ClassicalTop) or isinstance(literal, ClassicalBot):
            return literal.sign
        else:
            # get assigned truth value of atom (per default false) and flip the result if negated
            return valuation is not None and bool(valuation.get(literal.atom, False) ^ (not literal.sign))

    def set_to_bot(self, *atoms: ClassicalAtom) -> ForwardClassicalFormula:
        if isinstance(self.left, ClassicalLiteral):
            assert self.connective is None
            assert self.right is None
            if self.left.atom in atoms:
                if self.left.sign:
                    return ClassicalFormula(ClassicalBot())
                else:
                    return ClassicalFormula(ClassicalTop())
            else:
                return self
        else:
            left = self.left.set_to_bot(*atoms)
            right = self.right
            if right is not None:
                right = self.right.set_to_bot(*atoms)
            return ClassicalFormula(left, self.connective, right)

    def as_asp_bodies(self) -> Sequence[Sequence[ASPBasicLiteral]]:
        bodies: Sequence[Sequence[ASPBasicLiteral]] = [()]
        if isinstance(self.left, ClassicalLiteral):
            if isinstance(self.left, ClassicalTop):
                bodies = ((ASPDirective.true(),),)
            elif isinstance(self.left, ClassicalBot):
                bodies = ((ASPDirective.false(),),)
            else:
                literal = ASPBasicLiteral(sign=ASPSign(int(not self.left.sign)),
                                          atom=ASPAtom(ASPFunction(self.left.atom.symbol)))
                bodies = ((literal,),)
        elif self.right is None:
            bodies = self.left.as_asp_bodies()
        elif self.right is not None:
            bodies_left = self.left.as_asp_bodies()
            bodies_right = self.right.as_asp_bodies()
            bodies_ = []
            if self.connective is ClassicalConnective.Or:
                for body in bodies:
                    for body_left in bodies_left:
                        bodies_.append((*body, *body_left))
                    for body_right in bodies_right:
                        bodies_.append((*body, *body_right))
            elif self.connective is ClassicalConnective.And:
                for body in bodies:
                    for body_left in bodies_left:
                        for body_right in bodies_right:
                            bodies_.append((*body, *body_left, *body_right))
            else:
                assert self.connective is ClassicalConnective.Implies
                if self.right.is_bot:
                    if self.left.connective is not None and self.left.connective is ClassicalConnective.Implies:
                        bodies_ = (-((-self.left.left) | self.left.right)).as_asp_bodies()
                    else:
                        bodies_ = (-self.left).as_asp_bodies()
                else:
                    for body in bodies:
                        for body_left in bodies_left:
                            bodies_.append((*body, *(-literal for literal in body_left)))
                        for body_right in bodies_right:
                            bodies_.append((*body, *body_right))
            bodies = bodies_
        else:
            assert False
        return bodies


def all_valuations(alphabet: ClassicalAlphabet, complete: bool = False) -> Iterator[ClassicalValuation]:
    subsets = powerset(alphabet)
    for subset in subsets:
        valuation = defaultdict(lambda: False)
        for atom in subset:
            valuation[atom] = True
        if complete:
            for atom in alphabet:
                if atom not in subset:
                    valuation[atom] = False
        yield valuation


def models(formulas: Set[ClassicalFormula], alphabet: Optional[ClassicalAlphabet] = None) -> Iterator[
    ClassicalValuation]:
    if alphabet is None:
        alphabet = {atom for formula in formulas for atom in formula.atoms}
    for valuation in all_valuations(alphabet):
        if all(formula.evaluate(valuation) for formula in formulas):
            yield valuation


def sat(formulas: Set[ClassicalFormula], alphabet: Optional[ClassicalAlphabet] = None) -> bool:
    model = next(models(formulas, alphabet), None)
    return model is not None


def unsat(formulas: Set[ClassicalFormula], alphabet: Optional[ClassicalAlphabet] = None) -> bool:
    return not sat(formulas, alphabet)


def entails(formulas: Set[ClassicalFormula], formula: ClassicalFormula) -> bool:
    return unsat(formulas | {-formula})


def valid(formulas: Set[ClassicalFormula], alphabet: Optional[ClassicalAlphabet] = None) -> bool:
    if alphabet is None:
        alphabet = {atom for formula in formulas for atom in formula.atoms}
    for valuation in all_valuations(alphabet):
        if any(not formula.evaluate(valuation) for formula in formulas):
            return False
    return True


In [127]:
@dataclass(order=True, frozen=True)
class Action:
    symbol: str

    def __str__(self):
        return self.symbol



In [128]:
State = FrozenSet[ClassicalAtom]
BeliefState = FrozenSet[State]
ActionTrajectory = Sequence[Action]
Observation = FrozenSet[State]
ObservationTrajectory = Sequence[Observation]
BeliefTrajectory = Sequence[BeliefState]
WorldView = Tuple[ActionTrajectory, ObservationTrajectory]

In [129]:
@dataclass(order=True, frozen=True)
class TransitionSystem:
    S: BeliefState
    R: Callable[[State, Action, State], bool]
    d: Callable[[State, State], int]

    def belief_update(self,
                      alpha: BeliefState,
                      A: Action) -> BeliefState:
        return frozenset(f for e in alpha for f in self.S if self.R(e, A, f))

    def belief_revision(self,
                        kappa: BeliefState,
                        alpha: BeliefState) -> BeliefState:
        return frozenset(w for w in alpha if
                         any(all(all(self.d(w, v1) <= self.d(v2, v3) for v3 in kappa) for v2 in alpha) for v1 in kappa))

    def belief_evolution(self,
                         kappa: BeliefState,
                         W: WorldView) -> BeliefTrajectory:
        obs_cap = None
        act, obs = W
        for i, act_i in enumerate(act):
            obs_i = obs[i]
            obs_cap_ = self.pred(obs_i, act_i)
            if obs_cap is None:
                obs_cap = obs_cap_
            else:
                obs_cap = obs_cap & obs_cap_
        assert obs_cap is not None
        kappa_0 = self.belief_revision(kappa, obs_cap)
        belief_trajectory = [kappa_0]
        for i, act_i in enumerate(act):
            kappa_i = self.belief_update(belief_trajectory[i], act_i)
            belief_trajectory.append(kappa_i)
        return belief_trajectory

    def repair(self, W: WorldView) -> WorldView:
        ACT, OBS = W
        n = len(ACT) - 1
        tau_OBS = []
        *_, last = self.preds(OBS[n], ACT)
        if last:
            tau_OBS.insert(0, OBS[n])
        else:
            tau_OBS.insert(0, self.S)
        for i in reversed(range(n)):
            obs_i = OBS[i]
            act_i = ACT[i]
            pred = self.pred(obs_i, act_i)
            obs_cap = pred
            for j in range(i, n + 1):
                obs_j = OBS[j]
                act_j = ACT[j]
                obs_cap_ = self.pred(obs_j, act_j)
                obs_cap = obs_cap & obs_cap_
                if not obs_cap:
                    break
            if obs_cap:
                tau_OBS.insert(0, obs_i)
            else:
                tau_OBS.insert(0, self.S)
        return ACT, tau_OBS

    def pred(self, alpha: FrozenSet[State], A: Action) -> FrozenSet[State]:
        return frozenset(w for a in alpha for w in self.S if self.R(w, A, a))

    def preds(self, alpha: FrozenSet[State], A: Sequence[Action]) -> Iterator[FrozenSet[State]]:
        alpha_ = alpha
        for a in A:
            alpha_ = self.pred(alpha_, a)
            yield alpha_

In [130]:
dip = Action('dip')
Red = ClassicalAtom('Red')
Blue = ClassicalAtom('Blue')
Acid = ClassicalAtom('Acid')
Litmus = ClassicalAtom('Litmus')

In [131]:
F = frozenset((Red, Blue, Acid, Litmus))
A = frozenset((dip,))

In [132]:
def trans(before: State, action: Action, after: State) -> bool:
    if action != dip:
        return False
    #print(before, "<>", action, "=", after)
    if before == {Litmus}:
        return after == {Litmus, Blue}
    elif before == {Litmus, Acid}:
        return after == {Litmus, Red, Acid}
    elif not before:
        return not after
    elif before == {Acid}:
        return after == {Acid}
    else:
        return False
        # assert False, "Before State: {}{}{}".format('{', ','.join(map(str, before)), '}')

In [133]:
trans({Litmus}, dip, {Litmus, Blue})


True

In [134]:
trans({Litmus, Acid}, dip, {Litmus, Acid, Red})


True

In [135]:
def hamming_distance(A: State, B: State):
    return len(A ^ B)

In [136]:
dipping_test = TransitionSystem(frozenset(frozen_powerset(F)), trans, hamming_distance)

In [137]:
alpha = frozenset((frozenset((Litmus,)), frozenset((Litmus, Acid))))

In [138]:
dipping_test.belief_update(alpha, dip)

frozenset({frozenset({ClassicalAtom(symbol='Acid'),
                      ClassicalAtom(symbol='Litmus'),
                      ClassicalAtom(symbol='Red')}),
           frozenset({ClassicalAtom(symbol='Blue'),
                      ClassicalAtom(symbol='Litmus')})})

In [139]:
O = frozenset(frozen_powerset((Litmus, Acid)))
O

frozenset({frozenset(),
           frozenset({ClassicalAtom(symbol='Acid')}),
           frozenset({ClassicalAtom(symbol='Litmus')}),
           frozenset({ClassicalAtom(symbol='Acid'),
                      ClassicalAtom(symbol='Litmus')})})

In [140]:
E_ = frozenset({frozenset({Blue, Litmus}),
                frozenset({Acid,
                           Litmus,
                           Red})})

In [141]:
dipping_test.belief_revision(E_, O)

frozenset({frozenset({ClassicalAtom(symbol='Litmus')}),
           frozenset({ClassicalAtom(symbol='Acid'),
                      ClassicalAtom(symbol='Litmus')})})

In [142]:
def consistent(W: WorldView, kappa: BeliefTrajectory, T: TransitionSystem) -> bool:
    if len(W) != len(kappa):
        return False
    ACT, OBS = W
    if len(ACT) != len(OBS):
        return False
    for i, w in enumerate(W):
        kappa_i = kappa[i]
        act_i, obs_i = w
        if not kappa_i <= obs_i:
            return False
        if i > 1:
            kappa_i_ = kappa[i - 1]
            if kappa_i != T.belief_update(kappa_i_, act_i):
                return False
    return True

In [143]:
E = copy(alpha)

In [144]:
dipping_test.belief_evolution(kappa=E, W=((dip,), (O,)))

[frozenset({frozenset(), frozenset({ClassicalAtom(symbol='Acid')})}),
 frozenset({frozenset(), frozenset({ClassicalAtom(symbol='Acid')})})]

In [145]:
@dataclass(order=True, frozen=True)
class EffectProposition:
    action: Action
    effect: ClassicalLiteral
    conditions: FrozenSet[ClassicalLiteral]

    def __call__(self, state: State, *args, **kwargs):
        if not all(condition.sign == (condition.atom in state) for condition in self.conditions):
            return self
        return (state - {-self.effect, self.effect}) | {self.effect}


ActionDescription = FrozenSet[EffectProposition]

In [146]:
def generate_transition_relation(AD: ActionDescription) -> Callable[[State, Action, State], bool]:
    def transition_relation(before: State, action: Action, after: State) -> bool:
        valid = False
        for ep in AD:
            if ep.action != action:
                continue
            if not all(condition.sign == (condition.atom in before) for condition in ep.conditions):
                continue
            if not ep.effect in after:
                continue
            if not after <= (before | {ep.effect}):
                continue
            valid = True
        return valid

    return transition_relation

In [147]:
@dataclass(order=True, frozen=True)
class SensingProposition:
    action: Action
    effect: ClassicalLiteral
    conditions: FrozenSet[ClassicalLiteral]

    def __call__(self, belief_state: BeliefState, *args, **kwargs):
        belief_state_ = set()
        for state in belief_state:
            if not all(condition.sign == (condition.atom in state) for condition in self.conditions):
                continue
            belief_state_.add((state - {-self.effect, self.effect}) | {self.effect})
        return frozenset(belief_state_)


ActionSensingDescription = FrozenSet[Union[SensingProposition, EffectProposition]]

In [148]:
def generate_epistemic_transition_relation(ASD: ActionSensingDescription, T: TransitionSystem) -> Callable[
    [BeliefState, Action, BeliefState], bool]:
    def transition_relation(before: BeliefState, action: Action, after: BeliefState) -> bool:
        k_star = frozenset(T.belief_update(k, action) for k in before)
        valid = False
        for p in ASD:
            if isinstance(p, SensingProposition):
                before_ = before - (
                            {-p.effect} | {-condition for condition in p.conditions} | {p.effect} | p.conditions)
                if after <= (before_ | p.conditions | {p.effect}):
                    valid = True
                if after <= (before_ | p.conditions | {-p.effect}):
                    valid = True
                if after <= (before_ | {-condition for condition in p.conditions}):
                    valid = True
            elif isinstance(p, EffectProposition):
                if k_star != after:
                    return False
            else:
                assert False
        return valid

    return transition_relation

In [149]:
Feed = Action('Feed')
LookAtCroc = Action('LookAtCroc')
FullChicken = ClassicalAtom('FullChicken')
FullChicken_ = ClassicalLiteral(FullChicken)
FullDuck = ClassicalAtom('FullDuck')
FullDuck_ = ClassicalLiteral(FullDuck)
Chicken = ClassicalAtom('Chicken')
Chicken_ = ClassicalLiteral(Chicken)
Sick = ClassicalAtom('Sick')
Sick_ = ClassicalLiteral(Sick)

F = frozenset({FullChicken, FullDuck, Chicken, Sick})

r1 = EffectProposition(Feed, FullChicken_, frozenset({Chicken_}))
r2 = EffectProposition(Feed, -Sick_, frozenset({Chicken_}))
r3 = EffectProposition(Feed, FullDuck_, frozenset({-Chicken_}))
r4 = EffectProposition(Feed, Sick_, frozenset({FullDuck_, -Chicken_}))
r5 = SensingProposition(LookAtCroc, Sick_, frozenset({Sick_}))

AD: ActionDescription = frozenset({r1, r2, r3, r4})
ASD: ActionSensingDescription = frozenset({r1, r2, r3, r4, r5})

In [150]:
trans = generate_transition_relation(AD)
croc_sys = TransitionSystem(frozenset(frozen_powerset(F)), trans, hamming_distance)
croc_sys

TransitionSystem(S=frozenset({frozenset({ClassicalAtom(symbol='FullChicken'), ClassicalAtom(symbol='FullDuck'), ClassicalAtom(symbol='Chicken')}), frozenset({ClassicalAtom(symbol='Sick')}), frozenset({ClassicalAtom(symbol='FullDuck'), ClassicalAtom(symbol='Chicken')}), frozenset({ClassicalAtom(symbol='FullDuck'), ClassicalAtom(symbol='Sick'), ClassicalAtom(symbol='FullChicken')}), frozenset({ClassicalAtom(symbol='FullChicken')}), frozenset({ClassicalAtom(symbol='Sick'), ClassicalAtom(symbol='FullChicken')}), frozenset({ClassicalAtom(symbol='Chicken')}), frozenset({ClassicalAtom(symbol='FullDuck'), ClassicalAtom(symbol='Sick')}), frozenset({ClassicalAtom(symbol='FullDuck'), ClassicalAtom(symbol='FullChicken')}), frozenset({ClassicalAtom(symbol='FullChicken'), ClassicalAtom(symbol='Chicken')}), frozenset({ClassicalAtom(symbol='FullChicken'), ClassicalAtom(symbol='Sick'), ClassicalAtom(symbol='Chicken')}), frozenset({ClassicalAtom(symbol='FullChicken'), ClassicalAtom(symbol='FullDuck'), C

In [151]:
def belief_state_from_models(formulas: Set[ClassicalFormula], alphabet=None) -> BeliefState:
    return frozenset(
        {frozenset({atom for atom, value in model.items() if value}) for model in models(formulas, alphabet)})

In [152]:
init_kappa = frozenset({frozenset({atom for atom, value in model.items() if value}) for model in
                        models({-FullChicken_ & -FullDuck_ & Chicken_}, F)})
init_kappa

frozenset({frozenset({ClassicalAtom(symbol='Chicken')}),
           frozenset({ClassicalAtom(symbol='Chicken'),
                      ClassicalAtom(symbol='Sick')})})

In [153]:
@dataclass(order=True, frozen=True)
class BeliefProposition:
    action: Action
    effect: ClassicalFormula
    conditions: FrozenSet[ClassicalLiteral]

    def applicable(self, state: State) -> bool:
        return all(condition.sign == (condition in state) for condition in self.conditions)

In [154]:
PointedBeliefState = Tuple[State, BeliefState]
ActionBeliefDescription = FrozenSet[Union[EffectProposition, BeliefProposition]]
BeliefDescription = FrozenSet[BeliefProposition]

In [155]:
def EFF(O: Action, BD: BeliefDescription, s: State):
    formulas = []
    for p in BD:
        if p.action == O and p.applicable(s):
            formulas.append(p.effect)
    return formulas

In [156]:
r5_ = BeliefProposition(LookAtCroc, ClassicalFormula(Sick_), frozenset({Sick_}))

AD: ActionDescription = frozenset({r1, r2, r3, r4})
BD: BeliefDescription = frozenset({r5_})

In [157]:
@dataclass(order=True, frozen=True)
class ActionDescriptionTransitionSystem:
    AD: ActionDescription
    BD: BeliefDescription
    T: TransitionSystem

    def transition(self, pointed_belief_state: PointedBeliefState, sensing_action: Optional[Action] = None,
                   *non_sensing_actions: Action) -> PointedBeliefState:
        s, k = pointed_belief_state
        s_ = frozenset({s})
        for non_sensing_action in non_sensing_actions:
            s_ = self.T.belief_update(s_, non_sensing_action)
        k_ = k
        if sensing_action is not None:
            phi = EFF(sensing_action, self.BD, s_)
            alpha = belief_state_from_models(phi, F)
            *_, k_ = self.T.preds(alpha, non_sensing_actions)
            for non_sensing_action in non_sensing_actions:
                k_ = self.T.belief_update(k_, non_sensing_action)
        return s_, k_


In [158]:
croc_ad_sys = ActionDescriptionTransitionSystem(AD, BD, croc_sys)
croc_ad_sys

ActionDescriptionTransitionSystem(AD=frozenset({EffectProposition(action=Action(symbol='Feed'), effect=¬Sick, conditions=frozenset({Chicken})), EffectProposition(action=Action(symbol='Feed'), effect=FullDuck, conditions=frozenset({¬Chicken})), EffectProposition(action=Action(symbol='Feed'), effect=Sick, conditions=frozenset({FullDuck, ¬Chicken})), EffectProposition(action=Action(symbol='Feed'), effect=FullChicken, conditions=frozenset({Chicken}))}), BD=frozenset({BeliefProposition(action=Action(symbol='LookAtCroc'), effect=Sick, conditions=frozenset({Sick}))}), T=TransitionSystem(S=frozenset({frozenset({ClassicalAtom(symbol='FullChicken'), ClassicalAtom(symbol='FullDuck'), ClassicalAtom(symbol='Chicken')}), frozenset({ClassicalAtom(symbol='Sick')}), frozenset({ClassicalAtom(symbol='FullDuck'), ClassicalAtom(symbol='Chicken')}), frozenset({ClassicalAtom(symbol='FullDuck'), ClassicalAtom(symbol='Sick'), ClassicalAtom(symbol='FullChicken')}), frozenset({ClassicalAtom(symbol='FullChicken')

In [159]:
init_s = frozenset({Chicken})

In [160]:
croc_ad_sys.transition((init_s, init_kappa), LookAtCroc, Feed)

(frozenset(), frozenset())