From bf7c3555aced996088cf3d5b8b2ca680a65f91df Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Wed, 3 May 2023 10:01:03 -0400 Subject: [PATCH 1/2] Fix already borrowed error when accessing PyRecordBatchStream from multiple threads. --- src/record_batch.rs | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/record_batch.rs b/src/record_batch.rs index 15b70e8ce..091292ff8 100644 --- a/src/record_batch.rs +++ b/src/record_batch.rs @@ -20,7 +20,8 @@ use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::arrow::record_batch::RecordBatch; use datafusion::physical_plan::SendableRecordBatchStream; use futures::StreamExt; -use pyo3::{pyclass, pymethods, PyObject, PyResult, Python}; +use pyo3::{pyclass, pymethods, PyObject, PyRef, PyResult, Python}; +use std::sync::Mutex; #[pyclass(name = "RecordBatch", module = "datafusion", subclass)] pub struct PyRecordBatch { @@ -42,19 +43,26 @@ impl From for PyRecordBatch { #[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)] pub struct PyRecordBatchStream { - stream: SendableRecordBatchStream, + stream: Mutex, } impl PyRecordBatchStream { pub fn new(stream: SendableRecordBatchStream) -> Self { - Self { stream } + Self { + stream: Mutex::new(stream), + } } } #[pymethods] impl PyRecordBatchStream { - fn next(&mut self, py: Python) -> PyResult> { - let result = self.stream.next(); + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(&mut self, py: Python) -> PyResult> { + let mut stream = self.stream.lock().unwrap(); + let result = stream.next(); match wait_for_future(py, result) { None => Ok(None), Some(Ok(b)) => Ok(Some(b.into())), From 2b932abf6d9078a8a1800141950137df47ae58ca Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Wed, 3 May 2023 16:05:58 -0400 Subject: [PATCH 2/2] Remove unneeded mutable pyclass borrows which caused "Already Borrowed" Panic in multi-thread. --- datafusion/tests/test_dataframe.py | 8 ++++--- src/config.rs | 6 +++--- src/context.rs | 34 +++++++++++++++--------------- src/expr/literal.rs | 4 ++-- 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index 221b0cc09..b37290ff0 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -391,11 +391,13 @@ def test_execution_plan(aggregate_df): ctx = SessionContext() stream = ctx.execute(plan, 0) # get the one and only batch - batch = stream.next() + batch = next(stream) assert batch is not None # there should be no more batches - batch = stream.next() - assert batch is None + try: + batch = next(stream) + except StopIteration: + pass def test_repartition(df): diff --git a/src/config.rs b/src/config.rs index 228f95a0b..1538ea981 100644 --- a/src/config.rs +++ b/src/config.rs @@ -45,7 +45,7 @@ impl PyConfig { } /// Get a configuration option - pub fn get(&mut self, key: &str, py: Python) -> PyResult { + pub fn get(&self, key: &str, py: Python) -> PyResult { let options = self.config.to_owned(); for entry in options.entries() { if entry.key == key { @@ -64,7 +64,7 @@ impl PyConfig { } /// Get all configuration options - pub fn get_all(&mut self, py: Python) -> PyResult { + pub fn get_all(&self, py: Python) -> PyResult { let dict = PyDict::new(py); let options = self.config.to_owned(); for entry in options.entries() { @@ -73,7 +73,7 @@ impl PyConfig { Ok(dict.into()) } - fn __repr__(&mut self, py: Python) -> PyResult { + fn __repr__(&self, py: Python) -> PyResult { let dict = self.get_all(py); match dict { Ok(result) => Ok(format!("Config({result})")), diff --git a/src/context.rs b/src/context.rs index b7f82230f..bf8db97b6 100644 --- a/src/context.rs +++ b/src/context.rs @@ -239,7 +239,7 @@ impl PySessionContext { /// Register a an object store with the given name fn register_object_store( - &mut self, + &self, scheme: &str, store: &PyAny, host: Option<&str>, @@ -272,14 +272,14 @@ impl PySessionContext { } /// Returns a PyDataFrame whose plan corresponds to the SQL statement. - fn sql(&mut self, query: &str, py: Python) -> PyResult { + fn sql(&self, query: &str, py: Python) -> PyResult { let result = self.ctx.sql(query); let df = wait_for_future(py, result).map_err(DataFusionError::from)?; Ok(PyDataFrame::new(df)) } fn create_dataframe( - &mut self, + &self, partitions: PyArrowType>>, name: Option<&str>, py: Python, @@ -310,14 +310,14 @@ impl PySessionContext { } /// Create a DataFrame from an existing logical plan - fn create_dataframe_from_logical_plan(&mut self, plan: PyLogicalPlan) -> PyDataFrame { + fn create_dataframe_from_logical_plan(&self, plan: PyLogicalPlan) -> PyDataFrame { PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone())) } /// Construct datafusion dataframe from Python list #[allow(clippy::wrong_self_convention)] fn from_pylist( - &mut self, + &self, data: PyObject, name: Option<&str>, _py: Python, @@ -337,7 +337,7 @@ impl PySessionContext { /// Construct datafusion dataframe from Python dictionary #[allow(clippy::wrong_self_convention)] fn from_pydict( - &mut self, + &self, data: PyObject, name: Option<&str>, _py: Python, @@ -357,7 +357,7 @@ impl PySessionContext { /// Construct datafusion dataframe from Arrow Table #[allow(clippy::wrong_self_convention)] fn from_arrow_table( - &mut self, + &self, data: PyObject, name: Option<&str>, _py: Python, @@ -378,7 +378,7 @@ impl PySessionContext { /// Construct datafusion dataframe from pandas #[allow(clippy::wrong_self_convention)] fn from_pandas( - &mut self, + &self, data: PyObject, name: Option<&str>, _py: Python, @@ -398,7 +398,7 @@ impl PySessionContext { /// Construct datafusion dataframe from polars #[allow(clippy::wrong_self_convention)] fn from_polars( - &mut self, + &self, data: PyObject, name: Option<&str>, _py: Python, @@ -413,14 +413,14 @@ impl PySessionContext { }) } - fn register_table(&mut self, name: &str, table: &PyTable) -> PyResult<()> { + fn register_table(&self, name: &str, table: &PyTable) -> PyResult<()> { self.ctx .register_table(name, table.table()) .map_err(DataFusionError::from)?; Ok(()) } - fn deregister_table(&mut self, name: &str) -> PyResult<()> { + fn deregister_table(&self, name: &str) -> PyResult<()> { self.ctx .deregister_table(name) .map_err(DataFusionError::from)?; @@ -428,7 +428,7 @@ impl PySessionContext { } fn register_record_batches( - &mut self, + &self, name: &str, partitions: PyArrowType>>, ) -> PyResult<()> { @@ -445,7 +445,7 @@ impl PySessionContext { parquet_pruning=true, file_extension=".parquet"))] fn register_parquet( - &mut self, + &self, name: &str, path: &str, table_partition_cols: Vec<(String, String)>, @@ -471,7 +471,7 @@ impl PySessionContext { schema_infer_max_records=1000, file_extension=".csv"))] fn register_csv( - &mut self, + &self, name: &str, path: PathBuf, schema: Option>, @@ -515,12 +515,12 @@ impl PySessionContext { Ok(()) } - fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> { + fn register_udf(&self, udf: PyScalarUDF) -> PyResult<()> { self.ctx.register_udf(udf.function); Ok(()) } - fn register_udaf(&mut self, udaf: PyAggregateUDF) -> PyResult<()> { + fn register_udaf(&self, udaf: PyAggregateUDF) -> PyResult<()> { self.ctx.register_udaf(udaf.function); Ok(()) } @@ -561,7 +561,7 @@ impl PySessionContext { #[allow(clippy::too_many_arguments)] #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![]))] fn read_json( - &mut self, + &self, path: PathBuf, schema: Option>, schema_infer_max_records: usize, diff --git a/src/expr/literal.rs b/src/expr/literal.rs index 076f89a66..362370e73 100644 --- a/src/expr/literal.rs +++ b/src/expr/literal.rs @@ -61,7 +61,7 @@ impl PyLiteral { extract_scalar_value!(self, Float64) } - pub fn value_decimal128(&mut self) -> PyResult<(Option, u8, i8)> { + pub fn value_decimal128(&self) -> PyResult<(Option, u8, i8)> { match &self.value { ScalarValue::Decimal128(value, precision, scale) => Ok((*value, *precision, *scale)), other => Err(unexpected_literal_value(other)), @@ -112,7 +112,7 @@ impl PyLiteral { extract_scalar_value!(self, Time64Nanosecond) } - pub fn value_timestamp(&mut self) -> PyResult<(Option, Option)> { + pub fn value_timestamp(&self) -> PyResult<(Option, Option)> { match &self.value { ScalarValue::TimestampNanosecond(iv, tz) | ScalarValue::TimestampMicrosecond(iv, tz)