Skip to content

Commit

Permalink
ARROW-8292: [Python] Allow to manually specify schema in dataset() fu…
Browse files Browse the repository at this point in the history
…nction

This just needed to be passed through to enable the functionality in `dataset()`

Closes #6788 from jorisvandenbossche/ARROW-8292

Authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
Signed-off-by: Neal Richardson <neal.p.richardson@gmail.com>
  • Loading branch information
jorisvandenbossche authored and nealrichardson committed Apr 2, 2020
1 parent 67cd34a commit db81f0a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
9 changes: 6 additions & 3 deletions python/pyarrow/dataset.py
Expand Up @@ -298,7 +298,7 @@ def _ensure_factory(src, **kwargs):


def dataset(paths_or_factories, filesystem=None, partitioning=None,
format=None):
format=None, schema=None):
"""
Open a dataset.
Expand All @@ -317,6 +317,9 @@ def dataset(paths_or_factories, filesystem=None, partitioning=None,
field names a DirectionaryPartitioning will be inferred.
format : str
Currently only "parquet" is supported.
schema : Schema, optional
Optionally provide the Schema for the Dataset, in which case it will
not be inferred from the source.
Returns
-------
Expand Down Expand Up @@ -347,8 +350,8 @@ def dataset(paths_or_factories, filesystem=None, partitioning=None,

factories = [_ensure_factory(f, **kwargs) for f in paths_or_factories]
if single_dataset:
return factories[0].finish()
return UnionDatasetFactory(factories).finish()
return factories[0].finish(schema=schema)
return UnionDatasetFactory(factories).finish(schema=schema)


def field(name):
Expand Down
50 changes: 50 additions & 0 deletions python/pyarrow/tests/test_dataset.py
Expand Up @@ -1002,6 +1002,56 @@ def test_multiple_factories_with_selectors(multisourcefs):
assert dataset.schema.equals(expected_schema)


@pytest.mark.parquet
def test_specified_schema(tempdir):
import pyarrow.parquet as pq

table = pa.table({'a': [1, 2, 3], 'b': [.1, .2, .3]})
pq.write_table(table, tempdir / "data.parquet")

def _check_dataset(schema, expected, expected_schema=None):
dataset = ds.dataset(str(tempdir / "data.parquet"), schema=schema)
if expected_schema is not None:
assert dataset.schema.equals(expected_schema)
else:
assert dataset.schema.equals(schema)
result = dataset.to_table()
assert result.equals(expected)

# no schema specified
schema = None
expected = table
_check_dataset(schema, expected, expected_schema=table.schema)

# identical schema specified
schema = table.schema
expected = table
_check_dataset(schema, expected)

# Specifying schema with change column order
schema = pa.schema([('b', 'float64'), ('a', 'int64')])
expected = pa.table([[.1, .2, .3], [1, 2, 3]], names=['b', 'a'])
_check_dataset(schema, expected)

# Specifying schema with missing column
schema = pa.schema([('a', 'int64')])
expected = pa.table({'a': [1, 2, 3]})
_check_dataset(schema, expected)

# Specifying schema with additional column
schema = pa.schema([('a', 'int64'), ('c', 'int32')])
expected = pa.table({'a': [1, 2, 3],
'c': pa.array([None, None, None], type='int32')})
_check_dataset(schema, expected)

# Specifying with incompatible schema
schema = pa.schema([('a', 'int32'), ('b', 'float64')])
dataset = ds.dataset(str(tempdir / "data.parquet"), schema=schema)
assert dataset.schema.equals(schema)
with pytest.raises(TypeError):
dataset.to_table()


def test_ipc_format(tempdir):
table = pa.table({'a': pa.array([1, 2, 3], type="int8"),
'b': pa.array([.1, .2, .3], type="float64")})
Expand Down

0 comments on commit db81f0a

Please sign in to comment.