Skip to content

Commit

Permalink
ARROW-3903: [Python] Random array generator for Arrow conversion and …
Browse files Browse the repository at this point in the history
…Parquet testing

Generate random schemas, arrays, chunked_arrays, columns, record_batches and tables.
Slow, but makes quiet easy to isolate corner cases (already created jira issues). In follow up PRs We should use these strategies to increase the coverage. It'll enable us to reduce the issues, We could even use it for generate benchmark datasets periodically (only if We persist somewhere).

Example usage:

Run 10 samples (dev profile):
`pytest -sv pyarrow/tests/test_strategies.py::test_tables --enable-hypothesis --hypothesis-show-statistics --hypothesis-profile=dev`

Print the generated examples (debug):
`pytest -sv pyarrow/tests/test_strategies.py::test_schemas --enable-hypothesis --hypothesis-show-statistics --hypothesis-profile=debug`

Author: Krisztián Szűcs <szucs.krisztian@gmail.com>

Closes #3301 from kszucs/ARROW-3903 and squashes the following commits:

ff6654c <Krisztián Szűcs> finalize
8b5e7ea <Krisztián Szűcs> rat
61fe01d <Krisztián Szűcs> strategies for chunked_arrays, columns, record batches; test the strategies themselves
bdb63df <Krisztián Szűcs> hypothesis array strategy
  • Loading branch information
kszucs authored and xhochy committed Feb 8, 2019
1 parent 74f3f5f commit d06c664
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 18 deletions.
6 changes: 3 additions & 3 deletions python/pyarrow/table.pxi
Expand Up @@ -1155,9 +1155,9 @@ cdef class Table(_PandasConvertible):
Parameters
----------
arrays: list of pyarrow.Array or pyarrow.Column
arrays : list of pyarrow.Array or pyarrow.Column
Equal-length arrays that should form the table.
names: list of str, optional
names : list of str, optional
Names for the table columns. If Columns passed, will be
inferred. If Arrays passed, this argument is required
schema : Schema, default None
Expand Down Expand Up @@ -1224,7 +1224,7 @@ cdef class Table(_PandasConvertible):
Parameters
----------
batches: sequence or iterator of RecordBatch
batches : sequence or iterator of RecordBatch
Sequence of RecordBatch to be converted, all schemas must be equal
schema : Schema, default None
If not passed, will be inferred from the first RecordBatch
Expand Down
158 changes: 143 additions & 15 deletions python/pyarrow/tests/strategies.py
Expand Up @@ -15,8 +15,14 @@
# specific language governing permissions and limitations
# under the License.

import pyarrow as pa
import pytz
import hypothesis as h
import hypothesis.strategies as st
import hypothesis.extra.numpy as npst
import hypothesis.extra.pytz as tzst
import numpy as np

import pyarrow as pa


# TODO(kszucs): alphanum_text, surrogate_text
Expand Down Expand Up @@ -69,12 +75,11 @@
pa.time64('us'),
pa.time64('ns')
])
timestamp_types = st.sampled_from([
pa.timestamp('s'),
pa.timestamp('ms'),
pa.timestamp('us'),
pa.timestamp('ns')
])
timestamp_types = st.builds(
pa.timestamp,
unit=st.sampled_from(['s', 'ms', 'us', 'ns']),
tz=tzst.timezones()
)
temporal_types = st.one_of(date_types, time_types, timestamp_types)

primitive_types = st.one_of(
Expand Down Expand Up @@ -106,20 +111,21 @@ def complex_types(inner_strategy=primitive_types):
return list_types(inner_strategy) | struct_types(inner_strategy)


def nested_list_types(item_strategy=primitive_types):
return st.recursive(item_strategy, list_types)
def nested_list_types(item_strategy=primitive_types, max_leaves=3):
return st.recursive(item_strategy, list_types, max_leaves=max_leaves)


def nested_struct_types(item_strategy=primitive_types):
return st.recursive(item_strategy, struct_types)
def nested_struct_types(item_strategy=primitive_types, max_leaves=3):
return st.recursive(item_strategy, struct_types, max_leaves=max_leaves)


def nested_complex_types(inner_strategy=primitive_types):
return st.recursive(inner_strategy, complex_types)
def nested_complex_types(inner_strategy=primitive_types, max_leaves=3):
return st.recursive(inner_strategy, complex_types, max_leaves=max_leaves)


def schemas(type_strategy=primitive_types):
return st.builds(pa.schema, st.lists(fields(type_strategy)))
def schemas(type_strategy=primitive_types, max_fields=None):
children = st.lists(fields(type_strategy), max_size=max_fields)
return st.builds(pa.schema, children)


complex_schemas = schemas(complex_types())
Expand All @@ -128,3 +134,125 @@ def schemas(type_strategy=primitive_types):
all_types = st.one_of(primitive_types, complex_types(), nested_complex_types())
all_fields = fields(all_types)
all_schemas = schemas(all_types)


_default_array_sizes = st.integers(min_value=0, max_value=20)


@st.composite
def arrays(draw, type, size=None):
if isinstance(type, st.SearchStrategy):
type = draw(type)
elif not isinstance(type, pa.DataType):
raise TypeError('Type must be a pyarrow DataType')

if isinstance(size, st.SearchStrategy):
size = draw(size)
elif size is None:
size = draw(_default_array_sizes)
elif not isinstance(size, int):
raise TypeError('Size must be an integer')

shape = (size,)

if pa.types.is_list(type):
offsets = draw(npst.arrays(np.uint8(), shape=shape)).cumsum() // 20
offsets = np.insert(offsets, 0, 0, axis=0) # prepend with zero
values = draw(arrays(type.value_type, size=int(offsets.sum())))
return pa.ListArray.from_arrays(offsets, values)

if pa.types.is_struct(type):
h.assume(len(type) > 0)
names, child_arrays = [], []
for field in type:
names.append(field.name)
child_arrays.append(draw(arrays(field.type, size=size)))
# fields' metadata are lost here, because from_arrays doesn't accept
# a fields argumentum, only names
return pa.StructArray.from_arrays(child_arrays, names=names)

if (pa.types.is_boolean(type) or pa.types.is_integer(type) or
pa.types.is_floating(type)):
values = npst.arrays(type.to_pandas_dtype(), shape=(size,))
return pa.array(draw(values), type=type)

if pa.types.is_null(type):
value = st.none()
elif pa.types.is_time(type):
value = st.times()
elif pa.types.is_date(type):
value = st.dates()
elif pa.types.is_timestamp(type):
tz = pytz.timezone(type.tz) if type.tz is not None else None
value = st.datetimes(timezones=st.just(tz))
elif pa.types.is_binary(type):
value = st.binary()
elif pa.types.is_string(type):
value = st.text()
elif pa.types.is_decimal(type):
# TODO(kszucs): properly limit the precision
# value = st.decimals(places=type.scale, allow_infinity=False)
h.reject()
else:
raise NotImplementedError(type)

values = st.lists(value, min_size=size, max_size=size)
return pa.array(draw(values), type=type)


@st.composite
def chunked_arrays(draw, type, min_chunks=0, max_chunks=None, chunk_size=None):
if isinstance(type, st.SearchStrategy):
type = draw(type)

# TODO(kszucs): remove it, field metadata is not kept
h.assume(not pa.types.is_struct(type))

chunk = arrays(type, size=chunk_size)
chunks = st.lists(chunk, min_size=min_chunks, max_size=max_chunks)

return pa.chunked_array(draw(chunks), type=type)


def columns(type, min_chunks=0, max_chunks=None, chunk_size=None):
chunked_array = chunked_arrays(type, chunk_size=chunk_size,
min_chunks=min_chunks,
max_chunks=max_chunks)
return st.builds(pa.column, st.text(), chunked_array)


@st.composite
def record_batches(draw, type, rows=None, max_fields=None):
if isinstance(rows, st.SearchStrategy):
rows = draw(rows)
elif rows is None:
rows = draw(_default_array_sizes)
elif not isinstance(rows, int):
raise TypeError('Rows must be an integer')

schema = draw(schemas(type, max_fields=max_fields))
children = [draw(arrays(field.type, size=rows)) for field in schema]
# TODO(kszucs): the names and schame arguments are not consistent with
# Table.from_array's arguments
return pa.RecordBatch.from_arrays(children, names=schema)


@st.composite
def tables(draw, type, rows=None, max_fields=None):
if isinstance(rows, st.SearchStrategy):
rows = draw(rows)
elif rows is None:
rows = draw(_default_array_sizes)
elif not isinstance(rows, int):
raise TypeError('Rows must be an integer')

schema = draw(schemas(type, max_fields=max_fields))
children = [draw(arrays(field.type, size=rows)) for field in schema]
return pa.Table.from_arrays(children, schema=schema)


all_arrays = arrays(all_types)
all_chunked_arrays = chunked_arrays(all_types)
all_columns = columns(all_types)
all_record_batches = record_batches(all_types)
all_tables = tables(all_types)
15 changes: 15 additions & 0 deletions python/pyarrow/tests/test_array.py
Expand Up @@ -18,6 +18,8 @@

import collections
import datetime
import hypothesis as h
import hypothesis.strategies as st
import pickle
import pytest
import struct
Expand All @@ -32,6 +34,7 @@
pickle5 = None

import pyarrow as pa
import pyarrow.tests.strategies as past
from pyarrow.pandas_compat import get_logical_type


Expand Down Expand Up @@ -802,6 +805,18 @@ def test_array_pickle(data, typ):
assert array.equals(result)


@h.given(
past.arrays(
past.all_types,
size=st.integers(min_value=0, max_value=10)
)
)
def test_pickling(arr):
data = pickle.dumps(arr)
restored = pickle.loads(data)
assert arr.equals(restored)


@pickle_test_parametrize
def test_array_pickle5(data, typ):
# Test zero-copy pickling with protocol 5 (PEP 574)
Expand Down
61 changes: 61 additions & 0 deletions python/pyarrow/tests/test_strategies.py
@@ -0,0 +1,61 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import hypothesis as h

import pyarrow as pa
import pyarrow.tests.strategies as past


@h.given(past.all_types)
def test_types(ty):
assert isinstance(ty, pa.lib.DataType)


@h.given(past.all_fields)
def test_fields(field):
assert isinstance(field, pa.lib.Field)


@h.given(past.all_schemas)
def test_schemas(schema):
assert isinstance(schema, pa.lib.Schema)


@h.given(past.all_arrays)
def test_arrays(array):
assert isinstance(array, pa.lib.Array)


@h.given(past.all_chunked_arrays)
def test_chunked_arrays(chunked_array):
assert isinstance(chunked_array, pa.lib.ChunkedArray)


@h.given(past.all_columns)
def test_columns(column):
assert isinstance(column, pa.lib.Column)


@h.given(past.all_record_batches)
def test_record_batches(record_bath):
assert isinstance(record_bath, pa.lib.RecordBatch)


@h.given(past.all_tables)
def test_tables(table):
assert isinstance(table, pa.lib.Table)

0 comments on commit d06c664

Please sign in to comment.