In [1]:
import copy
from dataclasses import dataclass, field
from enum import IntEnum
from typing import TypeVar, Sequence, Optional, Dict, Mapping, MutableMapping, MutableSequence, List, Iterator

import clingo


In [2]:
ForwardSymbol = TypeVar('ForwardSymbol', bound='Symbol')


class Symbol:
    def is_function(self) -> bool:
        return isinstance(self, Function)

    def is_unary_operation(self) -> bool:
        return isinstance(self, UnaryOperation)

    def is_variable(self) -> bool:
        return isinstance(self, Variable)

    @classmethod
    def from_clingo_symbol(cls, symbol: clingo.Symbol) -> ForwardSymbol:
        if symbol.type is clingo.SymbolType.Function:
            name: str = symbol.name
            arguments = tuple(SubSymbol.from_clingo_symbol(argument) for argument in symbol.arguments)
            return Function(name, arguments)
        else:
            assert False, "Unknown clingo.SymbolType {}.".format(symbol.type)


class SubSymbol(Symbol):

    @classmethod
    def from_clingo_symbol(cls, symbol: clingo.Symbol) -> ForwardSymbol:
        if symbol.type is clingo.SymbolType.Number:
            return Term(IntegerConstant(symbol.number))
        elif symbol.type is clingo.SymbolType.Function:
            return Symbol.from_clingo_symbol(symbol)
        else:
            assert False, "Unknown clingo.SymbolType {}.".format(symbol.type)


In [3]:
@dataclass(frozen=True, order=True)
class Variable(SubSymbol):
    name: str

    def __str__(self):
        return self.name


@dataclass(frozen=True, order=True)
class IntegerConstant:
    number: int = 0

    def __str__(self):
        return str(self.number)


@dataclass(frozen=True, order=True)
class Term(SubSymbol):
    constant: IntegerConstant = field(default=IntegerConstant())

    def __str__(self):
        return str(self.constant)

In [4]:
ForwardAtom = TypeVar('ForwardAtom', bound='Atom')
ForwardFunction = TypeVar('ForwardFunction', bound='Function')
ForwardUnaryOperator = TypeVar('ForwardUnaryOperator', bound='UnaryOperator')


@dataclass
class Function(Symbol):
    name: Optional[str] = None
    arguments: Sequence[SubSymbol] = ()

    @property
    def arity(self):
        return len(self.arguments)

    def __str__(self):
        if self.name is None and not self.arguments:
            return "()"
        elif self.name is not None and not self.arguments:
            return self.name
        elif self.name is None and self.arguments:
            return "({})".format(','.join(map(str, self.arguments)))
        else:
            return "{}({})".format(self.name, ','.join(map(str, self.arguments)))

    def match(self, name: Optional[str], arity: int = 0) -> bool:
        return name == self.name and arity == len(self.arguments)

    def match_signature(self, other: ForwardFunction) -> bool:
        return self.match(other.name, other.arity)

    def recursive_rename(self, new_name: Optional[str], name: Optional[str], arity: int = 0) -> ForwardFunction:
        stack: List[Function] = [self]
        while stack:
            current: Function = stack.pop()
            if current.match(name, arity):
                current.name = new_name
            for arg in current.arguments:
                if isinstance(arg, Function):
                    stack.append(arg)
        return self


class UnaryOperatorType(IntEnum):
    Minus = 1


@dataclass
class UnaryOperation:
    operator: UnaryOperatorType
    argument: SubSymbol

    def __str__(self) -> str:
        if self.operator is UnaryOperatorType.Minus:
            return "(-{})".format(self.argument)
        else:
            assert False, "Unknown UnaryOperatorType {}.".format(self.operator)


class BinaryOperatorType(IntEnum):
    Plus = 2


@dataclass
class BinaryOperation:
    left: SubSymbol
    operator: BinaryOperatorType
    right: SubSymbol

    def __str__(self) -> str:
        if self.operator is BinaryOperatorType.Plus:
            return "({}+{})".format(self.left, self.right)


In [5]:
@dataclass
class Atom:
    symbol: Symbol = field(default_factory=Function)

    def __str__(self) -> str:
        return str(self.symbol)

    def is_function(self) -> bool:
        return isinstance(self.symbol, Function)

    def is_unary_operation(self) -> bool:
        return isinstance(self.symbol, UnaryOperation)

    def get_top_function(self) -> Function:
        current = self.symbol
        while not isinstance(current, Function):
            if isinstance(current, UnaryOperation):
                current = current.argument
            else:
                assert False, "Unknown Type {} for Symbol {}.".format(type(current).__name__, current)
        return current

    def fill(self, env: Mapping[Variable, SubSymbol]) -> ForwardAtom:
        stack: List[SubSymbol] = [self.get_top_function()]
        while stack:
            current: SubSymbol = stack.pop()
            if current.is_function():
                assert isinstance(current, Function)
                arguments = list(current.arguments)
                i = 0
                while i < current.arity:
                    arg = arguments[i]
                    if arg.is_variable() and arg in env:
                        arguments[i] = env[arg]
                    else:
                        stack.append(arg)
                    i += 1
                assert len(current.arguments) == len(
                    arguments), "Function.fill should not change arity of Function."
                current.arguments = arguments
            elif current.is_unary_operation():
                assert isinstance(current, UnaryOperation)
                if current.argument.is_variable():
                    current.argument = env[current.argument]
                else:
                    stack.append(current.argument)
        return self

    @staticmethod
    def from_clingo_symbol(symbol: clingo.Symbol) -> ForwardAtom:
        assert symbol.type is clingo.SymbolType.Function, "clingo.Symbol {} should have type {}, but has type {}.".format(
            symbol, clingo.SymbolType.Function, symbol.type)
        return Atom(Symbol.from_clingo_symbol(symbol))




In [6]:
class Sign(IntEnum):
    NoSign = 0
    DefaultNeg = 1

In [7]:
@dataclass
class Literal:
    atom: Atom = field(default_factory=Atom)
    sign: Sign = Sign.NoSign

    def __str__(self):
        if self.sign is Sign.DefaultNeg:
            return "not {}".format(self.atom)
        return str(self.atom)

    def __abs__(self):
        return Literal(sign=Sign.NoSign, atom=copy.deepcopy(self.atom))

In [8]:
ForwardRule = TypeVar('ForwardRule', bound='Rule')


@dataclass
class Rule:
    head: Optional[Literal] = None
    body: MutableSequence[Literal] = ()

    def __str__(self) -> str:
        if self.head is None and not self.body:
            return ":-."
        elif self.head is None:
            return ":- {}.".format(', '.join(map(str, self.body)))
        elif not self.body:
            return "{}.".format(self.head)
        else:
            return "{} :- {}.".format(self.head, ', '.join(map(str, self.body)))

    def is_ground(self) -> bool:
        pass

    def is_fact(self) -> bool:
        return self.head is not None and not self.body

    def is_constraint(self) -> bool:
        return self.head is None

    def is_normal_rule(self) -> bool:
        return self.head is not None and self.body

    def substitute(self, env: Mapping[Variable, Atom]) -> ForwardRule:
        if self.head is not None:
            top_function = self.head.atom.get_top_function()
            top_function.fill(env)
        for literal in self.body:
            top_function = literal.atom.get_top_function()
            top_function.fill(env)

        return self

    def rename_atoms(self, new_name: Optional[str], name: Optional[str], arity: int = 0, head=True,
                     body=True) -> ForwardRule:
        if head:
            if self.head is not None:
                top_function: Function = self.head.atom.get_top_function()
                top_function.recursive_rename(new_name, name, arity)
        if body:
            for literal in self.body:
                top_function = literal.atom.get_top_function()
                top_function.recursive_rename(new_name, name, arity)
        return self

    def postulate(self) -> ForwardRule:
        self.body.clear()
        return self

In [9]:
ForwardGoal = TypeVar('ForwardGoal', bound='Goal')


@dataclass
class Goal:
    rule: Optional[Rule] = None
    parent: Optional[ForwardGoal] = field(default=None, repr=False)
    children: Sequence[ForwardGoal] = field(default_factory=list, repr=False)
    env: Dict[Variable, Term] = field(default_factory=dict)
    inx: int = field(default=0, repr=False)


ProofTree = Goal

In [10]:
def unify(src_atom: Atom, src_env: Mapping[Variable, Atom], dest_atom: Atom,
          dest_env: Optional[MutableMapping[Variable, Atom]] = None) -> bool:
    src_symbol = src_atom.symbol
    dest_symbol = dest_atom.symbol
    if not isinstance(src_symbol, Function):
        return False
    if not isinstance(dest_symbol, Function):
        return False
    assert isinstance(src_symbol, Function)
    if not src_symbol.match_signature(dest_symbol):
        return False
    if dest_env is None:
        dest_env = {}
    for i in range(src_symbol.arity):
        src_arg = src_symbol.arguments[i]
        dest_arg = dest_symbol.arguments[i]
        if isinstance(src_arg, Variable):
            src_val = src_env.get(src_arg)
        else:
            src_val = src_arg
        if src_val is not None:
            if isinstance(dest_arg, Variable):
                dest_val = dest_env.get(dest_arg)
                if dest_val is None:
                    dest_env[dest_arg] = src_val
                elif dest_val != src_val:
                    return False
            elif dest_arg != src_val:
                return False
    return True

In [11]:
def search(atom: Atom, rules: Sequence[Rule] = ()) -> Sequence[ProofTree]:
    root = Goal(rule=Rule(head=Literal(), body=[Literal(atom=atom)]))
    proof_trees = []
    stack = [root]
    while stack:
        current = stack.pop()
        if current.inx >= len(current.rule.body):
            if current.parent is None:
                if current.env:
                    print(current.env)
                else:
                    print("Yes")
                proof_trees.append(current)
            else:
                parent = copy.deepcopy(current.parent)
                unify(current.rule.head.atom, current.env, parent.rule.body[parent.inx].atom, parent.env)
                parent.inx += 1
                stack.append(parent)
        else:
            atom = current.rule.body[current.inx].atom
            for rule in rules:
                child_env = {}
                unifiable = unify(atom, current.env, rule.head.atom, child_env)
                if unifiable:
                    child = Goal(env=child_env, parent=current, rule=rule)
                    current.children.append(child)
                    stack.append(child)
    if not proof_trees:
        print("No")
    return proof_trees

In [12]:
@dataclass
class Program:
    rules: MutableSequence[Rule] = field(default_factory=list)

    def __str__(self):
        return ' '.join(map(str, self.rules))

    def facts(self) -> Iterator[Atom]:
        for rule in self.rules:
            if rule.is_fact():
                yield rule

    def query(self, atom):
        prooftrees = search(atom, self.rules)
        envs = [prooftree.env for prooftree in prooftrees]
        return envs

    def evaluate_backwards(self, atom) -> Sequence[Atom]:
        prooftrees = search(atom, self.rules)
        answers = []
        for prooftree in prooftrees:
            stack = [prooftree]
            while stack:
                current = stack.pop()
                env = current.env
                for literal in current.body:
                    fact = Rule(head=literal)
                    fact.substitute(env)
                    if fact.is_ground() and fact not in answers:
                        answers.append(fact)
                stack.extend(current.children)
        return answers

    def evaluate_forwards(self) -> Iterator[Sequence[Atom]]:
        ctl = clingo.Control()
        ctl.configuration.solve.models = 0
        ctl.add('base', [], str(self))
        ctl.ground([('base', [])])
        with ctl.solve(yield_=True) as solve_handle:
            for model in solve_handle:
                symbols = sorted(model.symbols(shown=True))
                atoms = tuple(Atom.from_clingo_symbol(symbol) for symbol in symbols)
                yield atoms



In [13]:
X = Variable('X')
Y = Variable('Y')
A = Variable('A')
bill = Function('bill')
frank = Function('frank')
alice = Function('alice')
alex = Function('alex')

program = Program([
    Rule(head=Literal(atom=Atom(Function('child', (X, Y)))),
         body=[Literal(atom=Atom(Function('mother', (Y, X))))]),
    Rule(head=Literal(atom=Atom(Function('child', (X, Y)))),
         body=[Literal(atom=Atom(Function('father', (Y, X))))]),

    Rule(head=Literal(atom=Atom(Function('son', (X, Y)))),
         body=[Literal(atom=Atom(Function('child', (X, Y)))), Literal(atom=Atom(Function('boy', (X,))))]),

    Rule(head=Literal(atom=Atom(Function('boy', (bill,))))),
    Rule(head=Literal(atom=Atom(Function('boy', (frank,))))),
    Rule(head=Literal(atom=Atom(Function('mother', (alice, bill))))),
    Rule(head=Literal(atom=Atom(Function('father', (alex, bill)))))
])
print(program)

search_atom = Atom(Function('son', (bill, A)))
print(f"{search_atom}?")

child(X,Y) :- mother(Y,X). child(X,Y) :- father(Y,X). son(X,Y) :- child(X,Y), boy(X). boy(bill). boy(frank). mother(alice,bill). father(alex,bill).
son(bill,A)?


In [14]:
program.query(search_atom)

{Variable(name='A'): Function(name='alex', arguments=())}
{Variable(name='A'): Function(name='alice', arguments=())}


[{Variable(name='A'): Function(name='alex', arguments=())},
 {Variable(name='A'): Function(name='alice', arguments=())}]

In [15]:
models = list(program.evaluate_forwards())
print(' '.join(map(str, models[0])))

boy(bill) boy(frank) child(bill,alex) child(bill,alice) father(alex,bill) mother(alice,bill) son(bill,alex) son(bill,alice)


In [16]:
models

[(Atom(symbol=Function(name='boy', arguments=(Function(name='bill', arguments=()),))),
  Atom(symbol=Function(name='boy', arguments=(Function(name='frank', arguments=()),))),
  Atom(symbol=Function(name='child', arguments=(Function(name='bill', arguments=()), Function(name='alex', arguments=())))),
  Atom(symbol=Function(name='child', arguments=(Function(name='bill', arguments=()), Function(name='alice', arguments=())))),
  Atom(symbol=Function(name='father', arguments=(Function(name='alex', arguments=()), Function(name='bill', arguments=())))),
  Atom(symbol=Function(name='mother', arguments=(Function(name='alice', arguments=()), Function(name='bill', arguments=())))),
  Atom(symbol=Function(name='son', arguments=(Function(name='bill', arguments=()), Function(name='alex', arguments=())))),
  Atom(symbol=Function(name='son', arguments=(Function(name='bill', arguments=()), Function(name='alice', arguments=())))))]