From 0ec4db36140382d4c3b580aa5278ed9f9cceb0d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Zahradn=C3=ADk?= Date: Fri, 9 Apr 2021 00:23:38 +0200 Subject: [PATCH] Prepare template building --- neuralogic/model/template.py | 78 +++++++++++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/neuralogic/model/template.py b/neuralogic/model/template.py index 77fa3d58..9fce3ac9 100644 --- a/neuralogic/model/template.py +++ b/neuralogic/model/template.py @@ -1,13 +1,27 @@ +from neuralogic import get_neuralogic, get_gateway +from py4j.java_collections import ListConverter +from py4j.java_gateway import get_field + from typing import Union, List + +from neuralogic.builder import Weight, Sample, Model, Neuron from neuralogic.model.atom import BaseAtom, WeightedAtom from neuralogic.model.rule import Rule +from neuralogic.model.predicate import PredicateMetadata +from neuralogic.model.java_objects import get_java_factory, init_java_factory TemplateEntries = Union[BaseAtom, WeightedAtom, Rule] +def stream_to_list(stream) -> List: + return list(stream.collect(get_gateway().jvm.java.util.stream.Collectors.toList())) + + class Template: def __init__(self): + self.java_model = None + self.template: List[TemplateEntries] = [] self.examples: List[TemplateEntries] = [] self.queries: List[TemplateEntries] = [] @@ -30,8 +44,70 @@ def add_query(self, query: TemplateEntries): def add_queries(self, queries: List[TemplateEntries]): self.queries.extend(queries) + def build_examples(self, examples_builder): + java_factory = get_java_factory() + logic_samples = [] + namespace = get_neuralogic().cz.cvut.fel.ida.logic.constructs.example + + for query_counter, example in enumerate(self.examples): + lifted_example = java_factory.get_lifted_example(example) + query_atom = examples_builder.createQueryAtom(str(query_counter), None, lifted_example) + logic_samples.append(namespace.LogicSample(None, query_atom)) + return logic_samples + + def get_parsed_template(self): + predicate_metadata = [] + weighted_rules = [] + valued_facts = [] + + for rule in self.template: + if isinstance(rule, PredicateMetadata): + predicate_metadata.append(rule.java_object) + elif isinstance(rule, Rule): + weighted_rules.append(rule.java_object) + elif isinstance(rule, (WeightedAtom, BaseAtom)): + valued_facts.append(rule.java_object) + + weighted_rules = ListConverter().convert(weighted_rules, get_gateway()._gateway_client) + valued_facts = ListConverter().convert(valued_facts, get_gateway()._gateway_client) + predicate_metadata = ListConverter().convert(predicate_metadata, get_gateway()._gateway_client) + + template_namespace = get_neuralogic().cz.cvut.fel.ida.logic.constructs.template.types + return template_namespace.ParsedTemplate(weighted_rules, valued_facts) + def build(self): - pass + java_factory = init_java_factory(get_java_factory().settings) + namespace = get_neuralogic().cz.cvut.fel.ida.logic.constructs.building + examples_builder = namespace.ExamplesBuilder(java_factory.settings.settings) + + logic_samples = self.build_examples(examples_builder) + logic_samples = ListConverter().convert(logic_samples, get_gateway()._gateway_client).stream() + parsed_template = self.get_parsed_template() + + namespace = get_neuralogic().cz.cvut.fel.ida.pipelines.building + pipes_namespace = get_neuralogic().cz.cvut.fel.ida.pipelines.pipes.specific + + builder = namespace.End2endTrainigBuilder(java_factory.settings.settings, None) + nn_builder = builder.getEnd2endNNBuilder() + + pipeline = nn_builder.buildPipelineFromTemplate(parsed_template, logic_samples) + serializer_pipe = pipes_namespace.NeuralSerializerPipe() + + pipeline.connectAfter(serializer_pipe) + pipeline.execute(None) + result = serializer_pipe.get() + + serialized_weights = list(get_field(result, "r")) + weights: List = [None] * len(serialized_weights) + + for x in serialized_weights: + weight = Weight(x) + + if weight.index >= len(weights): + weights.extend([None] * (weight.index - len(weights) + 1)) + weights[weight.index] = weight + + sample = [Sample(x) for x in stream_to_list(get_field(result, "s"))] def __str__(self): return "\n".join(str(r) for r in self.template)