diff --git a/dynet_test.py b/dynet_test.py deleted file mode 100644 index e024bd41..00000000 --- a/dynet_test.py +++ /dev/null @@ -1,51 +0,0 @@ -from dotenv import load_dotenv -import dynet as dy -from neuralogic import data -import neuralogic.dynet as nldy - -load_dotenv() - - -def train(): - model = data.Mutagenesis - - deserializer = nldy.NeuraLogicLayer(model.weights) - - epochs = 400 - trainer = dy.AdamTrainer(deserializer.model, alpha=0.001) - printouts = 10 - seen_instances = 0 - total_loss = 0 - - for iter in range(epochs): - if iter > 0 and iter % printouts == 0: - print(iter, " average loss is:", total_loss / seen_instances) - - seen_instances = 0 - total_loss = 0 - - dy.renew_cg(immediate_compute=False, check_validity=False) - - losses = [] - - for sample in model.samples: - label = dy.scalarInput(sample.target) - graph_output = deserializer.build_sample(sample) - loss = dy.squared_distance(graph_output, label) - losses.append(loss) - - loss = dy.esum(losses) - total_loss += loss.value() - loss.backward() - trainer.update() - seen_instances += 1 - - for sample in model.samples: - dy.renew_cg(immediate_compute=False, check_validity=False) - - graph_output = deserializer.build_sample(sample) - label = dy.scalarInput(sample.target) - print(f"label: {label.value()}, output: {graph_output.value()}") - - -train() diff --git a/neuralogic/model/java_objects.py b/neuralogic/model/java_objects.py index 61956f5c..29736ef3 100644 --- a/neuralogic/model/java_objects.py +++ b/neuralogic/model/java_objects.py @@ -46,9 +46,16 @@ def get_term(self, term, variable_factory): return self.constant_factory.construct(str(term)) raise NotImplementedError - def get_generic_atom(self, atom_class, atom, variable_factory): + def get_generic_atom(self, atom_class, atom, variable_factory, new_weight): predicate = self.get_predicate(atom.predicate) - weight = self.get_weight(atom.weight, atom.is_fixed) if isinstance(atom, factories.atom.WeightedAtom) else None + + weight = None + if isinstance(atom, factories.atom.WeightedAtom): + if new_weight: + weight = self.get_weight(atom.weight, atom.is_fixed) + else: + weight = get_field(atom.java_object, "weight") + term_list = ListConverter().convert( [self.get_term(term, variable_factory) for term in atom.terms], get_gateway()._gateway_client ) @@ -126,10 +133,10 @@ def get_predicate_metadata_pair(self, predicate_metadata): return namespace.Pair(predicate_metadata.predicate.java_object, self.get_metadata(predicate_metadata.metadata)) def get_valued_fact(self, atom, variable_factory): - return self.get_generic_atom(self.example_namespace.ValuedFact, atom, variable_factory) + return self.get_generic_atom(self.example_namespace.ValuedFact, atom, variable_factory, True) def get_atom(self, atom, variable_factory): - return self.get_generic_atom(self.namespace.BodyAtom, atom, variable_factory) + return self.get_generic_atom(self.namespace.BodyAtom, atom, variable_factory, False) def get_rule(self, rule): java_rule = self.namespace.WeightedRule()