diff --git a/docs/source/conf.py b/docs/source/conf.py index 28db17d35..01813b032 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -91,6 +91,13 @@ def autoapi_skip_member_fn(app, what, name, obj, skip, options) -> bool: # noqa ("method", "datafusion.context.SessionContext.tables"), ("method", "datafusion.dataframe.DataFrame.unnest_column"), ] + # Explicitly skip certain members listed above. These are either + # re-exports, duplicate module-level documentation, deprecated + # API surfaces, or private variables that would otherwise appear + # in the generated docs and cause confusing duplication. + # Keeping this explicit list avoids surprising entries in the + # AutoAPI output and gives us a single place to opt-out items + # when we intentionally hide them from the docs. if (what, name) in skip_contents: skip = True diff --git a/docs/source/contributor-guide/ffi.rst b/docs/source/contributor-guide/ffi.rst index e8a0398b8..64413866f 100644 --- a/docs/source/contributor-guide/ffi.rst +++ b/docs/source/contributor-guide/ffi.rst @@ -34,7 +34,7 @@ as performant as possible and to utilize the features of DataFusion, you may dec your source in Rust and then expose it through `PyO3 `_ as a Python library. At first glance, it may appear the best way to do this is to add the ``datafusion-python`` -crate as a dependency, provide a ``PyTable``, and then to register it with the +crate as a dependency, provide a ``PyTable``, and then to register it with the ``SessionContext``. Unfortunately, this will not work. When you produce your code as a Python library and it needs to interact with the DataFusion diff --git a/docs/source/user-guide/data-sources.rst b/docs/source/user-guide/data-sources.rst index a9b119b93..26f1303c4 100644 --- a/docs/source/user-guide/data-sources.rst +++ b/docs/source/user-guide/data-sources.rst @@ -154,11 +154,11 @@ as Delta Lake. This will require a recent version of from deltalake import DeltaTable delta_table = DeltaTable("path_to_table") - ctx.register_table_provider("my_delta_table", delta_table) + ctx.register_table("my_delta_table", delta_table) df = ctx.table("my_delta_table") df.show() -On older versions of ``deltalake`` (prior to 0.22) you can use the +On older versions of ``deltalake`` (prior to 0.22) you can use the `Arrow DataSet `_ interface to import to DataFusion, but this does not support features such as filter push down which can lead to a significant performance difference. diff --git a/docs/source/user-guide/io/table_provider.rst b/docs/source/user-guide/io/table_provider.rst index bd1d6b80f..29e5d9880 100644 --- a/docs/source/user-guide/io/table_provider.rst +++ b/docs/source/user-guide/io/table_provider.rst @@ -37,22 +37,26 @@ A complete example can be found in the `examples folder , ) -> PyResult> { - let name = CString::new("datafusion_table_provider").unwrap(); + let name = cr"datafusion_table_provider".into(); - let provider = Arc::new(self.clone()) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - let provider = FFI_TableProvider::new(Arc::new(provider), false); + let provider = Arc::new(self.clone()); + let provider = FFI_TableProvider::new(provider, false, None); PyCapsule::new_bound(py, provider, Some(name.clone())) } } -Once you have this library available, in python you can register your table provider -to the ``SessionContext``. +Once you have this library available, you can construct a +:py:class:`~datafusion.Table` in Python and register it with the +``SessionContext``. .. code-block:: python + from datafusion import SessionContext, Table + + ctx = SessionContext() provider = MyTableProvider() - ctx.register_table_provider("my_table", provider) - ctx.table("my_table").show() + ctx.register_table("capsule_table", provider) + + ctx.table("capsule_table").show() diff --git a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py index 72aadf64c..1bf1bf136 100644 --- a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py +++ b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py @@ -36,9 +36,9 @@ def test_catalog_provider(): my_catalog_schemas = my_catalog.names() assert expected_schema_name in my_catalog_schemas - my_database = my_catalog.database(expected_schema_name) - assert expected_table_name in my_database.names() - my_table = my_database.table(expected_table_name) + my_schema = my_catalog.schema(expected_schema_name) + assert expected_table_name in my_schema.names() + my_table = my_schema.table(expected_table_name) assert expected_table_columns == my_table.schema.names result = ctx.table( diff --git a/examples/datafusion-ffi-example/python/tests/_test_table_function.py b/examples/datafusion-ffi-example/python/tests/_test_table_function.py index f3c56a90a..4b8b21454 100644 --- a/examples/datafusion-ffi-example/python/tests/_test_table_function.py +++ b/examples/datafusion-ffi-example/python/tests/_test_table_function.py @@ -53,7 +53,7 @@ def test_ffi_table_function_call_directly(): table_udtf = udtf(table_func, "my_table_func") my_table = table_udtf() - ctx.register_table_provider("t", my_table) + ctx.register_table("t", my_table) result = ctx.table("t").collect() assert len(result) == 2 diff --git a/examples/datafusion-ffi-example/python/tests/_test_table_provider.py b/examples/datafusion-ffi-example/python/tests/_test_table_provider.py index 6b24da06c..48feaff64 100644 --- a/examples/datafusion-ffi-example/python/tests/_test_table_provider.py +++ b/examples/datafusion-ffi-example/python/tests/_test_table_provider.py @@ -25,7 +25,7 @@ def test_table_loading(): ctx = SessionContext() table = MyTableProvider(3, 2, 4) - ctx.register_table_provider("t", table) + ctx.register_table("t", table) result = ctx.table("t").collect() assert len(result) == 4 @@ -40,3 +40,7 @@ def test_table_loading(): ] assert result == expected + + result = ctx.read_table(table).collect() + result = [r.column(0) for r in result] + assert result == expected diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index e9d2dba75..9ebd58ea6 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -28,17 +28,16 @@ try: import importlib.metadata as importlib_metadata except ImportError: - import importlib_metadata + import importlib_metadata # type: ignore[import] +# Public submodules from . import functions, object_store, substrait, unparser # The following imports are okay to remain as opaque to the user. from ._internal import Config from .catalog import Catalog, Database, Table from .col import col, column -from .common import ( - DFSchema, -) +from .common import DFSchema from .context import ( RuntimeEnvBuilder, SessionConfig, @@ -47,10 +46,7 @@ ) from .dataframe import DataFrame, ParquetColumnOptions, ParquetWriterOptions from .dataframe_formatter import configure_formatter -from .expr import ( - Expr, - WindowFrame, -) +from .expr import Expr, WindowFrame from .io import read_avro, read_csv, read_json, read_parquet from .plan import ExecutionPlan, LogicalPlan from .record_batch import RecordBatch, RecordBatchStream diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 536b3a790..da54d233d 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -20,13 +20,16 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Protocol +from typing import TYPE_CHECKING, Any, Protocol import datafusion._internal as df_internal if TYPE_CHECKING: import pyarrow as pa + from datafusion import DataFrame + from datafusion.context import TableProviderExportable + try: from warnings import deprecated # Python 3.13+ except ImportError: @@ -82,7 +85,11 @@ def database(self, name: str = "public") -> Schema: """Returns the database with the given ``name`` from this catalog.""" return self.schema(name) - def register_schema(self, name, schema) -> Schema | None: + def register_schema( + self, + name: str, + schema: Schema | SchemaProvider | SchemaProviderExportable, + ) -> Schema | None: """Register a schema with this catalog.""" if isinstance(schema, Schema): return self.catalog.register_schema(name, schema._raw_schema) @@ -122,10 +129,12 @@ def table(self, name: str) -> Table: """Return the table with the given ``name`` from this schema.""" return Table(self._raw_schema.table(name)) - def register_table(self, name, table) -> None: - """Register a table provider in this schema.""" - if isinstance(table, Table): - return self._raw_schema.register_table(name, table.table) + def register_table( + self, + name: str, + table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset, + ) -> None: + """Register a table in this schema.""" return self._raw_schema.register_table(name, table) def deregister_table(self, name: str) -> None: @@ -139,30 +148,45 @@ class Database(Schema): class Table: - """DataFusion table.""" + """A DataFusion table. - def __init__(self, table: df_internal.catalog.RawTable) -> None: - """This constructor is not typically called by the end user.""" - self.table = table + Internally we currently support the following types of tables: + + - Tables created using built-in DataFusion methods, such as + reading from CSV or Parquet + - pyarrow datasets + - DataFusion DataFrames, which will be converted into a view + - Externally provided tables implemented with the FFI PyCapsule + interface (advanced) + """ + + __slots__ = ("_inner",) + + def __init__( + self, table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset + ) -> None: + """Constructor.""" + self._inner = df_internal.catalog.RawTable(table) def __repr__(self) -> str: """Print a string representation of the table.""" - return self.table.__repr__() + return repr(self._inner) @staticmethod + @deprecated("Use Table() constructor instead.") def from_dataset(dataset: pa.dataset.Dataset) -> Table: - """Turn a pyarrow Dataset into a Table.""" - return Table(df_internal.catalog.RawTable.from_dataset(dataset)) + """Turn a :mod:`pyarrow.dataset` ``Dataset`` into a :class:`Table`.""" + return Table(dataset) @property def schema(self) -> pa.Schema: """Returns the schema associated with this table.""" - return self.table.schema + return self._inner.schema @property def kind(self) -> str: """Returns the kind of table.""" - return self.table.kind + return self._inner.kind class CatalogProvider(ABC): @@ -219,14 +243,16 @@ def table(self, name: str) -> Table | None: """Retrieve a specific table from this schema.""" ... - def register_table(self, name: str, table: Table) -> None: # noqa: B027 - """Add a table from this schema. + def register_table( # noqa: B027 + self, name: str, table: Table | TableProviderExportable | Any + ) -> None: + """Add a table to this schema. This method is optional. If your schema provides a fixed list of tables, you do not need to implement this method. """ - def deregister_table(self, name, cascade: bool) -> None: # noqa: B027 + def deregister_table(self, name: str, cascade: bool) -> None: # noqa: B027 """Remove a table from this schema. This method is optional. If your schema provides a fixed list of tables, you do diff --git a/python/datafusion/context.py b/python/datafusion/context.py index b6e728b51..0aa2f27c4 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -29,11 +29,10 @@ import pyarrow as pa -from datafusion.catalog import Catalog, CatalogProvider, Table +from datafusion.catalog import Catalog from datafusion.dataframe import DataFrame -from datafusion.expr import SortKey, sort_list_to_raw_sort_list +from datafusion.expr import sort_list_to_raw_sort_list from datafusion.record_batch import RecordBatchStream -from datafusion.user_defined import AggregateUDF, ScalarUDF, TableFunction, WindowUDF from ._internal import RuntimeEnvBuilder as RuntimeEnvBuilderInternal from ._internal import SessionConfig as SessionConfigInternal @@ -48,7 +47,15 @@ import pandas as pd import polars as pl # type: ignore[import] + from datafusion.catalog import CatalogProvider, Table + from datafusion.expr import SortKey from datafusion.plan import ExecutionPlan, LogicalPlan + from datafusion.user_defined import ( + AggregateUDF, + ScalarUDF, + TableFunction, + WindowUDF, + ) class ArrowStreamExportable(Protocol): @@ -733,7 +740,7 @@ def from_polars(self, data: pl.DataFrame, name: str | None = None) -> DataFrame: # https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116 # is the discussion on how we arrived at adding register_view def register_view(self, name: str, df: DataFrame) -> None: - """Register a :py:class: `~datafusion.detaframe.DataFrame` as a view. + """Register a :py:class:`~datafusion.dataframe.DataFrame` as a view. Args: name (str): The name to register the view under. @@ -742,16 +749,21 @@ def register_view(self, name: str, df: DataFrame) -> None: view = df.into_view() self.ctx.register_table(name, view) - def register_table(self, name: str, table: Table) -> None: - """Register a :py:class: `~datafusion.catalog.Table` as a table. + def register_table( + self, + name: str, + table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset, + ) -> None: + """Register a :py:class:`~datafusion.Table` with this context. - The registered table can be referenced from SQL statement executed against. + The registered table can be referenced from SQL statements executed against + this context. Args: name: Name of the resultant table. - table: DataFusion table to add to the session context. + table: Any object that can be converted into a :class:`Table`. """ - self.ctx.register_table(name, table.table) + self.ctx.register_table(name, table) def deregister_table(self, name: str) -> None: """Remove a table from the session.""" @@ -770,15 +782,17 @@ def register_catalog_provider( else: self.ctx.register_catalog_provider(name, provider) + @deprecated("Use register_table() instead.") def register_table_provider( - self, name: str, provider: TableProviderExportable + self, + name: str, + provider: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset, ) -> None: """Register a table provider. - This table provider must have a method called ``__datafusion_table_provider__`` - which returns a PyCapsule that exposes a ``FFI_TableProvider``. + Deprecated: use :meth:`register_table` instead. """ - self.ctx.register_table_provider(name, provider) + self.register_table(name, provider) def register_udtf(self, func: TableFunction) -> None: """Register a user defined table function.""" @@ -1163,14 +1177,11 @@ def read_avro( self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) ) - def read_table(self, table: Table) -> DataFrame: - """Creates a :py:class:`~datafusion.dataframe.DataFrame` from a table. - - For a :py:class:`~datafusion.catalog.Table` such as a - :py:class:`~datafusion.catalog.ListingTable`, create a - :py:class:`~datafusion.dataframe.DataFrame`. - """ - return DataFrame(self.ctx.read_table(table.table)) + def read_table( + self, table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset + ) -> DataFrame: + """Creates a :py:class:`~datafusion.dataframe.DataFrame` from a table.""" + return DataFrame(self.ctx.read_table(table)) def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream: """Execute the ``plan`` and return the results.""" diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index c1b649e33..5a21d773b 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -60,6 +60,8 @@ import polars as pl import pyarrow as pa + from datafusion.catalog import Table + from enum import Enum @@ -313,9 +315,21 @@ def __init__(self, df: DataFrameInternal) -> None: """ self.df = df - def into_view(self) -> pa.Table: - """Convert DataFrame as a ViewTable which can be used in register_table.""" - return self.df.into_view() + def into_view(self) -> Table: + """Convert ``DataFrame`` into a :class:`~datafusion.Table`. + + Examples: + >>> from datafusion import SessionContext + >>> ctx = SessionContext() + >>> df = ctx.sql("SELECT 1 AS value") + >>> view = df.into_view() + >>> ctx.register_table("values_view", view) + >>> df.collect() # The DataFrame is still usable + >>> ctx.sql("SELECT value FROM values_view").collect() + """ + from datafusion.catalog import Table as _Table + + return _Table(self.df.into_view()) def __getitem__(self, key: str | list[str]) -> DataFrame: """Return a new :py:class`DataFrame` with the specified column or columns. diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 5d1180bd1..82e30a78c 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -25,14 +25,12 @@ import typing as _typing from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional, Sequence -import pyarrow as pa - try: from warnings import deprecated # Python 3.13+ except ImportError: from typing_extensions import deprecated # Python 3.12 -from datafusion.common import NullTreatment +import pyarrow as pa from ._internal import expr as expr_internal from ._internal import functions as functions_internal @@ -40,8 +38,11 @@ if TYPE_CHECKING: from collections.abc import Sequence - # Type-only imports - from datafusion.common import DataTypeMap, RexType + from datafusion.common import ( # type: ignore[import] + DataTypeMap, + NullTreatment, + RexType, + ) from datafusion.plan import LogicalPlan diff --git a/python/datafusion/io.py b/python/datafusion/io.py index 551e20a6f..67dbc730f 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -22,13 +22,13 @@ from typing import TYPE_CHECKING from datafusion.context import SessionContext -from datafusion.dataframe import DataFrame if TYPE_CHECKING: import pathlib import pyarrow as pa + from datafusion.dataframe import DataFrame from datafusion.expr import Expr diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 1f9ecbfc3..08f494dee 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -20,7 +20,7 @@ import pyarrow as pa import pyarrow.dataset as ds import pytest -from datafusion import SessionContext, Table +from datafusion import SessionContext, Table, udtf # Note we take in `database` as a variable even though we don't use @@ -53,7 +53,7 @@ def create_dataset() -> Table: names=["a", "b"], ) dataset = ds.dataset([batch]) - return Table.from_dataset(dataset) + return Table(dataset) class CustomSchemaProvider(dfn.catalog.SchemaProvider): @@ -164,6 +164,28 @@ def test_python_table_provider(ctx: SessionContext): assert schema.table_names() == {"table4"} +def test_schema_register_table_with_pyarrow_dataset(ctx: SessionContext): + schema = ctx.catalog().schema() + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + dataset = ds.dataset([batch]) + table_name = "pa_dataset" + + try: + schema.register_table(table_name, dataset) + assert table_name in schema.table_names() + + result = ctx.sql(f"SELECT a, b FROM {table_name}").collect() + + assert len(result) == 1 + assert result[0].column(0) == pa.array([1, 2, 3]) + assert result[0].column(1) == pa.array([4, 5, 6]) + finally: + schema.deregister_table(table_name) + + def test_in_end_to_end_python_providers(ctx: SessionContext): """Test registering all python providers and running a query against them.""" @@ -210,3 +232,19 @@ def test_in_end_to_end_python_providers(ctx: SessionContext): assert len(batches) == 1 assert batches[0].column(0) == pa.array([1, 2, 3]) assert batches[0].column(1) == pa.array([4, 5, 6]) + + +def test_register_python_function_as_udtf(ctx: SessionContext): + basic_table = Table(ctx.sql("SELECT 3 AS value")) + + @udtf("my_table_function") + def my_table_function_udtf() -> Table: + return basic_table + + ctx.register_udtf(my_table_function_udtf) + + result = ctx.sql("SELECT * FROM my_table_function()").collect() + assert len(result) == 1 + assert len(result[0]) == 1 + assert len(result[0][0]) == 1 + assert result[0][0][0].as_py() == 3 diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 6dbcc0d5e..94d1e6a39 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -27,6 +27,7 @@ SessionConfig, SessionContext, SQLOptions, + Table, column, literal, ) @@ -311,7 +312,7 @@ def test_register_table(ctx, database): assert public.names() == {"csv", "csv1", "csv2", "csv3"} -def test_read_table(ctx, database): +def test_read_table_from_catalog(ctx, database): default = ctx.catalog() public = default.schema("public") assert public.names() == {"csv", "csv1", "csv2"} @@ -321,6 +322,25 @@ def test_read_table(ctx, database): table_df.show() +def test_read_table_from_df(ctx): + df = ctx.from_pydict({"a": [1, 2]}) + result = ctx.read_table(df).collect() + assert [b.to_pydict() for b in result] == [{"a": [1, 2]}] + + +def test_read_table_from_dataset(ctx): + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + dataset = ds.dataset([batch]) + + result = ctx.read_table(dataset).collect() + + assert result[0].column(0) == pa.array([1, 2, 3]) + assert result[0].column(1) == pa.array([4, 5, 6]) + + def test_deregister_table(ctx, database): default = ctx.catalog() public = default.schema("public") @@ -330,6 +350,40 @@ def test_deregister_table(ctx, database): assert public.names() == {"csv1", "csv2"} +def test_register_table_from_dataframe(ctx): + df = ctx.from_pydict({"a": [1, 2]}) + ctx.register_table("df_tbl", df) + result = ctx.sql("SELECT * FROM df_tbl").collect() + assert [b.to_pydict() for b in result] == [{"a": [1, 2]}] + + +def test_register_table_from_dataframe_into_view(ctx): + df = ctx.from_pydict({"a": [1, 2]}) + table = df.into_view() + assert isinstance(table, Table) + ctx.register_table("view_tbl", table) + result = ctx.sql("SELECT * FROM view_tbl").collect() + assert [b.to_pydict() for b in result] == [{"a": [1, 2]}] + + +def test_table_from_dataframe(ctx): + df = ctx.from_pydict({"a": [1, 2]}) + table = Table(df) + assert isinstance(table, Table) + ctx.register_table("from_dataframe_tbl", table) + result = ctx.sql("SELECT * FROM from_dataframe_tbl").collect() + assert [b.to_pydict() for b in result] == [{"a": [1, 2]}] + + +def test_table_from_dataframe_internal(ctx): + df = ctx.from_pydict({"a": [1, 2]}) + table = Table(df.df) + assert isinstance(table, Table) + ctx.register_table("from_internal_dataframe_tbl", table) + result = ctx.sql("SELECT * FROM from_internal_dataframe_tbl").collect() + assert [b.to_pydict() for b in result] == [{"a": [1, 2]}] + + def test_register_dataset(ctx): # create a RecordBatch and register it as a pyarrow.dataset.Dataset batch = pa.RecordBatch.from_arrays( diff --git a/python/tests/test_wrapper_coverage.py b/python/tests/test_wrapper_coverage.py index f484cb282..cf6719ecf 100644 --- a/python/tests/test_wrapper_coverage.py +++ b/python/tests/test_wrapper_coverage.py @@ -28,7 +28,27 @@ from enum import EnumMeta as EnumType -def missing_exports(internal_obj, wrapped_obj) -> None: # noqa: C901 +def _check_enum_exports(internal_obj, wrapped_obj) -> None: + """Check that all enum values are present in wrapped object.""" + expected_values = [v for v in dir(internal_obj) if not v.startswith("__")] + for value in expected_values: + assert value in dir(wrapped_obj) + + +def _check_list_attribute(internal_attr, wrapped_attr) -> None: + """Check that list attributes match between internal and wrapped objects.""" + assert isinstance(wrapped_attr, list) + + # We have cases like __all__ that are a list and we want to be certain that + # every value in the list in the internal object is also in the wrapper list + for val in internal_attr: + if isinstance(val, str) and val.startswith("Raw"): + assert val[3:] in wrapped_attr + else: + assert val in wrapped_attr + + +def missing_exports(internal_obj, wrapped_obj) -> None: """ Identify if any of the rust exposted structs or functions do not have wrappers. @@ -40,9 +60,7 @@ def missing_exports(internal_obj, wrapped_obj) -> None: # noqa: C901 # Special case enums - EnumType overrides a some of the internal functions, # so check all of the values exist and move on if isinstance(wrapped_obj, EnumType): - expected_values = [v for v in dir(internal_obj) if not v.startswith("__")] - for value in expected_values: - assert value in dir(wrapped_obj) + _check_enum_exports(internal_obj, wrapped_obj) return if "__repr__" in internal_obj.__dict__ and "__repr__" not in wrapped_obj.__dict__: @@ -50,6 +68,7 @@ def missing_exports(internal_obj, wrapped_obj) -> None: # noqa: C901 for internal_attr_name in dir(internal_obj): wrapped_attr_name = internal_attr_name.removeprefix("Raw") + assert wrapped_attr_name in dir(wrapped_obj) internal_attr = getattr(internal_obj, internal_attr_name) @@ -66,15 +85,7 @@ def missing_exports(internal_obj, wrapped_obj) -> None: # noqa: C901 continue if isinstance(internal_attr, list): - assert isinstance(wrapped_attr, list) - - # We have cases like __all__ that are a list and we want to be certain that - # every value in the list in the internal object is also in the wrapper list - for val in internal_attr: - if isinstance(val, str) and val.startswith("Raw"): - assert val[3:] in wrapped_attr - else: - assert val in wrapped_attr + _check_list_attribute(internal_attr, wrapped_attr) elif hasattr(internal_attr, "__dict__"): # Check all submodules recursively missing_exports(internal_attr, wrapped_attr) diff --git a/src/catalog.rs b/src/catalog.rs index b5fa3da72..398c5881f 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -17,17 +17,16 @@ use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; +use crate::table::PyTable; use crate::utils::{validate_pycapsule, wait_for_future}; use async_trait::async_trait; use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider}; use datafusion::common::DataFusionError; use datafusion::{ - arrow::pyarrow::ToPyArrow, catalog::{CatalogProvider, SchemaProvider}, - datasource::{TableProvider, TableType}, + datasource::TableProvider, }; use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider}; -use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::exceptions::PyKeyError; use pyo3::prelude::*; use pyo3::types::PyCapsule; @@ -48,12 +47,6 @@ pub struct PySchema { pub schema: Arc, } -#[pyclass(frozen, name = "RawTable", module = "datafusion.catalog", subclass)] -#[derive(Clone)] -pub struct PyTable { - pub table: Arc, -} - impl From> for PyCatalog { fn from(catalog: Arc) -> Self { Self { catalog } @@ -66,16 +59,6 @@ impl From> for PySchema { } } -impl PyTable { - pub fn new(table: Arc) -> Self { - Self { table } - } - - pub fn table(&self) -> Arc { - self.table.clone() - } -} - #[pymethods] impl PyCatalog { #[new] @@ -181,7 +164,7 @@ impl PySchema { fn table(&self, name: &str, py: Python) -> PyDataFusionResult { if let Some(table) = wait_for_future(py, self.schema.table(name))?? { - Ok(PyTable::new(table)) + Ok(PyTable::from(table)) } else { Err(PyDataFusionError::Common(format!( "Table not found: {name}" @@ -195,31 +178,12 @@ impl PySchema { Ok(format!("Schema(table_names=[{}])", names.join(";"))) } - fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> { - let provider = if table_provider.hasattr("__datafusion_table_provider__")? { - let capsule = table_provider - .getattr("__datafusion_table_provider__")? - .call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); - Arc::new(provider) as Arc - } else { - match table_provider.extract::() { - Ok(py_table) => py_table.table, - Err(_) => { - let py = table_provider.py(); - let provider = Dataset::new(&table_provider, py)?; - Arc::new(provider) as Arc - } - } - }; + fn register_table(&self, name: &str, table_provider: &Bound<'_, PyAny>) -> PyResult<()> { + let table = PyTable::new(table_provider)?; let _ = self .schema - .register_table(name.to_string(), provider) + .register_table(name.to_string(), table.table) .map_err(py_datafusion_err)?; Ok(()) @@ -235,43 +199,6 @@ impl PySchema { } } -#[pymethods] -impl PyTable { - /// Get a reference to the schema for this table - #[getter] - fn schema(&self, py: Python) -> PyResult { - self.table.schema().to_pyarrow(py) - } - - #[staticmethod] - fn from_dataset(py: Python<'_>, dataset: &Bound<'_, PyAny>) -> PyResult { - let ds = Arc::new(Dataset::new(dataset, py).map_err(py_datafusion_err)?) - as Arc; - - Ok(Self::new(ds)) - } - - /// Get the type of this table for metadata/catalog purposes. - #[getter] - fn kind(&self) -> &str { - match self.table.table_type() { - TableType::Base => "physical", - TableType::View => "view", - TableType::Temporary => "temporary", - } - } - - fn __repr__(&self) -> PyResult { - let kind = self.kind(); - Ok(format!("Table(kind={kind})")) - } - - // fn scan - // fn statistics - // fn has_exact_statistics - // fn supports_filter_pushdown -} - #[derive(Debug)] pub(crate) struct RustWrappedPySchemaProvider { schema_provider: PyObject, @@ -304,30 +231,9 @@ impl RustWrappedPySchemaProvider { return Ok(None); } - if py_table.hasattr("__datafusion_table_provider__")? { - let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); - - Ok(Some(Arc::new(provider) as Arc)) - } else { - if let Ok(inner_table) = py_table.getattr("table") { - if let Ok(inner_table) = inner_table.extract::() { - return Ok(Some(inner_table.table)); - } - } + let table = PyTable::new(&py_table)?; - match py_table.extract::() { - Ok(py_table) => Ok(Some(py_table.table)), - Err(_) => { - let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?; - Ok(Some(Arc::new(ds) as Arc)) - } - } - } + Ok(Some(table.table)) }) } } @@ -368,7 +274,7 @@ impl SchemaProvider for RustWrappedPySchemaProvider { name: String, table: Arc, ) -> datafusion::common::Result>> { - let py_table = PyTable::new(table); + let py_table = PyTable::from(table); Python::with_gil(|py| { let provider = self.schema_provider.bind(py); let _ = provider diff --git a/src/context.rs b/src/context.rs index e3f978ee1..dc18a7676 100644 --- a/src/context.rs +++ b/src/context.rs @@ -31,7 +31,7 @@ use uuid::Uuid; use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; -use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider}; +use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider}; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; @@ -41,6 +41,7 @@ use crate::record_batch::PyRecordBatchStream; use crate::sql::exceptions::py_value_err; use crate::sql::logical::PyLogicalPlan; use crate::store::StorageContexts; +use crate::table::PyTable; use crate::udaf::PyAggregateUDF; use crate::udf::PyScalarUDF; use crate::udtf::PyTableFunction; @@ -71,7 +72,6 @@ use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider}; -use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; use pyo3::IntoPyObjectExt; use tokio::task::JoinHandle; @@ -417,12 +417,7 @@ impl PySessionContext { .with_listing_options(options) .with_schema(resolved_schema); let table = ListingTable::try_new(config)?; - self.register_table( - name, - &PyTable { - table: Arc::new(table), - }, - )?; + self.ctx.register_table(name, Arc::new(table))?; Ok(()) } @@ -599,8 +594,10 @@ impl PySessionContext { Ok(df) } - pub fn register_table(&self, name: &str, table: &PyTable) -> PyDataFusionResult<()> { - self.ctx.register_table(name, table.table())?; + pub fn register_table(&self, name: &str, table: Bound<'_, PyAny>) -> PyDataFusionResult<()> { + let table = PyTable::new(&table)?; + + self.ctx.register_table(name, table.table)?; Ok(()) } @@ -643,23 +640,8 @@ impl PySessionContext { name: &str, provider: Bound<'_, PyAny>, ) -> PyDataFusionResult<()> { - if provider.hasattr("__datafusion_table_provider__")? { - let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); - - let _ = self.ctx.register_table(name, Arc::new(provider))?; - - Ok(()) - } else { - Err(crate::errors::PyDataFusionError::Common( - "__datafusion_table_provider__ does not exist on Table Provider object." - .to_string(), - )) - } + // Deprecated: use `register_table` instead + self.register_table(name, provider) } pub fn register_record_batches( @@ -1094,7 +1076,8 @@ impl PySessionContext { Ok(PyDataFrame::new(df)) } - pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult { + pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult { + let table = PyTable::new(&table)?; let df = self.ctx.read_table(table.table())?; Ok(PyDataFrame::new(df)) } diff --git a/src/dataframe.rs b/src/dataframe.rs index 555a8500d..bfdc35e13 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -31,12 +31,10 @@ use datafusion::arrow::util::pretty; use datafusion::common::UnnestOptions; use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, TableParquetOptions}; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; -use datafusion::datasource::TableProvider; use datafusion::error::DataFusionError; use datafusion::execution::SendableRecordBatchStream; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; use datafusion::prelude::*; -use datafusion_ffi::table_provider::FFI_TableProvider; use futures::{StreamExt, TryStreamExt}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; @@ -44,12 +42,12 @@ use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods}; use tokio::task::JoinHandle; -use crate::catalog::PyTable; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError}; use crate::expr::sort_expr::to_sort_expressions; use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; use crate::sql::logical::PyLogicalPlan; +use crate::table::PyTable; use crate::utils::{ get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future, }; @@ -65,40 +63,6 @@ use parking_lot::Mutex; type CachedBatches = Option<(Vec, bool)>; type SharedCachedBatches = Arc>; -// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116 -// - we have not decided on the table_provider approach yet -// this is an interim implementation -#[pyclass(frozen, name = "TableProvider", module = "datafusion")] -pub struct PyTableProvider { - provider: Arc, -} - -impl PyTableProvider { - pub fn new(provider: Arc) -> Self { - Self { provider } - } - - pub fn as_table(&self) -> PyTable { - let table_provider: Arc = self.provider.clone(); - PyTable::new(table_provider) - } -} - -#[pymethods] -impl PyTableProvider { - fn __datafusion_table_provider__<'py>( - &self, - py: Python<'py>, - ) -> PyResult> { - let name = CString::new("datafusion_table_provider").unwrap(); - - let runtime = get_tokio_runtime().0.handle().clone(); - let provider = FFI_TableProvider::new(Arc::clone(&self.provider), false, Some(runtime)); - - PyCapsule::new(py, provider, Some(name.clone())) - } -} - /// Configuration for DataFrame display formatting #[derive(Debug, Clone)] pub struct FormatterConfig { @@ -309,6 +273,11 @@ impl PyDataFrame { } } + /// Return a clone of the inner Arc for crate-local callers. + pub(crate) fn inner_df(&self) -> Arc { + Arc::clone(&self.df) + } + fn prepare_repr_string(&self, py: Python, as_html: bool) -> PyDataFusionResult { // Get the Python formatter and config let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?; @@ -443,22 +412,18 @@ impl PyDataFrame { PyArrowType(self.df.schema().into()) } - /// Convert this DataFrame into a Table that can be used in register_table + /// Convert this DataFrame into a Table Provider that can be used in register_table /// By convention, into_... methods consume self and return the new object. /// Disabling the clippy lint, so we can use &self /// because we're working with Python bindings /// where objects are shared - /// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116 - /// - we have not decided on the table_provider approach yet #[allow(clippy::wrong_self_convention)] - fn into_view(&self) -> PyDataFusionResult { + pub fn into_view(&self) -> PyDataFusionResult { // Call the underlying Rust DataFrame::into_view method. // Note that the Rust method consumes self; here we clone the inner Arc - // so that we don’t invalidate this PyDataFrame. + // so that we don't invalidate this PyDataFrame. let table_provider = self.df.as_ref().clone().into_view(); - let table_provider = PyTableProvider::new(table_provider); - - Ok(table_provider.as_table()) + Ok(PyTable::from(table_provider)) } #[pyo3(signature = (*args))] diff --git a/src/lib.rs b/src/lib.rs index 29d3f41da..0361c7315 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,7 @@ pub mod pyarrow_util; mod record_batch; pub mod sql; pub mod store; +pub mod table; pub mod unparser; #[cfg(feature = "substrait")] diff --git a/src/table.rs b/src/table.rs new file mode 100644 index 000000000..b830f7764 --- /dev/null +++ b/src/table.rs @@ -0,0 +1,106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::pyarrow::ToPyArrow; +use datafusion::datasource::{TableProvider, TableType}; +use pyo3::prelude::*; +use std::sync::Arc; + +use crate::dataframe::PyDataFrame; +use crate::dataset::Dataset; +use crate::utils::table_provider_from_pycapsule; + +/// This struct is used as a common method for all TableProviders, +/// whether they refer to an FFI provider, an internally known +/// implementation, a dataset, or a dataframe view. +#[pyclass(frozen, name = "RawTable", module = "datafusion.catalog", subclass)] +#[derive(Clone)] +pub struct PyTable { + pub table: Arc, +} + +impl PyTable { + pub fn table(&self) -> Arc { + self.table.clone() + } +} + +#[pymethods] +impl PyTable { + /// Instantiate from any Python object that supports any of the table + /// types. We do not know a priori when using this method if the object + /// will be passed a wrapped or raw class. Here we handle all of the + /// following object types: + /// + /// - PyTable (essentially a clone operation), but either raw or wrapped + /// - DataFrame, either raw or wrapped + /// - FFI Table Providers via PyCapsule + /// - PyArrow Dataset objects + #[new] + pub fn new(obj: &Bound<'_, PyAny>) -> PyResult { + if let Ok(py_table) = obj.extract::() { + Ok(py_table) + } else if let Ok(py_table) = obj + .getattr("_inner") + .and_then(|inner| inner.extract::()) + { + Ok(py_table) + } else if let Ok(py_df) = obj.extract::() { + let provider = py_df.inner_df().as_ref().clone().into_view(); + Ok(PyTable::from(provider)) + } else if let Ok(py_df) = obj + .getattr("df") + .and_then(|inner| inner.extract::()) + { + let provider = py_df.inner_df().as_ref().clone().into_view(); + Ok(PyTable::from(provider)) + } else if let Some(provider) = table_provider_from_pycapsule(obj)? { + Ok(PyTable::from(provider)) + } else { + let py = obj.py(); + let provider = Arc::new(Dataset::new(obj, py)?) as Arc; + Ok(PyTable::from(provider)) + } + } + + /// Get a reference to the schema for this table + #[getter] + fn schema(&self, py: Python) -> PyResult { + self.table.schema().to_pyarrow(py) + } + + /// Get the type of this table for metadata/catalog purposes. + #[getter] + fn kind(&self) -> &str { + match self.table.table_type() { + TableType::Base => "physical", + TableType::View => "view", + TableType::Temporary => "temporary", + } + } + + fn __repr__(&self) -> PyResult { + let kind = self.kind(); + Ok(format!("Table(kind={kind})")) + } +} + +impl From> for PyTable { + fn from(table: Arc) -> Self { + Self { table } + } +} diff --git a/src/udtf.rs b/src/udtf.rs index 55f306b17..f6604e5bc 100644 --- a/src/udtf.rs +++ b/src/udtf.rs @@ -18,16 +18,14 @@ use pyo3::prelude::*; use std::sync::Arc; -use crate::dataframe::PyTableProvider; use crate::errors::{py_datafusion_err, to_datafusion_err}; use crate::expr::PyExpr; +use crate::table::PyTable; use crate::utils::validate_pycapsule; use datafusion::catalog::{TableFunctionImpl, TableProvider}; use datafusion::error::Result as DataFusionResult; use datafusion::logical_expr::Expr; -use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use datafusion_ffi::udtf::{FFI_TableFunction, ForeignTableFunction}; -use pyo3::exceptions::PyNotImplementedError; use pyo3::types::{PyCapsule, PyTuple}; /// Represents a user defined table function @@ -71,11 +69,11 @@ impl PyTableFunction { } #[pyo3(signature = (*args))] - pub fn __call__(&self, args: Vec) -> PyResult { + pub fn __call__(&self, args: Vec) -> PyResult { let args: Vec = args.iter().map(|e| e.expr.clone()).collect(); let table_provider = self.call(&args).map_err(py_datafusion_err)?; - Ok(PyTableProvider::new(table_provider)) + Ok(PyTable::from(table_provider)) } fn __repr__(&self) -> PyResult { @@ -99,20 +97,7 @@ fn call_python_table_function( let provider_obj = func.call1(py, py_args)?; let provider = provider_obj.bind(py); - if provider.hasattr("__datafusion_table_provider__")? { - let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; - let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_provider")?; - - let provider = unsafe { capsule.reference::() }; - let provider: ForeignTableProvider = provider.into(); - - Ok(Arc::new(provider) as Arc) - } else { - Err(PyNotImplementedError::new_err( - "__datafusion_table_provider__ does not exist on Table Provider object.", - )) - } + Ok::, PyErr>(PyTable::new(provider)?.table) }) .map_err(to_datafusion_err) } diff --git a/src/utils.rs b/src/utils.rs index 3b30de5de..0fcfadcea 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -15,18 +15,26 @@ // specific language governing permissions and limitations // under the License. +use crate::errors::py_datafusion_err; use crate::{ common::data_type::PyScalarValue, errors::{PyDataFusionError, PyDataFusionResult}, TokioRuntime, }; use datafusion::{ - common::ScalarValue, execution::context::SessionContext, logical_expr::Volatility, + common::ScalarValue, datasource::TableProvider, execution::context::SessionContext, + logical_expr::Volatility, }; +use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::prelude::*; use pyo3::{exceptions::PyValueError, types::PyCapsule}; -use std::{future::Future, sync::OnceLock, time::Duration}; +use std::{ + future::Future, + sync::{Arc, OnceLock}, + time::Duration, +}; use tokio::{runtime::Runtime, time::sleep}; + /// Utility to get the Tokio Runtime from Python #[inline] pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime { @@ -91,7 +99,7 @@ pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult { "volatile" => Volatility::Volatile, value => { return Err(PyDataFusionError::Common(format!( - "Unsupportad volatility type: `{value}`, supported \ + "Unsupported volatility type: `{value}`, supported \ values are: immutable, stable and volatile." ))) } @@ -101,9 +109,9 @@ pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult { pub(crate) fn validate_pycapsule(capsule: &Bound, name: &str) -> PyResult<()> { let capsule_name = capsule.name()?; if capsule_name.is_none() { - return Err(PyValueError::new_err( - "Expected schema PyCapsule to have name set.", - )); + return Err(PyValueError::new_err(format!( + "Expected {name} PyCapsule to have name set." + ))); } let capsule_name = capsule_name.unwrap().to_str()?; @@ -116,6 +124,23 @@ pub(crate) fn validate_pycapsule(capsule: &Bound, name: &str) -> PyRe Ok(()) } +pub(crate) fn table_provider_from_pycapsule( + obj: &Bound, +) -> PyResult>> { + if obj.hasattr("__datafusion_table_provider__")? { + let capsule = obj.getattr("__datafusion_table_provider__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_table_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignTableProvider = provider.into(); + + Ok(Some(Arc::new(provider))) + } else { + Ok(None) + } +} + pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult { // convert Python object to PyScalarValue to ScalarValue