diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index f1d1f5b3ea800..eb71428677763 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -338,7 +338,7 @@ public final VariantVal getVariant(int rowId) { * Sets up the data type of this column vector. */ protected ColumnVector(DataType type) { - this.type = type.transformRecursively( + this.type = type == null ? null : type.transformRecursively( new PartialFunction() { @Override public boolean isDefinedAt(DataType x) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ConstantColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ConstantColumnVectorSuite.scala index 2bee643df4eff..1bebfafb9f388 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ConstantColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ConstantColumnVectorSuite.scala @@ -204,4 +204,8 @@ class ConstantColumnVectorSuite extends SparkFunSuite { assert(vector.getChild(2).getLong(i) == 12345L) } } + + testVector("null DataType", 0, null) { vector => + assert(vector.dataType == null) + } }