Skip to content

Commit

Permalink
Add conjunctions and lifted examples to java object factory
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasZahradnik committed Apr 8, 2021
1 parent d16a510 commit 8e8d0d6
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion neuralogic/model/java_objects.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from py4j.java_gateway import get_field, set_field
from typing import Optional, List
from typing import Optional, List, Iterable
from contextlib import contextmanager
from py4j.java_collections import ListConverter

Expand All @@ -11,6 +11,7 @@
class JavaFactory:
def __init__(self, settings: Settings):
namespace = get_neuralogic().cz.cvut.fel.ida.logic.constructs.building
self.settings = settings

self.namespace = get_neuralogic().cz.cvut.fel.ida.logic.constructs.template.components
self.value_namespace = get_neuralogic().cz.cvut.fel.ida.algebra.values
Expand Down Expand Up @@ -81,6 +82,32 @@ def get_metadata(self, metadata):
namespace = get_neuralogic().cz.cvut.fel.ida.logic.constructs.template.metadata
return namespace.RuleMetadata(get_field(self.builder, "settings"), map)

def get_lifted_example(self, example):
variable_factory = self.get_variable_factory()

if not isinstance(example, Iterable):
example = [example]

conjunctions = []
rules = []

for entry in example:
if isinstance(entry, Iterable):
conjunctions.append(self.get_conjunction(list(entry), variable_factory))
else:
rules.append(self.get_rule(entry))

return self.example_namespace.LiftedExample(
ListConverter().convert(conjunctions, get_gateway()._gateway_client),
ListConverter().convert(rules, get_gateway()._gateway_client),
)

def get_conjunction(self, atoms, variable_factory):
namespace = get_neuralogic().cz.cvut.fel.ida.logic.constructs
valued_facts = [self.get_valued_fact(atom, variable_factory) for atom in atoms]

return namespace.Conjunction(ListConverter().convert(valued_facts, get_gateway()._gateway_client))

def get_predicate_metadata_pair(self, predicate_metadata):
namespace = get_neuralogic().cz.cvut.fel.ida.utils.generic
return namespace.Pair(predicate_metadata.predicate.java_object, self.get_metadata(predicate_metadata.metadata))
Expand Down

0 comments on commit 8e8d0d6

Please sign in to comment.