From 15c3cf9320cc49d3c4f9ea26018ab5dc2627c2e0 Mon Sep 17 00:00:00 2001 From: "Korn, Uwe" Date: Sun, 21 Oct 2018 16:29:30 +0200 Subject: [PATCH] =?UTF-8?q?ARROW-3583:=20[Python/Java]=C2=A0Create=20Recor?= =?UTF-8?q?dBatch=20from=20VectorSchemaRoot?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/pyarrow/jvm.py | 25 +++++ python/pyarrow/tests/test_jvm.py | 165 +++++++++++++++++++++++++++++-- 2 files changed, 181 insertions(+), 9 deletions(-) diff --git a/python/pyarrow/jvm.py b/python/pyarrow/jvm.py index 4a43493865d..28b72711764 100644 --- a/python/pyarrow/jvm.py +++ b/python/pyarrow/jvm.py @@ -276,3 +276,28 @@ def array(jvm_array): for buf in list(jvm_array.getBuffers(False))] null_count = jvm_array.getNullCount() return pa.Array.from_buffers(dtype, length, buffers, null_count) + + +def record_batch(jvm_vector_schema_root): + """ + Construct a (Python) RecordBatch from a JVM VectorSchemaRoot + + Parameters + ---------- + jvm_vector_schema_root : org.apache.arrow.vector.VectorSchemaRoot + + Returns + ------- + record_batch: pyarrow.RecordBatch + """ + pa_schema = schema(jvm_vector_schema_root.getSchema()) + + arrays = [] + for name in pa_schema.names: + arrays.append(array(jvm_vector_schema_root.getVector(name))) + + return pa.RecordBatch.from_arrays( + arrays, + pa_schema.names, + metadata=pa_schema.metadata + ) diff --git a/python/pyarrow/tests/test_jvm.py b/python/pyarrow/tests/test_jvm.py index 3ca874e5631..e0d507ce4b9 100644 --- a/python/pyarrow/tests/test_jvm.py +++ b/python/pyarrow/tests/test_jvm.py @@ -93,7 +93,7 @@ def _jvm_schema(jvm_spec, metadata=None): # om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')() # field = … # Code to instantiate the field # jvm_spec = om.writeValueAsString(field) -@pytest.mark.parametrize('typ,jvm_spec', [ +@pytest.mark.parametrize('pa_type,jvm_spec', [ (pa.null(), '{"name":"null"}'), (pa.bool_(), '{"name":"bool"}'), (pa.int8(), '{"name":"int","bitWidth":8,"isSigned":true}'), @@ -142,7 +142,7 @@ def _jvm_schema(jvm_spec, metadata=None): # pa.dictionary(pa.int32(), pa.array(['a', 'b', 'c'])), ]) @pytest.mark.parametrize('nullable', [True, False]) -def test_jvm_types(root_allocator, typ, jvm_spec, nullable): +def test_jvm_types(root_allocator, pa_type, jvm_spec, nullable): spec = { 'name': 'field_name', 'nullable': nullable, @@ -152,7 +152,7 @@ def test_jvm_types(root_allocator, typ, jvm_spec, nullable): } jvm_field = _jvm_field(json.dumps(spec)) result = pa_jvm.field(jvm_field) - expected_field = pa.field('field_name', typ, nullable=nullable) + expected_field = pa.field('field_name', pa_type, nullable=nullable) assert result == expected_field jvm_schema = _jvm_schema(json.dumps(spec)) @@ -168,7 +168,7 @@ def test_jvm_types(root_allocator, typ, jvm_spec, nullable): # These test parameters mostly use an integer range as an input as this is # often the only type that is understood by both Python and Java # implementations of Arrow. -@pytest.mark.parametrize('typ,data,jvm_type', [ +@pytest.mark.parametrize('pa_type,py_data,jvm_type', [ (pa.bool_(), [True, False, True, True], 'BitVector'), (pa.uint8(), list(range(128)), 'UInt1Vector'), (pa.uint16(), list(range(128)), 'UInt2Vector'), @@ -189,21 +189,168 @@ def test_jvm_types(root_allocator, typ, jvm_spec, nullable): (pa.date64(), list(range(128)), 'DateMilliVector'), # TODO(ARROW-2606): pa.decimal128(19, 4) ]) -def test_jvm_array(root_allocator, typ, data, jvm_type): +def test_jvm_array(root_allocator, pa_type, py_data, jvm_type): # Create vector cls = "org.apache.arrow.vector.{}".format(jvm_type) jvm_vector = jpype.JClass(cls)("vector", root_allocator) - jvm_vector.allocateNew(len(data)) - for i, val in enumerate(data): + jvm_vector.allocateNew(len(py_data)) + for i, val in enumerate(py_data): jvm_vector.setSafe(i, val) - jvm_vector.setValueCount(len(data)) + jvm_vector.setValueCount(len(py_data)) - py_array = pa.array(data, type=typ) + py_array = pa.array(py_data, type=pa_type) jvm_array = pa_jvm.array(jvm_vector) assert py_array.equals(jvm_array) +# These test parameters mostly use an integer range as an input as this is +# often the only type that is understood by both Python and Java +# implementations of Arrow. +@pytest.mark.parametrize('pa_type,py_data,jvm_type,jvm_spec', [ + # TODO: null + (pa.bool_(), [True, False, True, True], 'BitVector', '{"name":"bool"}'), + ( + pa.uint8(), + list(range(128)), + 'UInt1Vector', + '{"name":"int","bitWidth":8,"isSigned":false}' + ), + ( + pa.uint16(), + list(range(128)), + 'UInt2Vector', + '{"name":"int","bitWidth":16,"isSigned":false}' + ), + ( + pa.uint32(), + list(range(128)), + 'UInt4Vector', + '{"name":"int","bitWidth":32,"isSigned":false}' + ), + ( + pa.uint64(), + list(range(128)), + 'UInt8Vector', + '{"name":"int","bitWidth":64,"isSigned":false}' + ), + ( + pa.int8(), + list(range(128)), + 'TinyIntVector', + '{"name":"int","bitWidth":8,"isSigned":true}' + ), + ( + pa.int16(), + list(range(128)), + 'SmallIntVector', + '{"name":"int","bitWidth":16,"isSigned":true}' + ), + ( + pa.int32(), + list(range(128)), + 'IntVector', + '{"name":"int","bitWidth":32,"isSigned":true}' + ), + ( + pa.int64(), + list(range(128)), + 'BigIntVector', + '{"name":"int","bitWidth":64,"isSigned":true}' + ), + # TODO: float16 + ( + pa.float32(), + list(range(128)), + 'Float4Vector', + '{"name":"floatingpoint","precision":"SINGLE"}' + ), + ( + pa.float64(), + list(range(128)), + 'Float8Vector', + '{"name":"floatingpoint","precision":"DOUBLE"}' + ), + ( + pa.timestamp('s'), + list(range(128)), + 'TimeStampSecVector', + '{"name":"timestamp","unit":"SECOND","timezone":null}' + ), + ( + pa.timestamp('ms'), + list(range(128)), + 'TimeStampMilliVector', + '{"name":"timestamp","unit":"MILLISECOND","timezone":null}' + ), + ( + pa.timestamp('us'), + list(range(128)), + 'TimeStampMicroVector', + '{"name":"timestamp","unit":"MICROSECOND","timezone":null}' + ), + ( + pa.timestamp('ns'), + list(range(128)), + 'TimeStampNanoVector', + '{"name":"timestamp","unit":"NANOSECOND","timezone":null}' + ), + # TODO(ARROW-2605): These types miss a conversion from pure Python objects + # * pa.time32('s') + # * pa.time32('ms') + # * pa.time64('us') + # * pa.time64('ns') + ( + pa.date32(), + list(range(128)), + 'DateDayVector', + '{"name":"date","unit":"DAY"}' + ), + ( + pa.date64(), + list(range(128)), + 'DateMilliVector', + '{"name":"date","unit":"MILLISECOND"}' + ), + # TODO(ARROW-2606): pa.decimal128(19, 4) +]) +def test_jvm_record_batch(root_allocator, pa_type, py_data, jvm_type, + jvm_spec): + # Create vector + cls = "org.apache.arrow.vector.{}".format(jvm_type) + jvm_vector = jpype.JClass(cls)("vector", root_allocator) + jvm_vector.allocateNew(len(py_data)) + for i, val in enumerate(py_data): + jvm_vector.setSafe(i, val) + jvm_vector.setValueCount(len(py_data)) + + # Create field + spec = { + 'name': 'field_name', + 'nullable': False, + 'type': json.loads(jvm_spec), + # TODO: This needs to be set for complex types + 'children': [] + } + jvm_field = _jvm_field(json.dumps(spec)) + + # Create VectorSchemaRoot + jvm_fields = jpype.JClass('java.util.ArrayList')() + jvm_fields.add(jvm_field) + jvm_vectors = jpype.JClass('java.util.ArrayList')() + jvm_vectors.add(jvm_vector) + jvm_vsr = jpype.JClass('org.apache.arrow.vector.VectorSchemaRoot') + jvm_vsr = jvm_vsr(jvm_fields, jvm_vectors, len(py_data)) + + py_record_batch = pa.RecordBatch.from_arrays( + [pa.array(py_data, type=pa_type)], + ['col'] + ) + jvm_record_batch = pa_jvm.record_batch(jvm_vsr) + + assert py_record_batch.equals(jvm_record_batch) + + def _string_to_varchar_holder(ra, string): nvch_cls = "org.apache.arrow.vector.holders.NullableVarCharHolder" holder = jpype.JClass(nvch_cls)()