In [27]:
import copy
from dataclasses import dataclass, field
from enum import IntEnum
from typing import TypeVar, Union, Sequence, Optional, Dict, Tuple, Mapping, MutableMapping, MutableSequence, List

import clingo
import clingo.ast


In [28]:
@dataclass(frozen=True, order=True)
class Variable:
    name: str

    def __str__(self):
        return self.name


@dataclass(frozen=True, order=True)
class IntegerConstant:
    number: int = 0

    def __str__(self):
        return str(self.number)


@dataclass(frozen=True, order=True)
class Term:
    constant: IntegerConstant = field(default=IntegerConstant())

    def __str__(self):
        return str(self.constant)

In [29]:
ForwardAtom = TypeVar('ForwardAtom', bound='Atom')
ForwardFunction = TypeVar('ForwardFunction', bound='Function')

Symbol = Union[ForwardFunction]
SubSymbol = Union[Symbol, Variable, Term]

@dataclass
class Function:
    name: Optional[str] = None
    arguments: Sequence[SubSymbol] = ()

    @property
    def arity(self):
        return len(self.arguments)

    def __str__(self):
        if self.name is None and not self.arguments:
            return "()"
        elif self.name is not None and not self.arguments:
            return self.name
        elif self.name is None and self.arguments:
            return "({})".format(','.join(map(str, self.arguments)))
        else:
            return "{}({})".format(self.name, ','.join(map(str, self.arguments)))

    def match(self, name: Optional[str], arity: int = 0) -> bool:
        return name == self.name and arity == len(self.arguments)

    def match_signature(self, other: ForwardFunction) -> bool:
        return self.match(other.name, other.arity)

    def recursive_rename(self, new_name: Optional[str]) -> ForwardFunction:
        stack: List[Function] = [self]
        while stack:
            current: Function = stack.pop()
            if current.match(self.name, self.arity):
                current.name = new_name
            for arg in current.arguments:
                if isinstance(arg, Function):
                    stack.append(arg)
        return self


In [30]:
@dataclass
class Atom:
    symbol: Symbol = field(default_factory=Function)

    def __str__(self):
        return str(self.symbol)

    def is_function(self):
        return isinstance(self.symbol, Function)

    def fill(self, env: Mapping[Variable, ForwardAtom]) -> ForwardAtom:
        stack: List[Atom] = [self]
        while stack:
            current: Atom = stack.pop()
            arguments = list(current.symbol.arguments)
            i = 0
            while i < current.symbol.arity:
                arg = arguments[i]
                if isinstance(arg, Variable) and arg in env:
                    arguments[i] = env[arg]
                else:
                    stack.append(arg)
                i += 1
            assert len(current.symbol.arguments) == len(
                arguments), "Function.fill should not change arity of Function."
            current.symbol.arguments = arguments
        return self



In [31]:
class Sign(IntEnum):
    NoSign = 0
    DefaultNeg = 1

In [32]:
@dataclass
class Literal:
    atom: Atom = field(default_factory=Atom)
    sign: Sign = Sign.NoSign

    def __str__(self):
        if self.sign is Sign.DefaultNeg:
            return "not {}".format(self.atom)
        return str(self.atom)

    def __abs__(self):
        return Literal(sign=Sign.NoSign, atom=copy.deepcopy(self.atom))

In [33]:
ForwardRule = TypeVar('ForwardRule', bound='Rule')


@dataclass
class Rule:
    head: Optional[Literal] = None
    body: MutableSequence[Literal] = ()

    def __str__(self) -> str:
        if self.head is None and not self.body:
            return ":-."
        elif self.head is None:
            return ":- {}.".format(', '.join(map(str, self.body)))
        elif not self.body:
            return "{}.".format(self.head)
        else:
            return "{} :- {}.".format(self.head, ', '.join(map(str, self.body)))

    def is_ground(self) -> bool:
        pass

    def substitute(self, env: Mapping[Variable, Atom]) -> ForwardRule:
        if self.head is not None:
            self.head.atom.fill(env)
        for literal in self.body:
            if literal.atom.is_function():
                literal.atom.fill(env)
        return self

    def rename_atoms(self, new_name: Optional[str], name: Optional[str], arity: int = 0, head=True,
                     body=True) -> ForwardRule:
        if head:
            if self.head is not None and self.head.atom.is_function() and self.head.atom.symbol.match(name, arity):
                self.head.atom.symbol.recursive_rename(new_name)
        if body:
            for literal in self.body:
                if literal.atom.is_function() and literal.atom.symbol.match(name, arity):
                    literal.atom.symbol.recursive_rename(new_name)
        return self

    def postulate(self) -> ForwardRule:
        self.body.clear()
        return self

In [34]:
ForwardGoal = TypeVar('ForwardGoal', bound='Goal')


@dataclass
class Goal:
    rule: Optional[Rule] = None
    parent: Optional[ForwardGoal] = field(default=None, repr=False)
    children: Sequence[ForwardGoal] = field(default_factory=list, repr=False)
    env: Dict[Variable, Term] = field(default_factory=dict)
    inx: int = field(default=0, repr=False)


ProofTree = Goal

In [35]:
def unify(src_atom: Atom, src_env: Mapping[Variable, Atom], dest_atom: Atom,
          dest_env: Optional[MutableMapping[Variable, Atom]] = None) -> bool:
    src_symbol = src_atom.symbol
    dest_symbol = dest_atom.symbol
    if not isinstance(src_symbol, Function):
        return False
    if not isinstance(dest_symbol, Function):
        return False
    if not src_symbol.match_signature(dest_symbol):
        return False
    if dest_env is None:
        dest_env = {}
    for i in range(src_symbol.arity):
        src_arg = src_symbol.arguments[i]
        dest_arg = dest_symbol.arguments[i]
        if isinstance(src_arg, Variable):
            src_val = src_env.get(src_arg)
        else:
            src_val = src_arg
        if src_val is not None:
            if isinstance(dest_arg, Variable):
                dest_val = dest_env.get(dest_arg)
                if dest_val is None:
                    dest_env[dest_arg] = src_val
                elif dest_val != src_val:
                    return False
            elif dest_arg != src_val:
                return False
    return True

In [36]:
def search(atom: Atom, rules: Sequence[Rule] = ()) -> Sequence[ProofTree]:
    root = Goal(rule=Rule(head=Literal(), body=[Literal(atom=atom)]))
    proof_trees = []
    stack = [root]
    while stack:
        current = stack.pop()
        if current.inx >= len(current.rule.body):
            if current.parent is None:
                if current.env:
                    print(current.env)
                else:
                    print("Yes")
                proof_trees.append(current)
            else:
                parent = copy.deepcopy(current.parent)
                unify(current.rule.head.atom, current.env, parent.rule.body[parent.inx].atom, parent.env)
                parent.inx += 1
                stack.append(parent)
        else:
            atom = current.rule.body[current.inx].atom
            for rule in rules:
                child_env = {}
                unifiable = unify(atom, current.env, rule.head.atom, child_env)
                if unifiable:
                    child = Goal(env=child_env, parent=current, rule=rule)
                    current.children.append(child)
                    stack.append(child)
    if not proof_trees:
        print("No")
    return proof_trees

In [37]:
@dataclass
class Program:
    rules: MutableSequence[Rule] = field(default_factory=list)

    def __str__(self):
        return ' '.join(map(str, self.rules))

    def query(self, atom):
        prooftrees = search(atom, self.rules)
        envs = [prooftree.env for prooftree in prooftrees]
        return envs

    def evaluate_backwards(self, atom) -> Sequence[Atom]:
        prooftrees = search(atom, self.rules)
        answers = []
        for prooftree in prooftrees:
            stack = [prooftree]
            while stack:
                current = stack.pop()
                env = current.env
                for literal in current.body:
                    fact = Rule(head=literal)
                    fact.substitute(env)
                    if fact.is_ground() and fact not in answers:
                        answers.append(fact)
                stack.extend(current.children)
        return answers




In [38]:
X = Variable('X')
Y = Variable('Y')
A = Variable('A')
bill = Function('bill')
frank = Function('frank')
alice = Function('alice')
alex = Function('alex')

program = Program([
    Rule(head=Literal(atom= Atom(Function('child', (X, Y)))),
         body=[Literal(atom=Atom(Function('mother', (Y, X))))]),
    Rule(head=Literal(atom= Atom(Function('child', (X, Y)))),
         body=[Literal(atom=Atom(Function('father', (Y, X))))]),

    Rule(head=Literal(atom= Atom(Function('son', (X, Y)))),
         body=[Literal(atom=Atom(Function('child', (X, Y)))), Literal(atom=Atom(Function('boy', (X,))))]),

    Rule(head=Literal(atom= Atom(Function('boy', (bill,))))),
    Rule(head=Literal(atom= Atom(Function('boy', (frank,))))),
    Rule(head=Literal(atom= Atom(Function('mother', (alice, bill))))),
    Rule(head=Literal(atom= Atom(Function('father', (alex, bill)))))
])
print(program)

search_atom = Atom(Function('son', (bill, A)))
print(f"{search_atom}?")

child(X,Y) :- mother(Y,X). child(X,Y) :- father(Y,X). son(X,Y) :- child(X,Y), boy(X). boy(bill). boy(frank). mother(alice,bill). father(alex,bill).
son(bill,A)?


In [39]:
program.query(search_atom)

{Variable(name='A'): Function(name='alex', arguments=())}
{Variable(name='A'): Function(name='alice', arguments=())}


[{Variable(name='A'): Function(name='alex', arguments=())},
 {Variable(name='A'): Function(name='alice', arguments=())}]