diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index 6a0196e2b18..40c5ab10bbb 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -668,6 +668,22 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS } } + test("test flatten with nullable inner arrays") { + val sql = + """ + |select id, flatten(arr) + |from ( + | select id, + | if(id = 0, + | array(array(cast(id + 1 as int)), cast(null as array)), + | array(array(cast(id + 1 as int)))) as arr + | from range(2) + |) + |order by id + |""".stripMargin + runQueryAndCompare(sql)(checkGlutenPlan[ProjectExecTransformer]) + } + test("test common subexpression eliminate") { def checkOperatorCount[T <: TransformSupport](count: Int)(df: DataFrame)(implicit tag: ClassTag[T]): Unit = { diff --git a/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp index 96faa9d1dc1..7ead48cac1f 100644 --- a/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp +++ b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp @@ -16,6 +16,7 @@ */ #include #include +#include #include #include #include @@ -107,19 +108,27 @@ result: Row 1: [1, 2, 3], Row2: [4] const IColumn::Offsets * prev_offsets = &src_offsets; const IColumn * prev_data = &src_col->getData(); bool nullable = prev_data->isNullable(); - // when array has null element, return null + ColumnUInt8::MutablePtr result_null_map; + // When an inner array is null, only the corresponding outer row is null. if (nullable) { const ColumnNullable * nullable_column = checkAndGetColumn(prev_data); prev_data = nullable_column->getNestedColumnPtr().get(); - for (size_t i = 0; i < nullable_column->size(); i++) + result_null_map = ColumnUInt8::create(input_rows_count, 0); + auto & result_null_map_data = result_null_map->getData(); + size_t prev_offset = 0; + for (size_t row = 0; row < input_rows_count; ++row) { - if (nullable_column->isNullAt(i)) + const auto current_offset = src_offsets[row]; + for (size_t i = prev_offset; i < current_offset; ++i) { - auto res= nullable_column->cloneEmpty(); - res->insertManyDefaults(input_rows_count); - return res; + if (nullable_column->isNullAt(i)) + { + result_null_map_data[row] = 1; + break; + } } + prev_offset = current_offset; } } if (isNothing(prev_data->getDataType())) @@ -142,7 +151,7 @@ result: Row 1: [1, 2, 3], Row2: [4] prev_data->getPtr(), result_offsets_column ? std::move(result_offsets_column) : src_col->getOffsetsPtr()); if (nullable) - return makeNullable(res); + return ColumnNullable::create(std::move(res), std::move(result_null_map)); return res; }