Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions datafusion/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,11 +391,13 @@ def test_execution_plan(aggregate_df):
ctx = SessionContext()
stream = ctx.execute(plan, 0)
# get the one and only batch
batch = stream.next()
batch = next(stream)
assert batch is not None
# there should be no more batches
batch = stream.next()
assert batch is None
try:
batch = next(stream)
except StopIteration:
pass


def test_repartition(df):
Expand Down
6 changes: 3 additions & 3 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl PyConfig {
}

/// Get a configuration option
pub fn get(&mut self, key: &str, py: Python) -> PyResult<PyObject> {
pub fn get(&self, key: &str, py: Python) -> PyResult<PyObject> {
let options = self.config.to_owned();
for entry in options.entries() {
if entry.key == key {
Expand All @@ -64,7 +64,7 @@ impl PyConfig {
}

/// Get all configuration options
pub fn get_all(&mut self, py: Python) -> PyResult<PyObject> {
pub fn get_all(&self, py: Python) -> PyResult<PyObject> {
let dict = PyDict::new(py);
let options = self.config.to_owned();
for entry in options.entries() {
Expand All @@ -73,7 +73,7 @@ impl PyConfig {
Ok(dict.into())
}

fn __repr__(&mut self, py: Python) -> PyResult<String> {
fn __repr__(&self, py: Python) -> PyResult<String> {
let dict = self.get_all(py);
match dict {
Ok(result) => Ok(format!("Config({result})")),
Expand Down
34 changes: 17 additions & 17 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ impl PySessionContext {

/// Register a an object store with the given name
fn register_object_store(
&mut self,
&self,
scheme: &str,
store: &PyAny,
host: Option<&str>,
Expand Down Expand Up @@ -272,14 +272,14 @@ impl PySessionContext {
}

/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
fn sql(&mut self, query: &str, py: Python) -> PyResult<PyDataFrame> {
fn sql(&self, query: &str, py: Python) -> PyResult<PyDataFrame> {
let result = self.ctx.sql(query);
let df = wait_for_future(py, result).map_err(DataFusionError::from)?;
Ok(PyDataFrame::new(df))
}

fn create_dataframe(
&mut self,
&self,
partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
name: Option<&str>,
py: Python,
Expand Down Expand Up @@ -310,14 +310,14 @@ impl PySessionContext {
}

/// Create a DataFrame from an existing logical plan
fn create_dataframe_from_logical_plan(&mut self, plan: PyLogicalPlan) -> PyDataFrame {
fn create_dataframe_from_logical_plan(&self, plan: PyLogicalPlan) -> PyDataFrame {
PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
}

/// Construct datafusion dataframe from Python list
#[allow(clippy::wrong_self_convention)]
fn from_pylist(
&mut self,
&self,
data: PyObject,
name: Option<&str>,
_py: Python,
Expand All @@ -337,7 +337,7 @@ impl PySessionContext {
/// Construct datafusion dataframe from Python dictionary
#[allow(clippy::wrong_self_convention)]
fn from_pydict(
&mut self,
&self,
data: PyObject,
name: Option<&str>,
_py: Python,
Expand All @@ -357,7 +357,7 @@ impl PySessionContext {
/// Construct datafusion dataframe from Arrow Table
#[allow(clippy::wrong_self_convention)]
fn from_arrow_table(
&mut self,
&self,
data: PyObject,
name: Option<&str>,
_py: Python,
Expand All @@ -378,7 +378,7 @@ impl PySessionContext {
/// Construct datafusion dataframe from pandas
#[allow(clippy::wrong_self_convention)]
fn from_pandas(
&mut self,
&self,
data: PyObject,
name: Option<&str>,
_py: Python,
Expand All @@ -398,7 +398,7 @@ impl PySessionContext {
/// Construct datafusion dataframe from polars
#[allow(clippy::wrong_self_convention)]
fn from_polars(
&mut self,
&self,
data: PyObject,
name: Option<&str>,
_py: Python,
Expand All @@ -413,22 +413,22 @@ impl PySessionContext {
})
}

fn register_table(&mut self, name: &str, table: &PyTable) -> PyResult<()> {
fn register_table(&self, name: &str, table: &PyTable) -> PyResult<()> {
self.ctx
.register_table(name, table.table())
.map_err(DataFusionError::from)?;
Ok(())
}

fn deregister_table(&mut self, name: &str) -> PyResult<()> {
fn deregister_table(&self, name: &str) -> PyResult<()> {
self.ctx
.deregister_table(name)
.map_err(DataFusionError::from)?;
Ok(())
}

fn register_record_batches(
&mut self,
&self,
name: &str,
partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
) -> PyResult<()> {
Expand All @@ -445,7 +445,7 @@ impl PySessionContext {
parquet_pruning=true,
file_extension=".parquet"))]
fn register_parquet(
&mut self,
&self,
name: &str,
path: &str,
table_partition_cols: Vec<(String, String)>,
Expand All @@ -471,7 +471,7 @@ impl PySessionContext {
schema_infer_max_records=1000,
file_extension=".csv"))]
fn register_csv(
&mut self,
&self,
name: &str,
path: PathBuf,
schema: Option<PyArrowType<Schema>>,
Expand Down Expand Up @@ -515,12 +515,12 @@ impl PySessionContext {
Ok(())
}

fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> {
fn register_udf(&self, udf: PyScalarUDF) -> PyResult<()> {
self.ctx.register_udf(udf.function);
Ok(())
}

fn register_udaf(&mut self, udaf: PyAggregateUDF) -> PyResult<()> {
fn register_udaf(&self, udaf: PyAggregateUDF) -> PyResult<()> {
self.ctx.register_udaf(udaf.function);
Ok(())
}
Expand Down Expand Up @@ -561,7 +561,7 @@ impl PySessionContext {
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![]))]
fn read_json(
&mut self,
&self,
path: PathBuf,
schema: Option<PyArrowType<Schema>>,
schema_infer_max_records: usize,
Expand Down
4 changes: 2 additions & 2 deletions src/expr/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl PyLiteral {
extract_scalar_value!(self, Float64)
}

pub fn value_decimal128(&mut self) -> PyResult<(Option<i128>, u8, i8)> {
pub fn value_decimal128(&self) -> PyResult<(Option<i128>, u8, i8)> {
match &self.value {
ScalarValue::Decimal128(value, precision, scale) => Ok((*value, *precision, *scale)),
other => Err(unexpected_literal_value(other)),
Expand Down Expand Up @@ -112,7 +112,7 @@ impl PyLiteral {
extract_scalar_value!(self, Time64Nanosecond)
}

pub fn value_timestamp(&mut self) -> PyResult<(Option<i64>, Option<String>)> {
pub fn value_timestamp(&self) -> PyResult<(Option<i64>, Option<String>)> {
match &self.value {
ScalarValue::TimestampNanosecond(iv, tz)
| ScalarValue::TimestampMicrosecond(iv, tz)
Expand Down
18 changes: 13 additions & 5 deletions src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ use datafusion::arrow::pyarrow::PyArrowConvert;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::physical_plan::SendableRecordBatchStream;
use futures::StreamExt;
use pyo3::{pyclass, pymethods, PyObject, PyResult, Python};
use pyo3::{pyclass, pymethods, PyObject, PyRef, PyResult, Python};
use std::sync::Mutex;

#[pyclass(name = "RecordBatch", module = "datafusion", subclass)]
pub struct PyRecordBatch {
Expand All @@ -42,19 +43,26 @@ impl From<RecordBatch> for PyRecordBatch {

#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)]
pub struct PyRecordBatchStream {
stream: SendableRecordBatchStream,
stream: Mutex<SendableRecordBatchStream>,
}

impl PyRecordBatchStream {
pub fn new(stream: SendableRecordBatchStream) -> Self {
Self { stream }
Self {
stream: Mutex::new(stream),
}
}
}

#[pymethods]
impl PyRecordBatchStream {
fn next(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
let result = self.stream.next();
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}

fn __next__(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
let mut stream = self.stream.lock().unwrap();
let result = stream.next();
match wait_for_future(py, result) {
None => Ok(None),
Some(Ok(b)) => Ok(Some(b.into())),
Expand Down