Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PERF] enable metadata preservation across materialization points #2216

Merged
merged 2 commits into from
May 2, 2024

Conversation

samster25
Copy link
Member

@samster25 samster25 commented May 2, 2024

  • When enabling AQE, we introduce intermediate materializations for a query with multiple shuffles.
  • The problem with this is that metadata is not preserved across materialization boundaries.
  • So if we are running a SortMergeJoin and we draw a boundary after the sort and before the join, the algorithm errors out because the boundaries value is not set on the MaterializedResult.
  • This happens because at the .collect() point, we place Micropartitions into the cache rather than the MaterializedResult which contains both the data and PartitionMetadata.

We already do this behavior for the ray runner, this PR formalizes it for all runners.

Copy link
Contributor

@jaychia jaychia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall

@@ -314,7 +316,10 @@ def _from_tables(cls, *parts: MicroPartition) -> "DataFrame":
if not parts:
raise ValueError("Can't create a DataFrame from an empty list of tables.")

result_pset = LocalPartitionSet({i: part for i, part in enumerate(parts)})
result_pset = LocalPartitionSet()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a LocalPartitionSet.from_tables()?

@@ -67,10 +75,10 @@ def has_partition(self, idx: PartID) -> bool:
return idx in self._partitions

def __len__(self) -> int:
return sum(len(partition) for partition in self._partitions.values())
return sum(len(partition.partition()) for partition in self._partitions.values())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. After this PR we actually have metadata, and don't necessarily need to reach for the partition to get the length...

Would it not be possible/safe to let MaterializedResult.__len__ delegate appropriately between the metadata and the partition to get the length of the partition?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it doesn't really matter given that this is a local MicroPartition though

def build_partitions(
instruction_stack: list[Instruction],
partitions: list[MicroPartition],
final_metadata: list[PartialPartitionMetadata],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we could enforce same length using partitions: list[tuple[MicroPartition, PartialPartitionMetdata]]

@@ -266,9 +269,12 @@ def partition_set_from_ray_dataset(
daft_vpartitions = [
_make_daft_partition_from_ray_dataset_blocks.remote(block, daft_schema) for block in block_refs
]
pset = RayPartitionSet()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RayPartitionSet.from_ray_materialized_results might be nice

@@ -536,8 +547,25 @@ def place_in_queue(item):
elif len(next_step.instructions) == 0:
logger.debug("Running task synchronously in main thread: %s", next_step)
assert isinstance(next_step, SingleOutputPartitionTask)
[single_partial] = next_step.partial_metadatas
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems new - why was this necessary when before we didn't need to ensure that num_rows is available?

@samster25 samster25 merged commit 24d0831 into main May 2, 2024
29 checks passed
@samster25 samster25 deleted the sammy/enable-metadata-across-materialization-points branch May 2, 2024 19:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants