Skip to content

Commit

Permalink
Reuse weights in rule atoms
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasZahradnik committed Apr 14, 2021
1 parent 719bf10 commit af54ef8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 55 deletions.
51 changes: 0 additions & 51 deletions dynet_test.py

This file was deleted.

15 changes: 11 additions & 4 deletions neuralogic/model/java_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,16 @@ def get_term(self, term, variable_factory):
return self.constant_factory.construct(str(term))
raise NotImplementedError

def get_generic_atom(self, atom_class, atom, variable_factory):
def get_generic_atom(self, atom_class, atom, variable_factory, new_weight):
predicate = self.get_predicate(atom.predicate)
weight = self.get_weight(atom.weight, atom.is_fixed) if isinstance(atom, factories.atom.WeightedAtom) else None

weight = None
if isinstance(atom, factories.atom.WeightedAtom):
if new_weight:
weight = self.get_weight(atom.weight, atom.is_fixed)
else:
weight = get_field(atom.java_object, "weight")

term_list = ListConverter().convert(
[self.get_term(term, variable_factory) for term in atom.terms], get_gateway()._gateway_client
)
Expand Down Expand Up @@ -126,10 +133,10 @@ def get_predicate_metadata_pair(self, predicate_metadata):
return namespace.Pair(predicate_metadata.predicate.java_object, self.get_metadata(predicate_metadata.metadata))

def get_valued_fact(self, atom, variable_factory):
return self.get_generic_atom(self.example_namespace.ValuedFact, atom, variable_factory)
return self.get_generic_atom(self.example_namespace.ValuedFact, atom, variable_factory, True)

def get_atom(self, atom, variable_factory):
return self.get_generic_atom(self.namespace.BodyAtom, atom, variable_factory)
return self.get_generic_atom(self.namespace.BodyAtom, atom, variable_factory, False)

def get_rule(self, rule):
java_rule = self.namespace.WeightedRule()
Expand Down

0 comments on commit af54ef8

Please sign in to comment.