diff --git a/neuralogic/core/constructs/atom.py b/neuralogic/core/constructs/atom.py index 49b46e52..ec01a53e 100644 --- a/neuralogic/core/constructs/atom.py +++ b/neuralogic/core/constructs/atom.py @@ -25,6 +25,11 @@ def __init__(self, predicate: Predicate, terms=None, negated=False): elif not isinstance(self.terms, Iterable): self.terms = [self.terms] + if self.predicate.special and self.predicate.name == "alldiff": + for term in self.terms: + if term is Ellipsis: + self.java_object = None + return self.java_object = get_java_factory().get_valued_fact(self, get_java_factory().get_variable_factory()) def __neg__(self) -> "BaseAtom": @@ -53,6 +58,8 @@ def __call__(self, *args) -> "BaseAtom": return BaseAtom(predicate, terms, self.negated) def __getitem__(self, item) -> "WeightedAtom": + if self.java_object is None: + raise NotImplementedError return WeightedAtom(self, item) def __le__(self, other: Body) -> rule.Rule: @@ -78,7 +85,7 @@ def __copy__(self): atom.java_object = self.java_object -class WeightedAtom: #todo gusta: mozna dedeni namisto kompozice? +class WeightedAtom: # todo gusta: mozna dedeni namisto kompozice? def __init__(self, atom: BaseAtom, weight, fixed=False): self.atom = atom self.weight = weight @@ -105,7 +112,7 @@ def predicate(self): return self.atom.predicate @property - def terms(self): #todo gusta: ...tim bys usetril toto volani atp. + def terms(self): # todo gusta: ...tim bys usetril toto volani atp. return self.atom.terms def __invert__(self) -> "WeightedAtom": diff --git a/neuralogic/core/constructs/rule.py b/neuralogic/core/constructs/rule.py index 844c4096..5eb23d9d 100644 --- a/neuralogic/core/constructs/rule.py +++ b/neuralogic/core/constructs/rule.py @@ -5,6 +5,8 @@ class Rule: def __init__(self, head, body): + from neuralogic.core import Atom + self.head = head if not isinstance(body, Iterable): @@ -13,8 +15,45 @@ def __init__(self, head, body): self.body = list(body) self.metadata: Optional[Metadata] = None + if self.is_ellipsis_templated(): + variable_set = {term for term in head.terms if term is not Ellipsis and str(term)[0].isupper()} + + for body_atom in self.body: + if body_atom.predicate.special and body_atom.predicate.name == "alldiff": + continue + + for term in body_atom.terms: + if term is not Ellipsis and str(term)[0].isupper(): + variable_set.add(term) + + for atom_index, body_atom in enumerate(self.body): + if not body_atom.predicate.special or body_atom.predicate.name != "alldiff": + continue + + new_terms = [] + found_replacement = False + + for index, term in enumerate(body_atom.terms): + if term is Ellipsis: + if found_replacement: + raise NotImplementedError + found_replacement = True + new_terms.extend(variable_set) + else: + new_terms.append(term) + if found_replacement: + self.body[atom_index] = Atom.special.alldiff(*new_terms) self.java_object = get_java_factory().get_rule(self) + def is_ellipsis_templated(self) -> bool: + for body_atom in self.body: + if not body_atom.predicate.special or body_atom.predicate.name != "alldiff": + continue + for term in body_atom.terms: + if term is Ellipsis: + return True + return False + def __str__(self): metadata = "" if self.metadata is None is None else f" {self.metadata}" return f"{self.head.to_str()} :- {', '.join(atom.to_str() for atom in self.body)}.{metadata}"