-
Notifications
You must be signed in to change notification settings - Fork 147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PERF] enable metadata preservation across materialization points #2216
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
from daft.runners import runner_io | ||
from daft.runners.partitioning import ( | ||
MaterializedResult, | ||
PartialPartitionMetadata, | ||
PartID, | ||
PartitionCacheEntry, | ||
PartitionMetadata, | ||
|
@@ -28,23 +29,27 @@ | |
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class LocalPartitionSet(PartitionSet[MicroPartition]): | ||
_partitions: dict[PartID, MicroPartition] | ||
_partitions: dict[PartID, MaterializedResult[MicroPartition]] | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self._partitions = {} | ||
|
||
def items(self) -> list[tuple[PartID, MicroPartition]]: | ||
def items(self) -> list[tuple[PartID, MaterializedResult[MicroPartition]]]: | ||
return sorted(self._partitions.items()) | ||
|
||
def _get_merged_vpartition(self) -> MicroPartition: | ||
ids_and_partitions = self.items() | ||
assert ids_and_partitions[0][0] == 0 | ||
assert ids_and_partitions[-1][0] + 1 == len(ids_and_partitions) | ||
return MicroPartition.concat([part for id, part in ids_and_partitions]) | ||
return MicroPartition.concat([part.partition() for id, part in ids_and_partitions]) | ||
|
||
def _get_preview_vpartition(self, num_rows: int) -> list[MicroPartition]: | ||
ids_and_partitions = self.items() | ||
preview_parts = [] | ||
for _, part in ids_and_partitions: | ||
for _, mat_result in ids_and_partitions: | ||
part: MicroPartition = mat_result.partition() | ||
part_len = len(part) | ||
if part_len >= num_rows: # if this part has enough rows, take what we need and break | ||
preview_parts.append(part.slice(0, num_rows)) | ||
|
@@ -54,11 +59,14 @@ def _get_preview_vpartition(self, num_rows: int) -> list[MicroPartition]: | |
preview_parts.append(part) | ||
return preview_parts | ||
|
||
def get_partition(self, idx: PartID) -> MicroPartition: | ||
def get_partition(self, idx: PartID) -> MaterializedResult[MicroPartition]: | ||
return self._partitions[idx] | ||
|
||
def set_partition(self, idx: PartID, part: MaterializedResult[MicroPartition]) -> None: | ||
self._partitions[idx] = part.partition() | ||
self._partitions[idx] = part | ||
|
||
def set_partition_from_table(self, idx: PartID, part: MicroPartition) -> None: | ||
self._partitions[idx] = PyMaterializedResult(part, PartitionMetadata.from_table(part)) | ||
|
||
def delete_partition(self, idx: PartID) -> None: | ||
del self._partitions[idx] | ||
|
@@ -67,10 +75,10 @@ def has_partition(self, idx: PartID) -> bool: | |
return idx in self._partitions | ||
|
||
def __len__(self) -> int: | ||
return sum(len(partition) for partition in self._partitions.values()) | ||
return sum(len(partition.partition()) for partition in self._partitions.values()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting. After this PR we actually have metadata, and don't necessarily need to reach for the partition to get the length... Would it not be possible/safe to let There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it doesn't really matter given that this is a local MicroPartition though |
||
|
||
def size_bytes(self) -> int | None: | ||
size_bytes_ = [partition.size_bytes() for partition in self._partitions.values()] | ||
size_bytes_ = [partition.partition().size_bytes() for partition in self._partitions.values()] | ||
size_bytes: list[int] = [size for size in size_bytes_ if size is not None] | ||
if len(size_bytes) != len(size_bytes_): | ||
return None | ||
|
@@ -126,7 +134,7 @@ def runner_io(self) -> PyRunnerIO: | |
def run(self, builder: LogicalPlanBuilder) -> PartitionCacheEntry: | ||
results = list(self.run_iter(builder)) | ||
|
||
result_pset = LocalPartitionSet({}) | ||
result_pset = LocalPartitionSet() | ||
for i, result in enumerate(results): | ||
result_pset.set_partition(i, result) | ||
|
||
|
@@ -144,6 +152,7 @@ def run_iter( | |
|
||
# Optimize the logical plan. | ||
builder = builder.optimize() | ||
|
||
# Finalize the logical plan and get a physical plan scheduler for translating the | ||
# physical plan to executable tasks. | ||
plan_scheduler = builder.to_physical_plan_scheduler(daft_execution_config) | ||
|
@@ -209,8 +218,10 @@ def _physical_plan_to_partitions( | |
) | ||
): | ||
logger.debug("Running task synchronously in main thread: %s", next_step) | ||
partitions = self.build_partitions(next_step.instructions, *next_step.inputs) | ||
next_step.set_result([PyMaterializedResult(partition) for partition in partitions]) | ||
materialized_results = self.build_partitions( | ||
next_step.instructions, next_step.inputs, next_step.partial_metadatas | ||
) | ||
next_step.set_result(materialized_results) | ||
|
||
else: | ||
# Submit the task for execution. | ||
|
@@ -220,7 +231,10 @@ def _physical_plan_to_partitions( | |
pbar.mark_task_start(next_step) | ||
|
||
future = thread_pool.submit( | ||
self.build_partitions, next_step.instructions, *next_step.inputs | ||
self.build_partitions, | ||
next_step.instructions, | ||
next_step.inputs, | ||
next_step.partial_metadatas, | ||
) | ||
# Register the inflight task and resources used. | ||
future_to_task[future] = next_step.id() | ||
|
@@ -239,12 +253,13 @@ def _physical_plan_to_partitions( | |
done_id = future_to_task.pop(done_future) | ||
del inflight_tasks_resources[done_id] | ||
done_task = inflight_tasks.pop(done_id) | ||
partitions = done_future.result() | ||
materialized_results = done_future.result() | ||
|
||
pbar.mark_task_done(done_task) | ||
|
||
logger.debug("Task completed: %s -> <%s partitions>", done_id, len(partitions)) | ||
done_task.set_result([PyMaterializedResult(partition) for partition in partitions]) | ||
logger.debug("Task completed: %s -> <%s partitions>", done_id, len(materialized_results)) | ||
|
||
done_task.set_result(materialized_results) | ||
|
||
if next_step is None: | ||
next_step = next(plan) | ||
|
@@ -278,17 +293,23 @@ def _can_admit_task(self, resource_request: ResourceRequest, inflight_resources: | |
return all((cpus_okay, gpus_okay, memory_okay)) | ||
|
||
@staticmethod | ||
def build_partitions(instruction_stack: list[Instruction], *inputs: MicroPartition) -> list[MicroPartition]: | ||
partitions = list(inputs) | ||
def build_partitions( | ||
instruction_stack: list[Instruction], | ||
partitions: list[MicroPartition], | ||
final_metadata: list[PartialPartitionMetadata], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: we could enforce same length using |
||
) -> list[MaterializedResult[MicroPartition]]: | ||
for instruction in instruction_stack: | ||
partitions = instruction.run(partitions) | ||
return [ | ||
PyMaterializedResult(part, PartitionMetadata.from_table(part).merge_with_partial(partial)) | ||
for part, partial in zip(partitions, final_metadata) | ||
] | ||
|
||
return partitions | ||
|
||
|
||
@dataclass(frozen=True) | ||
@dataclass | ||
class PyMaterializedResult(MaterializedResult[MicroPartition]): | ||
_partition: MicroPartition | ||
_metadata: PartitionMetadata | None = None | ||
|
||
def partition(self) -> MicroPartition: | ||
return self._partition | ||
|
@@ -297,7 +318,9 @@ def vpartition(self) -> MicroPartition: | |
return self._partition | ||
|
||
def metadata(self) -> PartitionMetadata: | ||
return PartitionMetadata.from_table(self._partition) | ||
if self._metadata is None: | ||
self._metadata = PartitionMetadata.from_table(self._partition) | ||
return self._metadata | ||
|
||
def cancel(self) -> None: | ||
return None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have a
LocalPartitionSet.from_tables()
?