Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 70 additions & 108 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,26 +467,29 @@ fn python_to_scalar_bound(obj: &Bound<'_, PyAny>, dtype_tag: &str) -> PyResult<S
}
}

/// Convert Python partition metadata dict to Rust PartitionMetadata.
fn convert_python_metadata(
meta_dict: HashMap<String, (Py<PyAny>, Py<PyAny>, String)>,
) -> PyResult<PartitionMetadata> {
Python::attach(|py| {
let mut ranges = HashMap::new();
for (dim_name, (min_obj, max_obj, dtype_tag)) in meta_dict {
let min_bound = python_to_scalar_bound(min_obj.bind(py), &dtype_tag)?;
let max_bound = python_to_scalar_bound(max_obj.bind(py), &dtype_tag)?;
ranges.insert(
dim_name.clone(),
DimensionRange {
column_name: dim_name,
min: min_bound,
max: max_bound,
},
);
}
Ok(PartitionMetadata { ranges })
})
/// Convert a bound Python metadata dict to Rust PartitionMetadata.
///
/// Operates on an already-bound reference so no additional GIL acquisition
/// is needed — this is called from within a `#[pymethods]` context where
/// the GIL is already held.
fn convert_python_metadata_from_bound(meta_obj: &Bound<'_, PyAny>) -> PyResult<PartitionMetadata> {
type MetaDict = HashMap<String, (Py<PyAny>, Py<PyAny>, String)>;
let meta_dict: MetaDict = meta_obj.extract()?;
let py = meta_obj.py();
let mut ranges = HashMap::new();
for (dim_name, (min_obj, max_obj, dtype_tag)) in meta_dict {
let min_bound = python_to_scalar_bound(min_obj.bind(py), &dtype_tag)?;
let max_bound = python_to_scalar_bound(max_obj.bind(py), &dtype_tag)?;
ranges.insert(
dim_name.clone(),
DimensionRange {
column_name: dim_name,
min: min_bound,
max: max_bound,
},
);
}
Ok(PartitionMetadata { ranges })
}

impl Debug for PrunableStreamingTable {
Expand Down Expand Up @@ -663,9 +666,8 @@ impl PartitionStream for PyArrowStreamPartition {
///
/// ## Filter Pushdown
///
/// When partition metadata is provided, SQL filters on dimension columns (time, lat, lon, etc.)
/// will automatically prune partitions that can't contain matching rows. This dramatically
/// improves query performance for range queries on large datasets.
/// SQL filters on dimension columns (time, lat, lon, etc.) automatically prune
/// partitions that can't contain matching rows when metadata is supplied.
///
/// # Note
///
Expand All @@ -676,32 +678,23 @@ impl PartitionStream for PyArrowStreamPartition {
///
/// ```python
/// from datafusion import SessionContext
/// from xarray_sql import LazyArrowStreamTable
/// import pyarrow as pa
///
/// # Create factories for each partition (chunk)
/// factories = [
/// lambda: pa.RecordBatchReader.from_batches(schema, batches_chunk_0),
/// lambda: pa.RecordBatchReader.from_batches(schema, batches_chunk_1),
/// ]
/// schema = pa.schema([("time", pa.int64()), ("air", pa.float32())])
///
/// # Partition metadata for filter pushdown (optional)
/// # Each dict maps dimension name to (min, max) coordinate values
/// metadata = [
/// {'time': (0, 1000000000), 'lat': (-90.0, 0.0)}, # partition 0
/// {'time': (1000000001, 2000000000), 'lat': (0.0, 90.0)}, # partition 1
/// ]
/// # Each element is a (factory_callable, metadata_dict) pair.
/// # metadata_dict maps dim name -> (min, max, dtype_str); use {} for no pruning.
/// def make_partitions():
/// yield (lambda: pa.RecordBatchReader.from_batches(schema, batches_0),
/// {"time": (0, 1_000_000_000, "int64")})
/// yield (lambda: pa.RecordBatchReader.from_batches(schema, batches_1),
/// {"time": (1_000_000_001, 2_000_000_000, "int64")})
///
/// # Wrap factories in lazy table with metadata
/// table = LazyArrowStreamTable(factories, schema, metadata)
/// table = LazyArrowStreamTable(make_partitions(), schema)
///
/// # Register with DataFusion
/// ctx = SessionContext()
/// ctx.register_table("air", table)
///
/// # Queries with filters on dimension columns will prune partitions!
/// # This query might only read partition 1:
/// result = ctx.sql("SELECT AVG(air) FROM air WHERE lat > 0").to_arrow_table()
/// result = ctx.sql("SELECT AVG(air) FROM air WHERE time > 500000000").to_arrow_table()
/// ```
#[pyclass(name = "LazyArrowStreamTable")]
struct LazyArrowStreamTable {
Expand All @@ -711,30 +704,26 @@ struct LazyArrowStreamTable {

#[pymethods]
impl LazyArrowStreamTable {
/// Create a new LazyArrowStreamTable from stream factory functions.
/// Create a new LazyArrowStreamTable from an iterable of partition pairs.
///
/// Args:
/// stream_factories: A list of callables, each returning a Python object
/// implementing the Arrow PyCapsule interface (`__arrow_c_stream__`).
/// Each factory represents one partition, enabling parallel execution.
/// Called on each query execution to create fresh streams.
/// schema: A PyArrow Schema for the table. Required since the factories
/// haven't been called yet.
/// partition_metadata: Optional list of dicts mapping dimension names to
/// (min, max) tuples. When provided, enables filter pushdown to
/// prune partitions based on SQL WHERE clauses.
/// partitions: Any Python iterable yielding ``(factory, metadata_dict)``
/// pairs, where:
/// - ``factory`` is a zero-argument callable returning a
/// ``pa.RecordBatchReader`` (called lazily at query time).
/// - ``metadata_dict`` is a ``dict[str, tuple[Any, Any, str]]``
/// mapping dimension name to ``(min, max, dtype_str)``; pass
/// ``{}`` to skip pruning for a partition.
/// Generators are accepted, so partition state can be produced
/// one item at a time and released after Rust stores it.
/// schema: A PyArrow Schema for the table.
///
/// Raises:
/// TypeError: If the schema is not a valid PyArrow Schema.
/// ValueError: If stream_factories is empty or metadata length doesn't match.
/// ValueError: If the partitions iterable is empty.
#[new]
#[pyo3(signature = (stream_factories, schema, partition_metadata=None))]
fn new(
stream_factories: &Bound<'_, PyAny>,
schema: &Bound<'_, PyAny>,
partition_metadata: Option<&Bound<'_, PyAny>>,
) -> PyResult<Self> {
// Convert the PyArrow schema to Arrow schema
#[pyo3(signature = (partitions, schema))]
fn new(partitions: &Bound<'_, PyAny>, schema: &Bound<'_, PyAny>) -> PyResult<Self> {
use arrow::datatypes::Schema;
use arrow::pyarrow::FromPyArrow;

Expand All @@ -743,59 +732,32 @@ impl LazyArrowStreamTable {
})?;
let schema_ref = Arc::new(arrow_schema);

// Extract factories from the Python list
let factories: Vec<Py<PyAny>> = stream_factories.extract().map_err(|e| {
pyo3::exceptions::PyTypeError::new_err(format!(
"stream_factories must be a list of callables: {e}"
))
})?;

if factories.is_empty() {
return Err(pyo3::exceptions::PyValueError::new_err(
"stream_factories must not be empty",
));
}

// Extract and convert partition metadata if provided
let metadata_list: Vec<PartitionMetadata> = if let Some(meta_py) = partition_metadata {
type MetaDict = HashMap<String, (Py<PyAny>, Py<PyAny>, String)>;
let meta_dicts: Vec<MetaDict> = meta_py.extract().map_err(|e| {
// Consume the Python iterable one item at a time.
// All GIL-bound work happens here, in a single GIL-held context,
// eliminating the per-partition Python::attach() calls of the old
// three-list approach. Python can release each block dict, factory
// closure, and metadata dict as soon as Rust has ingested them.
let mut partition_list: Vec<(Arc<dyn PartitionStream>, PartitionMetadata)> = Vec::new();
for item_result in partitions.try_iter()? {
let item = item_result?;
let (factory_obj, meta_obj): (Py<PyAny>, Py<PyAny>) = item.extract().map_err(|e| {
pyo3::exceptions::PyTypeError::new_err(format!(
"partition_metadata must be a list of dicts: {e}"
"each partition must be a (factory, metadata_dict) tuple: {e}"
))
})?;
let meta = convert_python_metadata_from_bound(meta_obj.bind(partitions.py()))?;
let partition = Arc::new(PyArrowStreamPartition::new(factory_obj, schema_ref.clone()))
as Arc<dyn PartitionStream>;
partition_list.push((partition, meta));
}

if meta_dicts.len() != factories.len() {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"partition_metadata length ({}) must match stream_factories length ({})",
meta_dicts.len(),
factories.len()
)));
}

meta_dicts
.into_iter()
.map(convert_python_metadata)
.collect::<PyResult<Vec<_>>>()?
} else {
// No metadata provided - create empty metadata for each partition
vec![PartitionMetadata::default(); factories.len()]
};

// Create partitions with their metadata
let partitions: Vec<(Arc<dyn PartitionStream>, PartitionMetadata)> = factories
.into_iter()
.zip(metadata_list)
.map(|(factory, meta)| {
let partition = Arc::new(PyArrowStreamPartition::new(factory, schema_ref.clone()))
as Arc<dyn PartitionStream>;
(partition, meta)
})
.collect();

// Create the PrunableStreamingTable
let table = PrunableStreamingTable::new(schema_ref, partitions);
if partition_list.is_empty() {
return Err(pyo3::exceptions::PyValueError::new_err(
"partitions iterable must not be empty",
));
}

let table = PrunableStreamingTable::new(schema_ref, partition_list);
Ok(Self {
table: Arc::new(table),
})
Expand Down
75 changes: 40 additions & 35 deletions xarray_sql/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,45 @@ def _parse_schema(ds) -> pa.Schema:
PartitionBounds = dict[str, tuple[Any, Any, str]]


def _block_metadata(coord_arrays: dict, block: Block) -> PartitionBounds:
"""Compute min/max coordinate values for a single partition block.

Args:
coord_arrays: Pre-materialised coordinate arrays keyed by dimension name
string. Hoist this outside any loop to avoid repeated remote I/O
for Zarr-backed datasets.
block: A single block slice dict from block_slices().

Returns:
Dict mapping dimension name to (min_value, max_value, dtype_str).
Dimensions with an empty slice are omitted; the Rust pruning logic
treats missing dimensions conservatively (never prunes on them).
"""
ranges: PartitionBounds = {}
for dim, slc in block.items():
coord_values = coord_arrays[str(dim)][slc]
if len(coord_values) > 0:
first, last = coord_values[0], coord_values[-1]
if first <= last:
min_val, max_val = first, last
else:
min_val, max_val = last, first

if isinstance(min_val, (np.datetime64, pd.Timestamp)):
min_val = int(pd.Timestamp(min_val).value)
max_val = int(pd.Timestamp(max_val).value)
ranges[str(dim)] = (min_val, max_val, "timestamp_ns")
elif hasattr(min_val, "item"):
min_val = min_val.item()
max_val = max_val.item()
dtype = "float64" if isinstance(min_val, float) else "int64"
ranges[str(dim)] = (min_val, max_val, dtype)
else:
dtype = "float64" if isinstance(min_val, float) else "int64"
ranges[str(dim)] = (min_val, max_val, dtype)
return ranges


def partition_metadata(
ds: xr.Dataset, blocks: list[Block]
) -> list[PartitionBounds]:
Expand Down Expand Up @@ -336,38 +375,4 @@ def partition_metadata(
# N_partitions × N_dims times is wasteful and, for remote Zarr-backed datasets
# (e.g. ARCO-ERA5 on GCS), may trigger repeated network I/O.
coord_arrays = {str(dim): ds.coords[dim].values for dim in ds.dims}

metadata = []
for block in blocks:
ranges: PartitionBounds = {}
for dim, slc in block.items():
coord_values = coord_arrays[str(dim)][slc]
if len(coord_values) > 0:
# Use endpoints for the common monotonic case (O(1)).
# xarray/CF-convention dimension coordinates are almost always
# monotonic; even for descending axes (e.g. latitude 90→-90)
# first/last gives the correct bounds after the min/max swap below.
first, last = coord_values[0], coord_values[-1]
if first <= last:
min_val, max_val = first, last
else:
min_val, max_val = last, first

# Convert numpy scalar types to Python native types
# This is required for PyO3 FFI conversion
if isinstance(min_val, (np.datetime64, pd.Timestamp)):
# Convert datetime to nanoseconds since epoch
min_val = int(pd.Timestamp(min_val).value)
max_val = int(pd.Timestamp(max_val).value)
ranges[str(dim)] = (min_val, max_val, "timestamp_ns")
elif hasattr(min_val, "item"):
# numpy scalar -> Python native
min_val = min_val.item()
max_val = max_val.item()
dtype = "float64" if isinstance(min_val, float) else "int64"
ranges[str(dim)] = (min_val, max_val, dtype)
else:
dtype = "float64" if isinstance(min_val, float) else "int64"
ranges[str(dim)] = (min_val, max_val, dtype)
metadata.append(ranges)
return metadata
return [_block_metadata(coord_arrays, block) for block in blocks]
Loading
Loading