In [164]:
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass, field
from enum import IntEnum
from functools import cached_property
from typing import Optional, Sequence, Iterator, Dict, Set, TypeVar, MutableSequence

In [165]:
class Symbol:
    pass

In [166]:
class TopLevelSymbol(Symbol):
    pass

In [167]:
@dataclass(order=True, frozen=True)
class Function(TopLevelSymbol):
    name: Optional[str] = None

    def __str__(self):
        if self.name is None:
            return '()'
        else:
            return self.name

In [168]:
@dataclass(order=True, frozen=True)
class Atom:
    symbol: TopLevelSymbol = field(default_factory=Function)

    def __str__(self):
        return str(self.symbol)

In [169]:
class Literal:
    pass

In [170]:
class Sign(IntEnum):
    NoSign = 0
    Negation = 1

    def __str__(self):
        if self is Sign.NoSign:
            return ''
        elif self is Sign.Negation:
            return 'not'
        else:
            assert False, 'Unknown IntEnum {} = {}.'.format(self.name, self.value)


In [171]:
@dataclass(order=True, frozen=True)
class BasicLiteral(Literal):
    sign: Sign = Sign.NoSign
    atom: Atom = field(default_factory=Atom)

    def __str__(self):
        if self.sign is Sign.NoSign:
            return "{}".format(self.atom)
        else:
            return "{} {}".format(self.sign, self.atom)

    def __neg__(self):
        return BasicLiteral(Sign((self.sign ^ 1) % 2), self.atom)

    def __abs__(self):
        return BasicLiteral(Sign.NoSign, self.atom)

    @property
    def is_pos(self) -> bool:
        return self.sign is Sign.NoSign

    @property
    def is_neg(self) -> bool:
        return self.sign is Sign.Negation

In [172]:
class Rule:
    @staticmethod
    def fmt_body(body: Sequence[BasicLiteral]):
        return ', '.join(map(str, body))

In [173]:
@dataclass(order=True, frozen=True)
class NormalRule(Rule):
    head: BasicLiteral = field(default_factory=BasicLiteral)
    body: Sequence[BasicLiteral] = ()

    def __str__(self):
        if self.body:
            return "{} :- {}.".format(self.head, Rule.fmt_body(self.body))
        else:
            return "{}.".format(self.head)

In [174]:
@dataclass(order=True, frozen=True)
class IntegrityConstraint(Rule):
    body: Sequence[BasicLiteral] = ()

    @property
    def head(self):
        return False

    def __str__(self):
        if self.body:
            return '#false :- {}.'.format(Rule.fmt_body(self.body))
        else:
            return '#false.'


In [175]:
@dataclass(order=True, frozen=True)
class Goal(Rule):
    body: Sequence[BasicLiteral] = ()

    @property
    def head(self):
        return True

    def __str__(self):
        if self.body:
            return '#true :- {}.'.format(Rule.fmt_body(self.body))
        else:
            return '#true.'

In [176]:
ForwardProof = TypeVar('ForwardProof', bound='Proof')


@dataclass
class Proof:
    parent: ForwardProof = field(repr=False, default=None)
    idx: int = 0
    subject: Optional[Rule] = field(default=None)
    children: MutableSequence[ForwardProof] = field(default_factory=list)
    hypotheses: Set[Literal] = field(default_factory=set)

In [177]:
@dataclass(order=True, frozen=True)
class Program:
    rules: Sequence[Rule] = ()

    def fmt(self, sep=' ', begin=None, end=None):
        b = begin + sep if begin is not None else ''
        e = sep + end if end is not None else ''
        return "{}{}{}".format(b, sep.join(map(str, self.rules)), e)

    def __str__(self):
        return self.fmt()

    @cached_property
    def dual(self):  #type Program
        return Program.dual_of(self.rules)

    @cached_property
    def sASP(self):
        sasp_rules = list(self.rules)
        sasp_rules.extend(self.dual_of(tuple(self.non_constraint_rules)).rules)

        chk_rules = []
        nmr_chk_head = BasicLiteral(atom=Atom(Function("__nmr_chk")))
        nmr_chk_body = []
        for i, c_rule in enumerate(self.constraint_rules):
            chk_head = BasicLiteral(atom=Atom(Function("__chk_{}_{}".format(c_rule.head.atom.symbol.name, i))))
            chk_body = (-c_rule.head, *(body_literal for body_literal in c_rule.body if -c_rule.head != body_literal))
            chk_rule = NormalRule(head=chk_head, body=chk_body)
            nmr_chk_body.append(-chk_head)

            chk_rules.append(chk_rule)
        sasp_rules.extend(Program.dual_of(chk_rules).rules)
        nmr_chk_rule = NormalRule(head=nmr_chk_head, body=nmr_chk_body)
        sasp_rules.append(nmr_chk_rule)
        return Program(rules=sasp_rules)

    @cached_property
    def reachable(self) -> Dict[Rule, Set[Literal]]:
        reachable = defaultdict(set)
        for rule in self.rules:
            considered = set()
            literal_stack = []
            naf_stack = []
            literal_stack.extend(rule.body)
            naf_stack.extend([0 for _ in rule.body])
            while literal_stack:
                literal = literal_stack.pop()
                naf = naf_stack.pop()
                naf = (naf + literal.is_neg) % 2
                reachable[rule].add(BasicLiteral(Sign(naf), literal.atom))
                for adj in self.rules:
                    if adj not in considered and adj.head == abs(literal):
                        considered.add(adj)
                        literal_stack.extend(adj.body)
                        naf_stack.extend([naf for _ in adj.body])
        return reachable

    @property
    def constraint_rules(self) -> Iterator[Rule]:
        for rule, reachable in self.reachable.items():
            if -rule.head in reachable:
                yield rule

    @property
    def non_constraint_rules(self) -> Iterator[Rule]:
        for rule, reachable in self.reachable.items():
            if rule.head in reachable or -rule.head not in reachable:
                yield rule

    def evaluate_top_down(self, *atoms: Atom):
        hypothesis_set = set()
        for rule in self.rules:
            if not rule.body:
                hypothesis_set.add(rule.head)
        proofs = []
        rules = self.sASP.rules
        __nmr_chk = BasicLiteral(atom=Atom(Function('__nmr_chk')))
        root = Proof(subject=Goal(body=(*atoms, __nmr_chk)), hypotheses=hypothesis_set)
        stack = [root]
        while stack:
            current = stack.pop()
            assert isinstance(current, Proof)
            if current.idx >= len(current.subject.body):
                if current.parent is None:
                    proofs.append(current)
                else:
                    parent = deepcopy(current.parent)
                    parent.hypotheses = current.hypotheses
                    parent.idx += 1
                    stack.append(parent)
            else:
                literal = current.subject.body[current.idx]
                for rule in rules:
                    if rule.head == literal:
                        child = Proof(parent=current, subject=rule, hypotheses=current.hypotheses)
                        current.children.append(child)
                        stack.append(child)

        return proofs

    @staticmethod
    def dual_of(rules):
        lit_rules = dict()
        for rule in rules:
            head = rule.head
            if isinstance(rule, NormalRule):
                lit_rules.setdefault(head, set()).add(rule)

        dual_rules = []
        for literal, rules in lit_rules.items():
            if isinstance(literal, bool):
                pass
            elif isinstance(literal, BasicLiteral):
                if len(rules) == 1:
                    rules = tuple(rules)
                    rule = rules[0]
                    dual_head = -literal
                    if not rule.body:
                        dual_bodies = ((),)
                    else:
                        dual_bodies = []
                        for body_literal in rule.body:
                            dual_bodies.append((-body_literal,))
                    for dual_body in dual_bodies:
                        dual_rules.append(NormalRule(head=dual_head, body=dual_body))
                elif len(rules) > 1:
                    dual_head = -literal
                    dual_body = []
                    support_dual_rules = []
                    for i, rule in enumerate(rules):
                        if len(rule.body) == 1:
                            dual_body.append(-rule.body[0])
                        elif len(rule.body) > 1:
                            support_dual_head = BasicLiteral(
                                atom=Atom(Function("__not_{}_{}".format(literal.atom.symbol.name, i))))
                            dual_body.append(support_dual_head)
                            if not rule.body:
                                support_dual_bodies = ((),)
                            else:
                                support_dual_bodies = []
                                for body_literal in rule.body:
                                    support_dual_bodies.append((-body_literal,))
                            for support_dual_body in support_dual_bodies:
                                support_dual_rules.append(NormalRule(head=support_dual_head, body=support_dual_body))
                    dual_rules.append(NormalRule(head=dual_head, body=dual_body))
                    dual_rules.extend(support_dual_rules)
        return Program(rules=dual_rules)

In [178]:
q = BasicLiteral(atom=Atom(Function('q')))
p = BasicLiteral(atom=Atom(Function('p')))
r = BasicLiteral(atom=Atom(Function('r')))
###
a = BasicLiteral(atom=Atom(Function('a')))
b = BasicLiteral(atom=Atom(Function('b')))
c = BasicLiteral(atom=Atom(Function('c')))
d = BasicLiteral(atom=Atom(Function('d')))
e = BasicLiteral(atom=Atom(Function('e')))
f = BasicLiteral(atom=Atom(Function('f')))
k = BasicLiteral(atom=Atom(Function('k')))


In [179]:
p1 = Program(rules=(
    NormalRule(head=p, body=(-q,)),
    NormalRule(head=q, body=(-r,)),
    NormalRule(head=r, body=(-p,)),
    NormalRule(head=q, body=(-p,)),
))
print(p1.fmt('\n'))  # AS: {{q, r}}
print('-' * 10)
d1 = p1.dual
print(d1.fmt('\n'))
print('-' * 10)
s1 = p1.sASP
print(s1.fmt('\n'))


p :- not q.
q :- not r.
r :- not p.
q :- not p.
----------
not p :- q.
not q :- p, r.
not r :- p.
----------
p :- not q.
q :- not r.
r :- not p.
q :- not p.
not p :- q.
not q :- p.
not __chk_p_0 :- p.
not __chk_p_0 :- q.
not __chk_q_1 :- q.
not __chk_q_1 :- r.
not __chk_r_2 :- r.
not __chk_r_2 :- p.
__nmr_chk :- not __chk_p_0, not __chk_q_1, not __chk_r_2.


In [180]:
answer_sets = p1.evaluate_top_down(Goal((q,)))
print("q:")
for answer_set in answer_sets:
    print("{", end='')
    print(', '.join(map(str, answer_set)), end='')
    print("}")

print("r:")
answer_sets = p1.evaluate_top_down(Goal((r,)))
for answer_set in answer_sets:
    print("{", end='')
    print(', '.join(map(str, answer_set)), end='')
    print("}")

q:
r:


In [181]:
p2 = Program(rules=(
    NormalRule(head=q, body=(-r,)),
    NormalRule(head=r, body=(-q,)),
    NormalRule(head=p, body=(-p,)),
    NormalRule(head=p, body=(-r,)),
))
print(p2.fmt('\n'))  # AS: {{q, p}}
print('-' * 10)
d2 = p2.dual
print(d2.fmt('\n'))
print('-' * 10)
s2 = p2.sASP
print(s2.fmt('\n'))


q :- not r.
r :- not q.
p :- not p.
p :- not r.
----------
not q :- r.
not r :- q.
not p :- p, r.
----------
q :- not r.
r :- not q.
p :- not p.
p :- not r.
not q :- r.
not r :- q.
not p :- p, r.
not __chk_p_0 :- p.
__nmr_chk :- not __chk_p_0.


In [182]:
answer_sets = p2.evaluate_top_down(Goal((p,)))
print("p:")
for answer_set in answer_sets:
    print("{", end='')
    print(', '.join(map(str, answer_set)), end='')
    print("}")

print("q:")
answer_sets = p2.evaluate_top_down(Goal((q,)))
for answer_set in answer_sets:
    print("{", end='')
    print(', '.join(map(str, answer_set)), end='')
    print("}")

p:
q:


In [183]:
p3 = Program(rules=(
    NormalRule(head=a, body=(b, d)),
    NormalRule(head=b, body=(d,)),
    NormalRule(head=c, body=(d,)),
    NormalRule(head=d, body=()),
))
print(p3.fmt('\n'))
print('-' * 10)
d3 = p3.dual
print(d3.fmt('\n'))

a :- b, d.
b :- d.
c :- d.
d.
----------
not a :- not b.
not a :- not d.
not b :- not d.
not c :- not d.
not d.


In [184]:
answer_sets = p3.evaluate_top_down(Goal((a,)))
print("a:")
for answer_set in answer_sets:
    print("{", end='')
    print(', '.join(map(str, answer_set)), end='')
    print("}")

print("c:")
answer_sets = p3.evaluate_top_down(Goal((c,)))
for answer_set in answer_sets:
    print("{", end='')
    print(', '.join(map(str, answer_set)), end='')
    print("}")

a:
c:


In [185]:
p4 = Program(rules=(
    NormalRule(head=a, body=(k, -b)),
    NormalRule(head=k, body=(e, -b)),
    NormalRule(head=c, body=(a, b)),
    NormalRule(head=b, body=(-a,)),
    NormalRule(head=c, body=(k,)),
    NormalRule(head=f, body=(e, k, -c)),
    NormalRule(head=e),
))
print(p4.fmt('\n'))
print('-' * 10)
d4 = p4.dual
print(d4.fmt('\n'))

a :- k, not b.
k :- e, not b.
c :- a, b.
b :- not a.
c :- k.
f :- e, k, not c.
e.
----------
not a :- not k.
not a :- b.
not k :- not e.
not k :- b.
not c :- not k, __not_c_1.
__not_c_1 :- not a.
__not_c_1 :- not b.
not b :- a.
not f :- not e.
not f :- not k.
not f :- c.
not e.


In [186]:
answer_sets = p4.evaluate_top_down(Goal((b,)))
print("b:")
for answer_set in answer_sets:
    print("{", end='')
    print(', '.join(map(str, answer_set)), end='')
    print("}")

answer_sets = p4.evaluate_top_down(Goal((e,)))
print("e:")
for answer_set in answer_sets:
    print("{", end='')
    print(', '.join(map(str, answer_set)), end='')
    print("}")

answer_sets = p4.evaluate_top_down(Goal((b, e, f)))
print("b,e,f:")
for answer_set in answer_sets:
    print("{", end='')
    print(', '.join(map(str, answer_set)), end='')
    print("}")

print("c:")
answer_sets = p4.evaluate_top_down(Goal((c,)))
for answer_set in answer_sets:
    print("{", end='')
    print(', '.join(map(str, answer_set)), end='')
    print("}")

b:
e:
b,e,f:
c:


In [187]:
p5 = Program(rules=(
    NormalRule(head=p, body=(a, -q)),
    NormalRule(head=q, body=(b, -r)),
    NormalRule(head=r, body=(c, -p)),
    NormalRule(head=q, body=(d, -p)),
))
print("% C:")
print('\n'.join(map(str, set(p5.constraint_rules))))
print("% NC:")
print('\n'.join(map(str, set(p5.non_constraint_rules))))

% C:
p :- a, not q.
r :- c, not p.
q :- b, not r.
% NC:
q :- d, not p.
p :- a, not q.


In [188]:
p6 = Program(rules=(
    NormalRule(head=a, body=(-b,)),
    NormalRule(head=b, body=(-a,)),
))
print(p6.fmt('\n'))  # AS: {{q, p}}
print('-' * 10)
d6 = p6.dual
print(d6.fmt('\n'))
print('-' * 10)
s6 = p6.sASP
print(s6.fmt('\n'))

a :- not b.
b :- not a.
----------
not a :- b.
not b :- a.
----------
a :- not b.
b :- not a.
not a :- b.
not b :- a.
__nmr_chk.


In [189]:
p7 = Program(rules=(
    NormalRule(head=a, body=(-b,)),
    NormalRule(head=b, body=(-c,)),
    NormalRule(head=c, body=(-a,)),
))
print(p7.fmt('\n'))  # AS: {{q, p}}
print('-' * 10)
d7 = p7.dual
print(d7.fmt('\n'))
print('-' * 10)
s7 = p7.sASP
print(s7.fmt('\n'))

a :- not b.
b :- not c.
c :- not a.
----------
not a :- b.
not b :- c.
not c :- a.
----------
a :- not b.
b :- not c.
c :- not a.
not __chk_a_0 :- a.
not __chk_a_0 :- b.
not __chk_b_1 :- b.
not __chk_b_1 :- c.
not __chk_c_2 :- c.
not __chk_c_2 :- a.
__nmr_chk :- not __chk_a_0, not __chk_b_1, not __chk_c_2.
