Skip to content

Commit

Permalink
GH-34884: [Python]: Support pickling pyarrow.dataset Partitioning sub…
Browse files Browse the repository at this point in the history
…classes (#36462)

### Rationale for this change

Add support for pickling Directory/Hive/FilenamePartitioning objects.

Does not yet actually fix the issue #34884, because this PR only addresses the actual Partitioning subclasses, and not the PartitioningFactory subclasses.

### Are these changes tested?

Yes

### Are there any user-facing changes?

Only new support for pickling and `==` operation.
* Issue: #34884

Lead-authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
Co-authored-by: Weston Pace <weston.pace@gmail.com>
Signed-off-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
  • Loading branch information
jorisvandenbossche and westonpace committed Jul 7, 2023
1 parent 375f3d9 commit dd36705
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 31 deletions.
2 changes: 2 additions & 0 deletions cpp/src/arrow/dataset/partition.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning {

const ArrayVector& dictionaries() const { return dictionaries_; }

SegmentEncoding segment_encoding() const { return options_.segment_encoding; }

bool Equals(const Partitioning& other) const override;

protected:
Expand Down
40 changes: 38 additions & 2 deletions python/pyarrow/_dataset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,20 @@ cdef CFileSource _make_file_source(object file, FileSystem filesystem=None):

cdef CSegmentEncoding _get_segment_encoding(str segment_encoding):
if segment_encoding == "none":
return CSegmentEncodingNone
return CSegmentEncoding_None
elif segment_encoding == "uri":
return CSegmentEncodingUri
return CSegmentEncoding_Uri
raise ValueError(f"Unknown segment encoding: {segment_encoding}")


cdef str _wrap_segment_encoding(CSegmentEncoding segment_encoding):
if segment_encoding == CSegmentEncoding_None:
return "none"
elif segment_encoding == CSegmentEncoding_Uri:
return "uri"
raise ValueError("Unknown segment encoding")


cdef Expression _true = Expression._scalar(True)


Expand Down Expand Up @@ -2339,6 +2347,12 @@ cdef class Partitioning(_Weakrefable):
cdef inline shared_ptr[CPartitioning] unwrap(self):
return self.wrapped

def __eq__(self, other):
try:
return self.partitioning.Equals(deref((<Partitioning>other).unwrap()))
except TypeError:
return False

def parse(self, path):
cdef CResult[CExpression] result
result = self.partitioning.Parse(tobytes(path))
Expand Down Expand Up @@ -2393,6 +2407,7 @@ cdef vector[shared_ptr[CArray]] _partitioning_dictionaries(

return c_dictionaries


cdef class KeyValuePartitioning(Partitioning):

cdef:
Expand All @@ -2407,6 +2422,15 @@ cdef class KeyValuePartitioning(Partitioning):
self.wrapped = sp
self.partitioning = sp.get()

def __reduce__(self):
dictionaries = self.dictionaries
if dictionaries:
dictionaries = dict(zip(self.schema.names, dictionaries))
segment_encoding = _wrap_segment_encoding(
deref(self.keyvalue_partitioning).segment_encoding()
)
return self.__class__, (self.schema, dictionaries, segment_encoding)

@property
def dictionaries(self):
"""
Expand Down Expand Up @@ -2620,6 +2644,18 @@ cdef class HivePartitioning(KeyValuePartitioning):
KeyValuePartitioning.init(self, sp)
self.hive_partitioning = <CHivePartitioning*> sp.get()

def __reduce__(self):
dictionaries = self.dictionaries
if dictionaries:
dictionaries = dict(zip(self.schema.names, dictionaries))
segment_encoding = _wrap_segment_encoding(
deref(self.keyvalue_partitioning).segment_encoding()
)
null_fallback = frombytes(deref(self.hive_partitioning).null_fallback())
return HivePartitioning, (
self.schema, dictionaries, null_fallback, segment_encoding
)

@staticmethod
def discover(infer_dictionary=False,
max_partition_dictionary_size=0,
Expand Down
9 changes: 6 additions & 3 deletions python/pyarrow/includes/libarrow_dataset.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,14 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil:
c_string type_name() const
CResult[CExpression] Parse(const c_string & path) const
const shared_ptr[CSchema] & schema()
c_bool Equals(const CPartitioning& other) const

cdef cppclass CSegmentEncoding" arrow::dataset::SegmentEncoding":
pass
bint operator==(CSegmentEncoding)

CSegmentEncoding CSegmentEncodingNone\
CSegmentEncoding CSegmentEncoding_None\
" arrow::dataset::SegmentEncoding::None"
CSegmentEncoding CSegmentEncodingUri\
CSegmentEncoding CSegmentEncoding_Uri\
" arrow::dataset::SegmentEncoding::Uri"

cdef cppclass CKeyValuePartitioningOptions \
Expand Down Expand Up @@ -329,6 +330,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil:
CKeyValuePartitioningOptions options)

vector[shared_ptr[CArray]] dictionaries() const
CSegmentEncoding segment_encoding()

cdef cppclass CDirectoryPartitioning \
"arrow::dataset::DirectoryPartitioning"(CPartitioning):
Expand All @@ -352,6 +354,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil:
CHivePartitioningFactoryOptions)

vector[shared_ptr[CArray]] dictionaries() const
c_string null_fallback() const

cdef cppclass CFilenamePartitioning \
"arrow::dataset::FilenamePartitioning"(CPartitioning):
Expand Down
77 changes: 51 additions & 26 deletions python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,13 +588,13 @@ def test_partitioning():
ds.FilenamePartitioning]:
partitioning = klass(schema)
assert isinstance(partitioning, ds.Partitioning)
assert partitioning == klass(schema)

partitioning = ds.DirectoryPartitioning(
pa.schema([
pa.field('group', pa.int64()),
pa.field('key', pa.float64())
])
)
schema = pa.schema([
pa.field('group', pa.int64()),
pa.field('key', pa.float64())
])
partitioning = ds.DirectoryPartitioning(schema)
assert len(partitioning.dictionaries) == 2
assert all(x is None for x in partitioning.dictionaries)
expr = partitioning.parse('/3/3.14/')
Expand All @@ -610,13 +610,13 @@ def test_partitioning():
expected = ds.field('group') == 3
assert expr.equals(expected)

partitioning = ds.HivePartitioning(
pa.schema([
pa.field('alpha', pa.int64()),
pa.field('beta', pa.int64())
]),
null_fallback='xyz'
)
assert partitioning != ds.DirectoryPartitioning(schema, segment_encoding="none")

schema = pa.schema([
pa.field('alpha', pa.int64()),
pa.field('beta', pa.int64())
])
partitioning = ds.HivePartitioning(schema, null_fallback='xyz')
assert len(partitioning.dictionaries) == 2
assert all(x is None for x in partitioning.dictionaries)
expr = partitioning.parse('/alpha=0/beta=3/')
Expand All @@ -636,12 +636,13 @@ def test_partitioning():
with pytest.raises(pa.ArrowInvalid):
partitioning.parse(shouldfail)

partitioning = ds.FilenamePartitioning(
pa.schema([
pa.field('group', pa.int64()),
pa.field('key', pa.float64())
])
)
assert partitioning != ds.HivePartitioning(schema, null_fallback='other')

schema = pa.schema([
pa.field('group', pa.int64()),
pa.field('key', pa.float64())
])
partitioning = ds.FilenamePartitioning(schema)
assert len(partitioning.dictionaries) == 2
assert all(x is None for x in partitioning.dictionaries)
expr = partitioning.parse('3_3.14_')
Expand All @@ -653,17 +654,19 @@ def test_partitioning():
with pytest.raises(pa.ArrowInvalid):
partitioning.parse('prefix_3_aaa_')

assert partitioning != ds.FilenamePartitioning(schema, segment_encoding="none")

schema = pa.schema([
pa.field('group', pa.int64()),
pa.field('key', pa.dictionary(pa.int8(), pa.string()))
])
partitioning = ds.DirectoryPartitioning(
pa.schema([
pa.field('group', pa.int64()),
pa.field('key', pa.dictionary(pa.int8(), pa.string()))
]),
dictionaries={
"key": pa.array(["first", "second", "third"]),
})
schema, dictionaries={"key": pa.array(["first", "second", "third"])}
)
assert partitioning.dictionaries[0] is None
assert partitioning.dictionaries[1].to_pylist() == [
"first", "second", "third"]
assert partitioning != ds.DirectoryPartitioning(schema, dictionaries=None)

partitioning = ds.FilenamePartitioning(
pa.schema([
Expand Down Expand Up @@ -696,6 +699,24 @@ def test_partitioning():
assert load_back_table.equals(table)


def test_partitioning_pickling():
schema = pa.schema([
pa.field('i64', pa.int64()),
pa.field('f64', pa.float64())
])
parts = [
ds.DirectoryPartitioning(schema),
ds.HivePartitioning(schema),
ds.FilenamePartitioning(schema),
ds.DirectoryPartitioning(schema, segment_encoding="none"),
ds.FilenamePartitioning(schema, segment_encoding="none"),
ds.HivePartitioning(schema, segment_encoding="none", null_fallback="xyz"),
]

for part in parts:
assert pickle.loads(pickle.dumps(part)) == part


def test_expression_arithmetic_operators():
dataset = ds.dataset(pa.table({'a': [1, 2, 3], 'b': [2, 2, 2]}))
a = ds.field("a")
Expand Down Expand Up @@ -3740,6 +3761,10 @@ def test_dataset_preserved_partitioning(tempdir):
_, path = _create_single_file(tempdir)
dataset = ds.dataset(path)
assert isinstance(dataset.partitioning, ds.DirectoryPartitioning)
# TODO(GH-34884) partitioning attribute not preserved in pickling
# dataset_ = ds.dataset(path)
# for dataset in [dataset_, pickle.loads(pickle.dumps(dataset_))]:
# assert isinstance(dataset.partitioning, ds.DirectoryPartitioning)

# through discovery, with hive partitioning but not specified
full_table, path = _create_partitioned_dataset(tempdir)
Expand Down

0 comments on commit dd36705

Please sign in to comment.