In [1]:
from dataclasses import dataclass, field
from enum import IntEnum
from functools import cached_property
from typing import Optional, Sequence

In [2]:
class Symbol:
    pass

In [3]:
class TopLevelSymbol(Symbol):
    pass

In [4]:
@dataclass(order=True, frozen=True)
class Function(TopLevelSymbol):
    name: Optional[str] = None

    def __str__(self):
        if self.name is None:
            return '()'
        else:
            return self.name

In [5]:
@dataclass(order=True, frozen=True)
class Atom:
    symbol: TopLevelSymbol = field(default_factory=Function)

    def __str__(self):
        return str(self.symbol)

In [6]:
class Literal:
    pass

In [7]:
class Sign(IntEnum):
    NoSign = 0
    Negation = 1

    def __str__(self):
        if self is Sign.NoSign:
            return ''
        elif self is Sign.Negation:
            return 'not'
        else:
            assert False, 'Unknown IntEnum {} = {}.'.format(self.name, self.value)


In [8]:
@dataclass(order=True, frozen=True)
class BasicLiteral(Literal):
    sign: Sign = Sign.NoSign
    atom: Atom = field(default_factory=Atom)

    def __str__(self):
        if self.sign is Sign.NoSign:
            return "{}".format(self.atom)
        else:
            return "{} {}".format(self.sign, self.atom)

    def __neg__(self):
        return BasicLiteral(Sign((self.sign ^ 1) % 2), self.atom)


In [9]:
class Rule:
    @staticmethod
    def fmt_body(body: Sequence[BasicLiteral]):
        return ', '.join(map(str, body))

In [10]:
@dataclass(order=True, frozen=True)
class NormalRule(Rule):
    head: BasicLiteral = field(default_factory=BasicLiteral)
    body: Sequence[BasicLiteral] = ()

    def __str__(self):
        if self.body:
            return "{} :- {}.".format(self.head, Rule.fmt_body(self.body))
        else:
            return "{}.".format(self.head)

In [11]:
@dataclass(order=True, frozen=True)
class Constraint(Rule):
    body: Sequence[BasicLiteral] = ()

    @property
    def head(self):
        return False

    def __str__(self):
        if self.body:
            return '#false :- {}.'.format(Rule.fmt_body(self.body))
        else:
            return '#false.'


In [12]:
@dataclass(order=True, frozen=True)
class Goal(Rule):
    body: Sequence[BasicLiteral] = ()

    @property
    def head(self):
        return True

    def __str__(self):
        if self.body:
            return '#true :- {}.'.format(Rule.fmt_body(self.body))
        else:
            return '#true.'

In [13]:
@dataclass(order=True, frozen=True)
class Program:
    rules: Sequence[Rule] = ()

    def fmt(self, sep=' ', begin=None, end=None):
        b = begin + sep if begin is not None else ''
        e = sep + end if end is not None else ''
        return "{}{}{}".format(b, sep.join(map(str, self.rules)), e)

    def __str__(self):
        return self.fmt()

    @cached_property
    def dual(self):  #type Program
        lit_rules = dict()
        for rule in self.rules:
            head = rule.head
            if isinstance(rule, NormalRule):
                lit_rules.setdefault(head, set()).add(rule)

        dual_rules = []
        for literal, rules in lit_rules.items():
            if isinstance(literal, bool):
                pass
            elif isinstance(literal, BasicLiteral):
                if len(rules) == 1:
                    rules = tuple(rules)
                    rule = rules[0]
                    dual_head = -literal
                    if not rule.body:
                        dual_bodies = ((),)
                    else:
                        dual_bodies = []
                        for body_literal in rule.body:
                            dual_bodies.append((-body_literal,))
                    for dual_body in dual_bodies:
                        dual_rules.append(NormalRule(head=dual_head, body=dual_body))
                elif len(rules) > 1:
                    dual_head = -literal
                    dual_body = []
                    support_dual_rules = []
                    for i, rule in enumerate(rules):
                        if len(rule.body) == 1:
                            dual_body.append(-rule.body[0])
                        elif len(rule.body) > 1:
                            support_dual_head = BasicLiteral(
                                atom=Atom(Function("__not_{}_{}".format(literal.atom.symbol.name, i))))
                            dual_body.append(support_dual_head)
                            if not rule.body:
                                support_dual_bodies = ((),)
                            else:
                                support_dual_bodies = []
                                for body_literal in rule.body:
                                    support_dual_bodies.append((-body_literal,))
                            for support_dual_body in support_dual_bodies:
                                support_dual_rules.append(NormalRule(head=support_dual_head, body=support_dual_body))
                    dual_rules.append(NormalRule(head=dual_head, body=dual_body))
                    dual_rules.extend(support_dual_rules)
        return Program(rules=dual_rules)



In [14]:
q = BasicLiteral(atom=Atom(Function('q')))
p = BasicLiteral(atom=Atom(Function('p')))
r = BasicLiteral(atom=Atom(Function('r')))
###
a = BasicLiteral(atom=Atom(Function('a')))
b = BasicLiteral(atom=Atom(Function('b')))
c = BasicLiteral(atom=Atom(Function('c')))
d = BasicLiteral(atom=Atom(Function('d')))
e = BasicLiteral(atom=Atom(Function('e')))
f = BasicLiteral(atom=Atom(Function('f')))
k = BasicLiteral(atom=Atom(Function('k')))


In [15]:
p1 = Program(rules=(
    NormalRule(head=p, body=(-q,)),
    NormalRule(head=q, body=(-r,)),
    NormalRule(head=r, body=(-p,)),
    NormalRule(head=q, body=(-p,)),
))
print(p1.fmt('\n'))  # AS: {{q, r}}
print('-' * 10)
d1 = p1.dual
print(d1.fmt('\n'))

p :- not q.
q :- not r.
r :- not p.
q :- not p.
----------
not p :- q.
not q :- r, p.
not r :- p.


In [16]:
p2 = Program(rules=(
    NormalRule(head=q, body=(-r,)),
    NormalRule(head=r, body=(-q,)),
    NormalRule(head=p, body=(-p,)),
    NormalRule(head=p, body=(-r,)),
))
print(p2.fmt('\n'))  # AS: {{q, p}}
print('-' * 10)
d2 = p2.dual
print(d2.fmt('\n'))

q :- not r.
r :- not q.
p :- not p.
p :- not r.
----------
not q :- r.
not r :- q.
not p :- r, p.


In [17]:
p3 = Program(rules=(
    NormalRule(head=a, body=(b, d)),
    NormalRule(head=b, body=(d,)),
    NormalRule(head=c, body=(d,)),
    NormalRule(head=d, body=()),
))
print(p3.fmt('\n'))
print('-' * 10)
d3 = p3.dual
print(d3.fmt('\n'))

a :- b, d.
b :- d.
c :- d.
d.
----------
not a :- not b.
not a :- not d.
not b :- not d.
not c :- not d.
not d.


In [18]:
p4 = Program(rules=(
    NormalRule(head=a, body=(k, -b)),
    NormalRule(head=k, body=(e, -b)),
    NormalRule(head=c, body=(a, b)),
    NormalRule(head=b, body=(-a,)),
    NormalRule(head=c, body=(k,)),
    NormalRule(head=f, body=(e, k, -c)),
    NormalRule(head=e),
))
print(p4.fmt('\n'))
print('-' * 10)
d4 = p4.dual
print(d4.fmt('\n'))

a :- k, not b.
k :- e, not b.
c :- a, b.
b :- not a.
c :- k.
f :- e, k, not c.
e.
----------
not a :- not k.
not a :- b.
not k :- not e.
not k :- b.
not c :- __not_c_0, not k.
__not_c_0 :- not a.
__not_c_0 :- not b.
not b :- a.
not f :- not e.
not f :- not k.
not f :- c.
not e.
