Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
924 changes: 896 additions & 28 deletions Cargo.lock

Large diffs are not rendered by default.

22 changes: 19 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,36 @@ rust-version = "1.62"
build = "build.rs"

[dependencies]
arrow = { version = "53", features = ["pyarrow", "ipc"] }
async-stream = "0.3"
datafusion = { version = "43.0", features = ["pyarrow", "avro"] }
datafusion-python = { version = "43.0" }
datafusion-proto = "43.0"
futures = "0.3"
glob = "0.3.1"
log = "0.4"
prost = "0.13"
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
tokio = { version = "1.40", features = ["macros", "rt", "rt-multi-thread", "sync"] }
pyo3 = { version = "0.22", features = [
"extension-module",
"abi3",
"abi3-py38",
] }
pyo3-async-runtimes = { version = "0.22", features = ["tokio-runtime"] }
tokio = { version = "1.40", features = [
"macros",
"rt",
"rt-multi-thread",
"sync",
] }
uuid = "1.11.0"

[build-dependencies]
prost-types = "0.13"
rustc_version = "0.4.0"
tonic-build = { version = "0.8", default-features = false, features = ["transport", "prost"] }
tonic-build = { version = "0.8", default-features = false, features = [
"transport",
"prost",
] }

[dev-dependencies]
anyhow = "1.0.89"
Expand Down
8 changes: 1 addition & 7 deletions datafusion_ray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@
except ImportError:
import importlib_metadata

from ._datafusion_ray_internal import (
Context,
ExecutionGraph,
QueryStage,
execute_partition,
)
from .context import DatafusionRayContext
from .context import RayContext

__version__ = importlib_metadata.version(__name__)
288 changes: 152 additions & 136 deletions datafusion_ray/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,150 +15,166 @@
# specific language governing permissions and limitations
# under the License.

import json
import os
import time
from typing import Iterable
import asyncio
import threading

import datafusion
import pyarrow as pa
import ray

import datafusion_ray
from datafusion_ray import Context, ExecutionGraph, QueryStage
from typing import List, Any
from datafusion import SessionContext


@ray.remote(num_cpus=0)
def execute_query_stage(
query_stages: list[QueryStage],
stage_id: int
) -> tuple[int, list[ray.ObjectRef]]:
"""
Execute a query stage on the workers.

Returns the stage ID, and a list of futures for the output partitions of the query stage.
"""
stage = QueryStage(stage_id, query_stages[stage_id])

# execute child stages first
child_futures = []
for child_id in stage.get_child_stage_ids():
child_futures.append(
execute_query_stage.remote(query_stages, child_id)
)
from datafusion_ray._datafusion_ray_internal import (
RayContext as RayContextInternal,
internal_execute_partition,
)


class RayContext:
def __init__(self) -> None:
self.ctx = RayContextInternal(RayShuffler())

def register_parquet(self, name: str, path: str):
self.ctx.register_parquet(name, path)

def sql(self, query: str) -> datafusion.DataFrame:
return self.ctx.sql(query)

def set(self, option: str, value: str) -> None:
self.ctx.set(option, value)


class RayIterable:
def __init__(self, iterable, name):
self.iterable = iterable
self.name = name

def __next__(self):
print(f"{self.name} ray iterable getting next")
object_ref = next(self.iterable)

list_of_ref = ray.get(object_ref)

if list_of_ref is None:
print(f"{self.name} ray iterable got None")
raise StopIteration

print(f"{self.name} ray iterable got list: {list_of_ref}")
ob = ray.get(list_of_ref[0])
print(f"{self.name} ray iterable got object ")
return ob

def __iter__(self):
return self


class RayShuffler:
def __init__(self):
pass

def execute_partition(
self,
plan: bytes,
partition: int,
output_partitions: int,
input_partitions: int,
unique_id: str,
) -> RayIterable:
print(f"ray executing partition {partition}")
# TODO: make name unique per query tree
self.actor = RayShuffleActor.options(
name=f"RayShuffleActor ({unique_id})",
# lifetime="detached",
get_if_exists=True,
).remote(plan, output_partitions, input_partitions)

stream = exec_stream(self.actor, partition)
return RayIterable(stream, f"partition {partition} ")

# if the query stage has a single output partition then we need to execute for the output
# partition, otherwise we need to execute in parallel for each input partition
concurrency = stage.get_execution_partition_count()
output_partitions_count = stage.get_output_partition_count()
if output_partitions_count == 1:
# reduce stage
print("Forcing reduce stage concurrency from {} to 1".format(concurrency))
concurrency = 1

print(
"Scheduling query stage #{} with {} input partitions and {} output partitions".format(
stage.id(), concurrency, output_partitions_count
)
)

# A list of (stage ID, list of futures) for each child stage
# Each list is a 2-D array of (input partitions, output partitions).
child_outputs = ray.get(child_futures)

# if we are using disk-based shuffle, wait until the child stages to finish
# writing the shuffle files to disk first.
ray.get([f for _, lst in child_outputs for f in lst])

# schedule the actual execution workers
plan_bytes = stage.get_execution_plan_bytes()
futures = []
opt = {}
for part in range(concurrency):
futures.append(
execute_query_partition.options(**opt).remote(
stage_id, plan_bytes, part
)
)

return stage_id, futures
def exec_stream(actor, partition: int):
while True:
object_ref = actor.stream.remote(partition)
print(f"stream got {object_ref}")
if object_ref is None:
print("breaking")
break
print(f"yielding {object_ref}")
yield object_ref


@ray.remote
def execute_query_partition(
stage_id: int,
plan_bytes: bytes,
part: int
) -> Iterable[pa.RecordBatch]:
start_time = time.time()
# plan = datafusion_ray.deserialize_execution_plan(plan_bytes)
# print(
# "Worker executing plan {} partition #{} with shuffle inputs {}".format(
# plan.display(),
# part,
# input_partition_ids,
# )
# )
# This is delegating to DataFusion for execution, but this would be a good place
# to plug in other execution engines by translating the plan into another engine's plan
# (perhaps via Substrait, once DataFusion supports converting a physical plan to Substrait)
ret = datafusion_ray.execute_partition(plan_bytes, part)
duration = time.time() - start_time
event = {
"cat": f"{stage_id}-{part}",
"name": f"{stage_id}-{part}",
"pid": ray.util.get_node_ip_address(),
"tid": os.getpid(),
"ts": int(start_time * 1_000_000),
"dur": int(duration * 1_000_000),
"ph": "X",
}
print(json.dumps(event), end=",")
return ret


class DatafusionRayContext:
def __init__(self, df_ctx: SessionContext):
self.df_ctx = df_ctx
self.ctx = Context(df_ctx)

def register_csv(self, table_name: str, path: str, has_header: bool):
self.ctx.register_csv(table_name, path, has_header)

def register_parquet(self, table_name: str, path: str):
self.ctx.register_parquet(table_name, path)

def register_data_lake(self, table_name: str, paths: List[str]):
self.ctx.register_datalake_table(table_name, paths)

def sql(self, sql: str) -> pa.RecordBatch:
# TODO we should parse sql and inspect the plan rather than
# perform a string comparison here
sql_str = sql.lower()
if "create view" in sql_str or "drop view" in sql_str:
self.ctx.sql(sql)
return []

df = self.df_ctx.sql(sql)
return self.plan(df.execution_plan())

def plan(self, execution_plan: Any) -> List[pa.RecordBatch]:

graph = self.ctx.plan(execution_plan)
final_stage_id = graph.get_final_query_stage().id()
# serialize the query stages and store in Ray object store
query_stages = [
graph.get_query_stage(i).get_execution_plan_bytes()
for i in range(final_stage_id + 1)
class RayShuffleActor:
def __init__(
self, plan: bytes, output_partitions: int, input_partitions: int
) -> None:
self.plan = plan
self.output_partitions = output_partitions
self.input_partitions = input_partitions

self.queues = [asyncio.Queue() for _ in range(output_partitions)]

self.is_finished = [False for _ in range(input_partitions)]

print(f"creating actor with {output_partitions}, {input_partitions}")

self._start_partition_tasks()

def _start_partition_tasks(self):
ctx = ray.get_runtime_context()
my_handle = ctx.current_actor

self.tasks = [
_exec_stream.remote(self.plan, p, self.output_partitions, my_handle)
for p in range(self.input_partitions)
]
# schedule execution
future = execute_query_stage.remote(
query_stages,
final_stage_id
print(f"started tasks: {self.tasks}")

def finished(self, partition: int) -> None:
self.is_finished[partition] = True

# if we are finished with all input partitions, then signal consumers
# of our output partitions
if all(self.is_finished):
for q in self.queues:
q.put_nowait(None)
print(f"Actor finished partition {partition}")

async def put(self, partition: int, thing) -> None:
await self.queues[partition].put(thing)

async def stream(self, partition: int):
thing = await self.queues[partition].get()
return thing


@ray.remote
def _exec_stream(
plan: bytes, shadow_partition: int, output_partitions: int, ray_shuffle_actor
):
my_id = ray.get_runtime_context().get_task_id()
print(f"Task {my_id} executing shadow partition {shadow_partition}")
print(f"Task {my_id} ray_shuffle handle: {ray_shuffle_actor}")

def do_a_partition(partition):
reader: pa.RecordBatchReader = internal_execute_partition(
plan, partition, shadow_partition
)
_, partitions = ray.get(future)
# assert len(partitions) == 1, len(partitions)
record_batches = ray.get(partitions[0])
# filter out empty batches
return [batch for batch in record_batches if batch.num_rows > 0]

print(f"Task {my_id} got reader for partition {partition}")

for batch in reader:
record_batch: pa.RecordBatch = batch
print(f"Task {my_id} got batch {len(record_batch)} rows")

object_ref = ray.put(batch)
ray_shuffle_actor.put.remote(partition, [object_ref])

threads = []
for p in range(output_partitions):
t = threading.Thread(target=do_a_partition, args=(p,))
threads.append(t)
t.start()

for t in threads:
t.join()

ray_shuffle_actor.finished.remote(shadow_partition)
31 changes: 0 additions & 31 deletions datafusion_ray/ray_utils.py

This file was deleted.

Loading