Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions python/pyarrow/jvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,28 @@ def array(jvm_array):
for buf in list(jvm_array.getBuffers(False))]
null_count = jvm_array.getNullCount()
return pa.Array.from_buffers(dtype, length, buffers, null_count)


def record_batch(jvm_vector_schema_root):
"""
Construct a (Python) RecordBatch from a JVM VectorSchemaRoot

Parameters
----------
jvm_vector_schema_root : org.apache.arrow.vector.VectorSchemaRoot

Returns
-------
record_batch: pyarrow.RecordBatch
"""
pa_schema = schema(jvm_vector_schema_root.getSchema())

arrays = []
for name in pa_schema.names:
arrays.append(array(jvm_vector_schema_root.getVector(name)))

return pa.RecordBatch.from_arrays(
arrays,
pa_schema.names,
metadata=pa_schema.metadata
)
165 changes: 156 additions & 9 deletions python/pyarrow/tests/test_jvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _jvm_schema(jvm_spec, metadata=None):
# om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')()
# field = … # Code to instantiate the field
# jvm_spec = om.writeValueAsString(field)
@pytest.mark.parametrize('typ,jvm_spec', [
@pytest.mark.parametrize('pa_type,jvm_spec', [
(pa.null(), '{"name":"null"}'),
(pa.bool_(), '{"name":"bool"}'),
(pa.int8(), '{"name":"int","bitWidth":8,"isSigned":true}'),
Expand Down Expand Up @@ -142,7 +142,7 @@ def _jvm_schema(jvm_spec, metadata=None):
# pa.dictionary(pa.int32(), pa.array(['a', 'b', 'c'])),
])
@pytest.mark.parametrize('nullable', [True, False])
def test_jvm_types(root_allocator, typ, jvm_spec, nullable):
def test_jvm_types(root_allocator, pa_type, jvm_spec, nullable):
spec = {
'name': 'field_name',
'nullable': nullable,
Expand All @@ -152,7 +152,7 @@ def test_jvm_types(root_allocator, typ, jvm_spec, nullable):
}
jvm_field = _jvm_field(json.dumps(spec))
result = pa_jvm.field(jvm_field)
expected_field = pa.field('field_name', typ, nullable=nullable)
expected_field = pa.field('field_name', pa_type, nullable=nullable)
assert result == expected_field

jvm_schema = _jvm_schema(json.dumps(spec))
Expand All @@ -168,7 +168,7 @@ def test_jvm_types(root_allocator, typ, jvm_spec, nullable):
# These test parameters mostly use an integer range as an input as this is
# often the only type that is understood by both Python and Java
# implementations of Arrow.
@pytest.mark.parametrize('typ,data,jvm_type', [
@pytest.mark.parametrize('pa_type,py_data,jvm_type', [
(pa.bool_(), [True, False, True, True], 'BitVector'),
(pa.uint8(), list(range(128)), 'UInt1Vector'),
(pa.uint16(), list(range(128)), 'UInt2Vector'),
Expand All @@ -189,21 +189,168 @@ def test_jvm_types(root_allocator, typ, jvm_spec, nullable):
(pa.date64(), list(range(128)), 'DateMilliVector'),
# TODO(ARROW-2606): pa.decimal128(19, 4)
])
def test_jvm_array(root_allocator, typ, data, jvm_type):
def test_jvm_array(root_allocator, pa_type, py_data, jvm_type):
# Create vector
cls = "org.apache.arrow.vector.{}".format(jvm_type)
jvm_vector = jpype.JClass(cls)("vector", root_allocator)
jvm_vector.allocateNew(len(data))
for i, val in enumerate(data):
jvm_vector.allocateNew(len(py_data))
for i, val in enumerate(py_data):
jvm_vector.setSafe(i, val)
jvm_vector.setValueCount(len(data))
jvm_vector.setValueCount(len(py_data))

py_array = pa.array(data, type=typ)
py_array = pa.array(py_data, type=pa_type)
jvm_array = pa_jvm.array(jvm_vector)

assert py_array.equals(jvm_array)


# These test parameters mostly use an integer range as an input as this is
# often the only type that is understood by both Python and Java
# implementations of Arrow.
@pytest.mark.parametrize('pa_type,py_data,jvm_type,jvm_spec', [
# TODO: null
(pa.bool_(), [True, False, True, True], 'BitVector', '{"name":"bool"}'),
(
pa.uint8(),
list(range(128)),
'UInt1Vector',
'{"name":"int","bitWidth":8,"isSigned":false}'
),
(
pa.uint16(),
list(range(128)),
'UInt2Vector',
'{"name":"int","bitWidth":16,"isSigned":false}'
),
(
pa.uint32(),
list(range(128)),
'UInt4Vector',
'{"name":"int","bitWidth":32,"isSigned":false}'
),
(
pa.uint64(),
list(range(128)),
'UInt8Vector',
'{"name":"int","bitWidth":64,"isSigned":false}'
),
(
pa.int8(),
list(range(128)),
'TinyIntVector',
'{"name":"int","bitWidth":8,"isSigned":true}'
),
(
pa.int16(),
list(range(128)),
'SmallIntVector',
'{"name":"int","bitWidth":16,"isSigned":true}'
),
(
pa.int32(),
list(range(128)),
'IntVector',
'{"name":"int","bitWidth":32,"isSigned":true}'
),
(
pa.int64(),
list(range(128)),
'BigIntVector',
'{"name":"int","bitWidth":64,"isSigned":true}'
),
# TODO: float16
(
pa.float32(),
list(range(128)),
'Float4Vector',
'{"name":"floatingpoint","precision":"SINGLE"}'
),
(
pa.float64(),
list(range(128)),
'Float8Vector',
'{"name":"floatingpoint","precision":"DOUBLE"}'
),
(
pa.timestamp('s'),
list(range(128)),
'TimeStampSecVector',
'{"name":"timestamp","unit":"SECOND","timezone":null}'
),
(
pa.timestamp('ms'),
list(range(128)),
'TimeStampMilliVector',
'{"name":"timestamp","unit":"MILLISECOND","timezone":null}'
),
(
pa.timestamp('us'),
list(range(128)),
'TimeStampMicroVector',
'{"name":"timestamp","unit":"MICROSECOND","timezone":null}'
),
(
pa.timestamp('ns'),
list(range(128)),
'TimeStampNanoVector',
'{"name":"timestamp","unit":"NANOSECOND","timezone":null}'
),
# TODO(ARROW-2605): These types miss a conversion from pure Python objects
# * pa.time32('s')
# * pa.time32('ms')
# * pa.time64('us')
# * pa.time64('ns')
(
pa.date32(),
list(range(128)),
'DateDayVector',
'{"name":"date","unit":"DAY"}'
),
(
pa.date64(),
list(range(128)),
'DateMilliVector',
'{"name":"date","unit":"MILLISECOND"}'
),
# TODO(ARROW-2606): pa.decimal128(19, 4)
])
def test_jvm_record_batch(root_allocator, pa_type, py_data, jvm_type,
jvm_spec):
# Create vector
cls = "org.apache.arrow.vector.{}".format(jvm_type)
jvm_vector = jpype.JClass(cls)("vector", root_allocator)
jvm_vector.allocateNew(len(py_data))
for i, val in enumerate(py_data):
jvm_vector.setSafe(i, val)
jvm_vector.setValueCount(len(py_data))

# Create field
spec = {
'name': 'field_name',
'nullable': False,
'type': json.loads(jvm_spec),
# TODO: This needs to be set for complex types
'children': []
}
jvm_field = _jvm_field(json.dumps(spec))

# Create VectorSchemaRoot
jvm_fields = jpype.JClass('java.util.ArrayList')()
jvm_fields.add(jvm_field)
jvm_vectors = jpype.JClass('java.util.ArrayList')()
jvm_vectors.add(jvm_vector)
jvm_vsr = jpype.JClass('org.apache.arrow.vector.VectorSchemaRoot')
jvm_vsr = jvm_vsr(jvm_fields, jvm_vectors, len(py_data))

py_record_batch = pa.RecordBatch.from_arrays(
[pa.array(py_data, type=pa_type)],
['col']
)
jvm_record_batch = pa_jvm.record_batch(jvm_vsr)

assert py_record_batch.equals(jvm_record_batch)


def _string_to_varchar_holder(ra, string):
nvch_cls = "org.apache.arrow.vector.holders.NullableVarCharHolder"
holder = jpype.JClass(nvch_cls)()
Expand Down