Skip to content

Commit

Permalink
Extend alldiff special predicate with special term
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasZahradnik committed Jul 18, 2021
1 parent 0860b5d commit 0d7e2e4
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
11 changes: 9 additions & 2 deletions neuralogic/core/constructs/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def __init__(self, predicate: Predicate, terms=None, negated=False):
elif not isinstance(self.terms, Iterable):
self.terms = [self.terms]

if self.predicate.special and self.predicate.name == "alldiff":
for term in self.terms:
if term is Ellipsis:
self.java_object = None
return
self.java_object = get_java_factory().get_valued_fact(self, get_java_factory().get_variable_factory())

def __neg__(self) -> "BaseAtom":
Expand Down Expand Up @@ -53,6 +58,8 @@ def __call__(self, *args) -> "BaseAtom":
return BaseAtom(predicate, terms, self.negated)

def __getitem__(self, item) -> "WeightedAtom":
if self.java_object is None:
raise NotImplementedError
return WeightedAtom(self, item)

def __le__(self, other: Body) -> rule.Rule:
Expand All @@ -78,7 +85,7 @@ def __copy__(self):
atom.java_object = self.java_object


class WeightedAtom: #todo gusta: mozna dedeni namisto kompozice?
class WeightedAtom: # todo gusta: mozna dedeni namisto kompozice?
def __init__(self, atom: BaseAtom, weight, fixed=False):
self.atom = atom
self.weight = weight
Expand All @@ -105,7 +112,7 @@ def predicate(self):
return self.atom.predicate

@property
def terms(self): #todo gusta: ...tim bys usetril toto volani atp.
def terms(self): # todo gusta: ...tim bys usetril toto volani atp.
return self.atom.terms

def __invert__(self) -> "WeightedAtom":
Expand Down
39 changes: 39 additions & 0 deletions neuralogic/core/constructs/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

class Rule:
def __init__(self, head, body):
from neuralogic.core import Atom

self.head = head

if not isinstance(body, Iterable):
Expand All @@ -13,8 +15,45 @@ def __init__(self, head, body):
self.body = list(body)
self.metadata: Optional[Metadata] = None

if self.is_ellipsis_templated():
variable_set = {term for term in head.terms if term is not Ellipsis and str(term)[0].isupper()}

for body_atom in self.body:
if body_atom.predicate.special and body_atom.predicate.name == "alldiff":
continue

for term in body_atom.terms:
if term is not Ellipsis and str(term)[0].isupper():
variable_set.add(term)

for atom_index, body_atom in enumerate(self.body):
if not body_atom.predicate.special or body_atom.predicate.name != "alldiff":
continue

new_terms = []
found_replacement = False

for index, term in enumerate(body_atom.terms):
if term is Ellipsis:
if found_replacement:
raise NotImplementedError
found_replacement = True
new_terms.extend(variable_set)
else:
new_terms.append(term)
if found_replacement:
self.body[atom_index] = Atom.special.alldiff(*new_terms)
self.java_object = get_java_factory().get_rule(self)

def is_ellipsis_templated(self) -> bool:
for body_atom in self.body:
if not body_atom.predicate.special or body_atom.predicate.name != "alldiff":
continue
for term in body_atom.terms:
if term is Ellipsis:
return True
return False

def __str__(self):
metadata = "" if self.metadata is None is None else f" {self.metadata}"
return f"{self.head.to_str()} :- {', '.join(atom.to_str() for atom in self.body)}.{metadata}"
Expand Down

0 comments on commit 0d7e2e4

Please sign in to comment.