diff --git a/python/docs/source/user_guide/arrow_pandas.rst b/python/docs/source/user_guide/arrow_pandas.rst index fe04315f87ad5..91d8155523391 100644 --- a/python/docs/source/user_guide/arrow_pandas.rst +++ b/python/docs/source/user_guide/arrow_pandas.rst @@ -341,8 +341,9 @@ Supported SQL Types .. currentmodule:: pyspark.sql.types -Currently, all Spark SQL data types are supported by Arrow-based conversion except :class:`MapType`, +Currently, all Spark SQL data types are supported by Arrow-based conversion except :class:`ArrayType` of :class:`TimestampType`, and nested :class:`StructType`. +:class: `MapType` is only supported when using PyArrow 2.0.0 and above. Setting Arrow Batch Size ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 3456c12e59c09..d8a241417532e 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -22,7 +22,7 @@ from pyspark.sql.pandas.serializers import ArrowCollectSerializer from pyspark.sql.types import IntegralType from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \ - DoubleType, BooleanType, TimestampType, StructType, DataType + DoubleType, BooleanType, MapType, TimestampType, StructType, DataType from pyspark.traceback_utils import SCCallSiteSync @@ -100,7 +100,8 @@ def toPandas(self): # of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled. if use_arrow: try: - from pyspark.sql.pandas.types import _check_series_localize_timestamps + from pyspark.sql.pandas.types import _check_series_localize_timestamps, \ + _convert_map_items_to_dict import pyarrow # Rename columns to avoid duplicated column names. tmp_column_names = ['col_{}'.format(i) for i in range(len(self.columns))] @@ -117,6 +118,9 @@ def toPandas(self): if isinstance(field.dataType, TimestampType): pdf[field.name] = \ _check_series_localize_timestamps(pdf[field.name], timezone) + elif isinstance(field.dataType, MapType): + pdf[field.name] = \ + _convert_map_items_to_dict(pdf[field.name]) return pdf else: return pd.DataFrame.from_records([], columns=self.columns) diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 16462e8702a0b..750aa4b0e6c56 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -284,7 +284,6 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: should be checked for accuracy by users. Currently, - :class:`pyspark.sql.types.MapType`, :class:`pyspark.sql.types.ArrayType` of :class:`pyspark.sql.types.TimestampType` and nested :class:`pyspark.sql.types.StructType` are currently not supported as output types. diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 73d36ee555fb5..2dcfdc1046049 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -117,7 +117,8 @@ def __init__(self, timezone, safecheck, assign_cols_by_name): self._assign_cols_by_name = assign_cols_by_name def arrow_to_pandas(self, arrow_column): - from pyspark.sql.pandas.types import _check_series_localize_timestamps + from pyspark.sql.pandas.types import _check_series_localize_timestamps, \ + _convert_map_items_to_dict import pyarrow # If the given column is a date type column, creates a series of datetime.date directly @@ -127,6 +128,8 @@ def arrow_to_pandas(self, arrow_column): if pyarrow.types.is_timestamp(arrow_column.type): return _check_series_localize_timestamps(s, self._timezone) + elif pyarrow.types.is_map(arrow_column.type): + return _convert_map_items_to_dict(s) else: return s @@ -147,7 +150,8 @@ def _create_batch(self, series): """ import pandas as pd import pyarrow as pa - from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal + from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal, \ + _convert_dict_to_map_items from pandas.api.types import is_categorical_dtype # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or \ @@ -160,6 +164,8 @@ def create_array(s, t): # Ensure timestamp series are in expected form for Spark internal representation if t is not None and pa.types.is_timestamp(t): s = _check_series_convert_timestamps_internal(s, self._timezone) + elif t is not None and pa.types.is_map(t): + s = _convert_dict_to_map_items(s) elif is_categorical_dtype(s.dtype): # Note: This can be removed once minimum pyarrow version is >= 0.16.1 s = s.astype(s.dtypes.categories.dtype) diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 67557120715ac..7e4d61b0d21b8 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -20,14 +20,15 @@ pandas instances during the type conversion. """ -from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \ - DoubleType, DecimalType, StringType, BinaryType, DateType, TimestampType, ArrayType, \ - StructType, StructField, BooleanType +from pyspark.sql.types import BooleanType, ByteType, ShortType, IntegerType, LongType, \ + FloatType, DoubleType, DecimalType, StringType, BinaryType, DateType, TimestampType, \ + ArrayType, MapType, StructType, StructField def to_arrow_type(dt): """ Convert Spark data type to pyarrow type """ + from distutils.version import LooseVersion import pyarrow as pa if type(dt) == BooleanType: arrow_type = pa.bool_() @@ -58,6 +59,13 @@ def to_arrow_type(dt): if type(dt.elementType) in [StructType, TimestampType]: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.list_(to_arrow_type(dt.elementType)) + elif type(dt) == MapType: + if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): + raise TypeError("MapType is only supported with pyarrow 2.0.0 and above") + if type(dt.keyType) in [StructType, TimestampType] or \ + type(dt.valueType) in [StructType, TimestampType]: + raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) + arrow_type = pa.map_(to_arrow_type(dt.keyType), to_arrow_type(dt.valueType)) elif type(dt) == StructType: if any(type(field.dataType) == StructType for field in dt): raise TypeError("Nested StructType not supported in conversion to Arrow") @@ -81,6 +89,8 @@ def to_arrow_schema(schema): def from_arrow_type(at): """ Convert pyarrow type to Spark data type. """ + from distutils.version import LooseVersion + import pyarrow as pa import pyarrow.types as types if types.is_boolean(at): spark_type = BooleanType() @@ -110,6 +120,12 @@ def from_arrow_type(at): if types.is_timestamp(at.value_type): raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) spark_type = ArrayType(from_arrow_type(at.value_type)) + elif types.is_map(at): + if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): + raise TypeError("MapType is only supported with pyarrow 2.0.0 and above") + if types.is_timestamp(at.key_type) or types.is_timestamp(at.item_type): + raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) + spark_type = MapType(from_arrow_type(at.key_type), from_arrow_type(at.item_type)) elif types.is_struct(at): if any(types.is_struct(field.type) for field in at): raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at)) @@ -306,3 +322,23 @@ def _check_series_convert_timestamps_tz_local(s, timezone): `pandas.Series` where if it is a timestamp, has been converted to tz-naive """ return _check_series_convert_timestamps_localize(s, timezone, None) + + +def _convert_map_items_to_dict(s): + """ + Convert a series with items as list of (key, value), as made from an Arrow column of map type, + to dict for compatibility with non-arrow MapType columns. + :param s: pandas.Series of lists of (key, value) pairs + :return: pandas.Series of dictionaries + """ + return s.apply(lambda m: None if m is None else {k: v for k, v in m}) + + +def _convert_dict_to_map_items(s): + """ + Convert a series of dictionaries to list of (key, value) pairs to match expected data + for Arrow column of map type. + :param s: pandas.Series of dictionaries + :return: pandas.Series of lists of (key, value) pairs + """ + return s.apply(lambda d: list(d.items()) if d is not None else None) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 55d5e9017b345..e764c42d88a31 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -21,13 +21,13 @@ import time import unittest import warnings +from distutils.version import LooseVersion from pyspark import SparkContext, SparkConf from pyspark.sql import Row, SparkSession from pyspark.sql.functions import udf from pyspark.sql.types import StructType, StringType, IntegerType, LongType, \ - FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, StructField, MapType, \ - ArrayType + FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, StructField, ArrayType from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message from pyspark.testing.utils import QuietTest @@ -114,9 +114,10 @@ def create_pandas_data_frame(self): return pd.DataFrame(data=data_dict) def test_toPandas_fallback_enabled(self): + ts = datetime.datetime(2015, 11, 1, 0, 30) with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}): - schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) + schema = StructType([StructField("a", ArrayType(TimestampType()), True)]) + df = self.spark.createDataFrame([([ts],)], schema=schema) with QuietTest(self.sc): with self.warnings_lock: with warnings.catch_warnings(record=True) as warns: @@ -129,10 +130,10 @@ def test_toPandas_fallback_enabled(self): self.assertTrue(len(user_warns) > 0) self.assertTrue( "Attempting non-optimization" in str(user_warns[-1])) - assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) + assert_frame_equal(pdf, pd.DataFrame({"a": [[ts]]})) def test_toPandas_fallback_disabled(self): - schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) + schema = StructType([StructField("a", ArrayType(TimestampType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): with self.warnings_lock: @@ -336,6 +337,62 @@ def test_toPandas_with_array_type(self): self.assertTrue(expected[r][e] == result_arrow[r][e] and result[r][e] == result_arrow[r][e]) + def test_createDataFrame_with_map_type(self): + map_data = [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}] + + pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], "m": map_data}) + schema = "id long, m map" + + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): + df = self.spark.createDataFrame(pdf, schema=schema) + + if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): + with QuietTest(self.sc): + with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"): + self.spark.createDataFrame(pdf, schema=schema) + else: + df_arrow = self.spark.createDataFrame(pdf, schema=schema) + + result = df.collect() + result_arrow = df_arrow.collect() + + self.assertEqual(len(result), len(result_arrow)) + for row, row_arrow in zip(result, result_arrow): + i, m = row + _, m_arrow = row_arrow + self.assertEqual(m, map_data[i]) + self.assertEqual(m_arrow, map_data[i]) + + def test_toPandas_with_map_type(self): + pdf = pd.DataFrame({"id": [0, 1, 2, 3], + "m": [{}, {"a": 1}, {"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 3}]}) + + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): + df = self.spark.createDataFrame(pdf, schema="id long, m map") + + if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): + with QuietTest(self.sc): + with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"): + df.toPandas() + else: + pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df) + assert_frame_equal(pdf_arrow, pdf_non) + + def test_toPandas_with_map_type_nulls(self): + pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], + "m": [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]}) + + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): + df = self.spark.createDataFrame(pdf, schema="id long, m map") + + if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): + with QuietTest(self.sc): + with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"): + df.toPandas() + else: + pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df) + assert_frame_equal(pdf_arrow, pdf_non) + def test_createDataFrame_with_int_col_names(self): import numpy as np pdf = pd.DataFrame(np.random.rand(4, 2)) @@ -345,26 +402,28 @@ def test_createDataFrame_with_int_col_names(self): self.assertEqual(pdf_col_names, df_arrow.columns) def test_createDataFrame_fallback_enabled(self): + ts = datetime.datetime(2015, 11, 1, 0, 30) with QuietTest(self.sc): with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}): with warnings.catch_warnings(record=True) as warns: # we want the warnings to appear even if this test is run from a subclass warnings.simplefilter("always") df = self.spark.createDataFrame( - pd.DataFrame([[{u'a': 1}]]), "a: map") + pd.DataFrame({"a": [[ts]]}), "a: array") # Catch and check the last UserWarning. user_warns = [ warn.message for warn in warns if isinstance(warn.message, UserWarning)] self.assertTrue(len(user_warns) > 0) self.assertTrue( "Attempting non-optimization" in str(user_warns[-1])) - self.assertEqual(df.collect(), [Row(a={u'a': 1})]) + self.assertEqual(df.collect(), [Row(a=[ts])]) def test_createDataFrame_fallback_disabled(self): with QuietTest(self.sc): with self.assertRaisesRegexp(TypeError, 'Unsupported type'): self.spark.createDataFrame( - pd.DataFrame([[{u'a': 1}]]), "a: map") + pd.DataFrame({"a": [[datetime.datetime(2015, 11, 1, 0, 30)]]}), + "a: array") # Regression test for SPARK-23314 def test_timestamp_dst(self): diff --git a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py index f9a7dd69b61fb..4afc1dfcc1c6e 100644 --- a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py @@ -176,9 +176,9 @@ def test_wrong_return_type(self): with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, - 'Invalid return type.*MapType'): + 'Invalid return type.*ArrayType.*TimestampType'): left.groupby('id').cogroup(right.groupby('id')).applyInPandas( - lambda l, r: l, 'id long, v map') + lambda l, r: l, 'id long, v array') def test_wrong_args(self): left = self.data1 diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map.py b/python/pyspark/sql/tests/test_pandas_grouped_map.py index 93e37125eaa33..ee68b95fc478d 100644 --- a/python/pyspark/sql/tests/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_grouped_map.py @@ -26,7 +26,7 @@ window from pyspark.sql.types import IntegerType, DoubleType, ArrayType, BinaryType, ByteType, \ LongType, DecimalType, ShortType, FloatType, StringType, BooleanType, StructType, \ - StructField, NullType, MapType, TimestampType + StructField, NullType, TimestampType from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message from pyspark.testing.utils import QuietTest @@ -246,10 +246,10 @@ def test_wrong_return_type(self): with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, - 'Invalid return type.*grouped map Pandas UDF.*MapType'): + 'Invalid return type.*grouped map Pandas UDF.*ArrayType.*TimestampType'): pandas_udf( lambda pdf: pdf, - 'id long, v map', + 'id long, v array', PandasUDFType.GROUPED_MAP) def test_wrong_args(self): @@ -276,7 +276,6 @@ def test_wrong_args(self): def test_unsupported_types(self): common_err_msg = 'Invalid return type.*grouped map Pandas UDF.*' unsupported_types = [ - StructField('map', MapType(StringType(), IntegerType())), StructField('arr_ts', ArrayType(TimestampType())), StructField('null', NullType()), StructField('struct', StructType([StructField('l', LongType())])), diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py index 451308927629b..2cbcf31f6e7b3 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py @@ -21,7 +21,7 @@ from pyspark.sql import Row from pyspark.sql.functions import array, explode, col, lit, mean, sum, \ udf, pandas_udf, PandasUDFType -from pyspark.sql.types import ArrayType, TimestampType, DoubleType, MapType +from pyspark.sql.types import ArrayType, TimestampType from pyspark.sql.utils import AnalysisException from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message @@ -159,7 +159,7 @@ def mean_and_std_udf(v): with QuietTest(self.sc): with self.assertRaisesRegexp(NotImplementedError, 'not supported'): - @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG) + @pandas_udf(ArrayType(TimestampType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return {v.mean(): v.std()} diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 6d325c9085ce1..5da5d043ceca4 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -22,6 +22,7 @@ import unittest from datetime import date, datetime from decimal import Decimal +from distutils.version import LooseVersion from pyspark import TaskContext from pyspark.rdd import PythonEvalType @@ -379,6 +380,20 @@ def test_vectorized_udf_nested_struct(self): 'Invalid return type with scalar Pandas UDFs'): pandas_udf(lambda x: x, returnType=nested_type, functionType=udf_type) + def test_vectorized_udf_map_type(self): + data = [({},), ({"a": 1},), ({"a": 1, "b": 2},), ({"a": 1, "b": 2, "c": 3},)] + schema = StructType([StructField("map", MapType(StringType(), LongType()))]) + df = self.spark.createDataFrame(data, schema=schema) + for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]: + if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): + with QuietTest(self.sc): + with self.assertRaisesRegex(Exception, "MapType.*not supported"): + pandas_udf(lambda x: x, MapType(StringType(), LongType()), udf_type) + else: + map_f = pandas_udf(lambda x: x, MapType(StringType(), LongType()), udf_type) + result = df.select(map_f(col('map'))) + self.assertEquals(df.collect(), result.collect()) + def test_vectorized_udf_complex(self): df = self.spark.range(10).select( col('id').cast('int').alias('a'), @@ -504,8 +519,8 @@ def test_vectorized_udf_wrong_return_type(self): for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]: with self.assertRaisesRegexp( NotImplementedError, - 'Invalid return type.*scalar Pandas UDF.*MapType'): - pandas_udf(lambda x: x, MapType(LongType(), LongType()), udf_type) + 'Invalid return type.*scalar Pandas UDF.*ArrayType.*TimestampType'): + pandas_udf(lambda x: x, ArrayType(TimestampType()), udf_type) def test_vectorized_udf_return_scalar(self): df = self.spark.range(10) @@ -577,8 +592,8 @@ def test_vectorized_udf_unsupported_types(self): for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]: with self.assertRaisesRegexp( NotImplementedError, - 'Invalid return type.*scalar Pandas UDF.*MapType'): - pandas_udf(lambda x: x, MapType(StringType(), IntegerType()), udf_type) + 'Invalid return type.*scalar Pandas UDF.*ArrayType.*TimestampType'): + pandas_udf(lambda x: x, ArrayType(TimestampType()), udf_type) with self.assertRaisesRegexp( NotImplementedError, 'Invalid return type.*scalar Pandas UDF.*ArrayType.StructType'): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6f79d1a91c814..5c17f0434bc79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1903,7 +1903,7 @@ object SQLConf { "1. pyspark.sql.DataFrame.toPandas " + "2. pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame " + "The following data types are unsupported: " + - "MapType, ArrayType of TimestampType, and nested StructType.") + "ArrayType of TimestampType, and nested StructType.") .version("3.0.0") .fallbackConf(ARROW_EXECUTION_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 501e1c460f9c9..f62aa5db0872f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -63,10 +63,10 @@ object ArrowWriter { val elementVector = createFieldWriter(vector.getDataVector()) new ArrayWriter(vector, elementVector) case (MapType(_, _, _), vector: MapVector) => - val entryWriter = createFieldWriter(vector.getDataVector).asInstanceOf[StructWriter] - val keyWriter = createFieldWriter(entryWriter.valueVector.getChild(MapVector.KEY_NAME)) - val valueWriter = createFieldWriter(entryWriter.valueVector.getChild(MapVector.VALUE_NAME)) - new MapWriter(vector, keyWriter, valueWriter) + val structVector = vector.getDataVector.asInstanceOf[StructVector] + val keyWriter = createFieldWriter(structVector.getChild(MapVector.KEY_NAME)) + val valueWriter = createFieldWriter(structVector.getChild(MapVector.VALUE_NAME)) + new MapWriter(vector, structVector, keyWriter, valueWriter) case (StructType(_), vector: StructVector) => val children = (0 until vector.size()).map { ordinal => createFieldWriter(vector.getChildByOrdinal(ordinal)) @@ -331,11 +331,11 @@ private[arrow] class StructWriter( override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { val struct = input.getStruct(ordinal, children.length) var i = 0 + valueVector.setIndexDefined(count) while (i < struct.numFields) { children(i).write(struct, i) i += 1 } - valueVector.setIndexDefined(count) } override def finish(): Unit = { @@ -351,6 +351,7 @@ private[arrow] class StructWriter( private[arrow] class MapWriter( val valueVector: MapVector, + val structVector: StructVector, val keyWriter: ArrowFieldWriter, val valueWriter: ArrowFieldWriter) extends ArrowFieldWriter { @@ -363,6 +364,7 @@ private[arrow] class MapWriter( val values = map.valueArray() var i = 0 while (i < map.numElements()) { + structVector.setIndexDefined(keyWriter.count) keyWriter.write(keys, i) valueWriter.write(values, i) i += 1