Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test new join #80

Merged
merged 12 commits into from Feb 23, 2024
120 changes: 87 additions & 33 deletions btrdb/stream.py
Expand Up @@ -17,12 +17,16 @@
import json
import logging
import re
import uuid
import uuid as uuidlib
import warnings
from collections import deque
from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, Dict, List

import pyarrow
import pyarrow.compute as pc

try:
import pyarrow as pa
Expand Down Expand Up @@ -2183,22 +2187,10 @@ def arrow_values(
)
stream_uus = [str(s.uuid) for s in self._streams]
data = list(aligned_windows_gen)
tablex = data.pop(0)
uu = stream_uus.pop(0)
tab_columns = [
c if c == "time" else uu + "/" + c for c in tablex.column_names
]
tablex = tablex.rename_columns(tab_columns)
if data:
for tab, uu in zip(data, stream_uus):
tab_columns = [
c if c == "time" else uu + "/" + c for c in tab.column_names
]
tab = tab.rename_columns(tab_columns)
tablex = tablex.join(tab, "time", join_type="full outer")
data = tablex
else:
data = tablex
table_joined = _merge_pyarrow_tables(
{uu: tab for uu, tab in zip(stream_uus, data)}
)
data = table_joined

elif self.width is not None and self.depth is not None:
# create list of stream.windows data (the windows method should
Expand All @@ -2217,22 +2209,10 @@ def arrow_values(
)
stream_uus = [str(s.uuid) for s in self._streams]
data = list(windows_gen)
tablex = data.pop(0)
uu = stream_uus.pop(0)
tab_columns = [
c if c == "time" else uu + "/" + c for c in tablex.column_names
]
tablex = tablex.rename_columns(tab_columns)
if data:
for tab, uu in zip(data, stream_uus):
tab_columns = [
c if c == "time" else uu + "/" + c for c in tab.column_names
]
tab = tab.rename_columns(tab_columns)
tablex = tablex.join(tab, "time", join_type="full outer")
data = tablex
else:
data = tablex
table_joined = _merge_pyarrow_tables(
{uu: tab for uu, tab in zip(stream_uus, data)}
)
data = table_joined
else:
sampling_freq = params.pop("sampling_frequency", 0)
period_ns = 0
Expand Down Expand Up @@ -2349,3 +2329,77 @@ def _coalesce_table_deque(tables: deque):
t2, "time", join_type="full outer", right_suffix=f"_{idx}"
)
return main_table


def _extract_unique_times(stream_map: Dict[uuid.UUID, pa.Table]) -> pa.Array:
"""Extracts and returns unique 'time' values from all tables."""
time_arrays = [
table.column("time").combine_chunks()
for table in stream_map.values()
if table.num_rows > 0
]
if not time_arrays: # Check if the list is empty
return pa.array([], type=pa.timestamp("ns"))

all_times = pa.concat_arrays(time_arrays)
return pc.unique(all_times).sort()


def _build_combined_schema(
stream_map: Dict[uuid.UUID, pa.Table], unique_times: pa.Array
) -> pa.Schema:
"""Constructs a combined schema for the merged table, ensuring unique column names."""
combined_schema = [("time", unique_times.type)]
# Use a list comprehension to flatten the loop into a single iterable over all columns
combined_schema += [
pa.field(
f"{uu}/{col_name}",
table.column(col_name).type if table.num_rows > 0 else pa.null(),
)
for uu, table in stream_map.items()
for col_name in table.column_names
if col_name != "time"
]

return pa.schema(combined_schema)


def _merge_pyarrow_tables(stream_map: Dict[uuid.UUID, pa.Table]) -> pa.Table:
"""Merges PyArrow tables based on 'time' values into a single table."""
unique_times = _extract_unique_times(stream_map)
combined_schema = _build_combined_schema(stream_map, unique_times)

none_data = [None] * len(unique_times)
preallocated_data = {
field.name: pa.array(none_data, type=field.type) for field in combined_schema
}
preallocated_data["time"] = unique_times

for uu, table in stream_map.items():
if table.num_rows > 0:
time_indices = pc.index_in(
preallocated_data["time"],
value_set=table.column("time"),
skip_nulls=True,
)
for col_name in table.column_names:
if col_name == "time":
continue
combined_col_name = f"{str(uu)}/{col_name}"
preallocated_data[combined_col_name] = pc.take(
table.column(col_name), indices=time_indices
)
else:
# For empty tables, ensure their columns are represented with all nulls
for col_name in combined_schema.names:
if (
col_name.startswith(f"{str(uu)}/")
and col_name not in preallocated_data
):
field_type = combined_schema.field(col_name).type
preallocated_data[col_name] = pa.array(
[None] * preallocated_data["time"].length(), type=field_type
)

arrays = [preallocated_data[col] for col in combined_schema.names]
return pa.Table.from_arrays(arrays=arrays, schema=combined_schema)
41 changes: 40 additions & 1 deletion tests/btrdb_integration/test_streamset.py
Expand Up @@ -209,7 +209,46 @@ def test_streamset_arrow_aligned_windows_vs_aligned_windows(
ss = (
btrdb.stream.StreamSet([s1, s2, s3])
.filter(start=100, end=121)
.windows(width=btrdb.utils.general.pointwidth.from_nanoseconds(10))
.aligned_windows(pointwidth=btrdb.utils.general.pointwidth.from_nanoseconds(10))
)
values_arrow = ss.arrow_to_dataframe(name_callable=name_callable)
values_arrow.index = pd.DatetimeIndex(values_arrow.index)
values_prev = ss.to_dataframe(
name_callable=name_callable
) # .convert_dtypes(dtype_backend='pyarrow')
values_prev = values_prev.apply(lambda x: x.astype(str(x.dtype) + "[pyarrow]"))
values_prev = values_prev.apply(
lambda x: x.astype("uint64[pyarrow]") if "count" in x.name else x
)
values_prev.index = pd.DatetimeIndex(values_prev.index, tz="UTC")
col_map = {old_col: old_col + "/mean" for old_col in values_prev.columns}
values_prev = values_prev.rename(columns=col_map)
assert values_arrow.equals(values_prev)


@pytest.mark.parametrize(
"name_callable",
[(None), (lambda s: str(s.uuid)), (lambda s: s.name + "/" + s.collection)],
ids=["empty", "uu_as_str", "name_collection"],
)
def test_streamset_arrow_aligned_windows_join_logic(
conn, tmp_collection, name_callable
):
s1 = conn.create(new_uuid(), tmp_collection, tags={"name": "s1"})
s2 = conn.create(new_uuid(), tmp_collection, tags={"name": "s2"})
s3 = conn.create(new_uuid(), tmp_collection, tags={"name": "s3"})
t1 = [100, 105, 110, 115, 120]
t2 = [101, 106, 110, 132, 140]
d1 = [0.0, 1.0, 2.0, 3.0, 4.0]
d2 = [5.0, 6.0, 7.0, 8.0, 9.0]
d3 = [1.0, 9.0, 44.0, 8.0, 9.0]
s1.insert(list(zip(t1, d1)))
s2.insert(list(zip(t2, d2)))
s3.insert(list(zip(t2, d3)))
ss = (
btrdb.stream.StreamSet([s1, s2, s3])
.filter(start=100, end=141)
.aligned_windows(pointwidth=btrdb.utils.general.pointwidth.from_nanoseconds(8))
)
values_arrow = ss.arrow_to_dataframe(name_callable=name_callable)
values_arrow.index = pd.DatetimeIndex(values_arrow.index)
Expand Down