Skip to content

Commit

Permalink
Show correct length of dataset when running neuralization
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasZahradnik committed Apr 28, 2024
1 parent a6f24ba commit 68322fc
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
20 changes: 14 additions & 6 deletions neuralogic/core/builder/components.py
Expand Up @@ -234,21 +234,29 @@ def draw_grounding(
class GroundedDataset:
"""GroundedDataset represents grounded examples that are not neuralized yet."""

__slots__ = "length", "_groundings", "_groundings_list", "_builder"
__slots__ = "_groundings", "_groundings_list", "_builder"

def __init__(self, groundings, length, builder):
self.length = length
def __init__(self, groundings, builder):
self._groundings = groundings
self._groundings_list = None
self._builder = builder

def __getitem__(self, item):
def _to_list(self):
if self._groundings_list is None:
self._groundings = self._groundings.collect(jpype.JClass("java.util.stream.Collectors").toList())
self._groundings_list = [Grounding(g) for g in self._groundings]

def __getitem__(self, item):
self._to_list()
return self._groundings_list[item]

def __len__(self):
self._to_list()
return len(self._groundings_list)

def neuralize(self, progress: bool):
if self._groundings_list is not None:
return self._builder.neuralize(self._groundings.stream(), progress, self.length)
return self._builder.neuralize(self._groundings, progress, self.length)
return self._builder.neuralize(self._groundings.stream(), progress, len(self))
if progress:
return self._builder.neuralize(self._groundings, progress, len(self))
return self._builder.neuralize(self._groundings, progress, 0)
4 changes: 1 addition & 3 deletions neuralogic/core/builder/dataset_builder.py
Expand Up @@ -134,7 +134,6 @@ def ground_dataset(
settings.settings.parallelTraining = True

builder = Builder(settings)
length = None

if isinstance(dataset, datasets.Dataset):
self.examples_counter = 0
Expand Down Expand Up @@ -171,7 +170,6 @@ def ground_dataset(
queries, examples, one_query_per_example, example_queries
)

length = len(logic_samples)
groundings = builder.ground_from_logic_samples(self.parsed_template, logic_samples)

self.java_factory.weight_factory = weight_factory
Expand All @@ -187,7 +185,7 @@ def ground_dataset(
else:
raise NotImplementedError

return GroundedDataset(groundings, length, builder)
return GroundedDataset(groundings, builder)

def build_dataset(
self,
Expand Down

0 comments on commit 68322fc

Please sign in to comment.