diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java index 90857c667abf9..c5a7d34281fca 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; +import org.apache.spark.sql.catalyst.types.*; import org.apache.spark.sql.types.*; public final class SpecializedGettersReader { @@ -28,70 +29,56 @@ public static Object read( DataType dataType, boolean handleNull, boolean handleUserDefinedType) { - if (handleNull && (obj.isNullAt(ordinal) || dataType instanceof NullType)) { + PhysicalDataType physicalDataType = dataType.physicalDataType(); + if (handleNull && (obj.isNullAt(ordinal) || physicalDataType instanceof PhysicalNullType)) { return null; } - if (dataType instanceof BooleanType) { + if (physicalDataType instanceof PhysicalBooleanType) { return obj.getBoolean(ordinal); } - if (dataType instanceof ByteType) { + if (physicalDataType instanceof PhysicalByteType) { return obj.getByte(ordinal); } - if (dataType instanceof ShortType) { + if (physicalDataType instanceof PhysicalShortType) { return obj.getShort(ordinal); } - if (dataType instanceof IntegerType) { + if (physicalDataType instanceof PhysicalIntegerType) { return obj.getInt(ordinal); } - if (dataType instanceof LongType) { + if (physicalDataType instanceof PhysicalLongType) { return obj.getLong(ordinal); } - if (dataType instanceof FloatType) { + if (physicalDataType instanceof PhysicalFloatType) { return obj.getFloat(ordinal); } - if (dataType instanceof DoubleType) { + if (physicalDataType instanceof PhysicalDoubleType) { return obj.getDouble(ordinal); } - if (dataType instanceof StringType) { + if (physicalDataType instanceof PhysicalStringType) { return obj.getUTF8String(ordinal); } - if (dataType instanceof DecimalType) { - DecimalType dt = (DecimalType) dataType; + if (physicalDataType instanceof PhysicalDecimalType) { + PhysicalDecimalType dt = (PhysicalDecimalType) physicalDataType; return obj.getDecimal(ordinal, dt.precision(), dt.scale()); } - if (dataType instanceof DateType) { - return obj.getInt(ordinal); - } - if (dataType instanceof TimestampType) { - return obj.getLong(ordinal); - } - if (dataType instanceof TimestampNTZType) { - return obj.getLong(ordinal); - } - if (dataType instanceof CalendarIntervalType) { + if (physicalDataType instanceof PhysicalCalendarIntervalType) { return obj.getInterval(ordinal); } - if (dataType instanceof BinaryType) { + if (physicalDataType instanceof PhysicalBinaryType) { return obj.getBinary(ordinal); } - if (dataType instanceof StructType) { - return obj.getStruct(ordinal, ((StructType) dataType).size()); + if (physicalDataType instanceof PhysicalStructType) { + return obj.getStruct(ordinal, ((PhysicalStructType) physicalDataType).fields().length); } - if (dataType instanceof ArrayType) { + if (physicalDataType instanceof PhysicalArrayType) { return obj.getArray(ordinal); } - if (dataType instanceof MapType) { + if (physicalDataType instanceof PhysicalMapType) { return obj.getMap(ordinal); } if (handleUserDefinedType && dataType instanceof UserDefinedType) { return obj.get(ordinal, ((UserDefinedType)dataType).sqlType()); } - if (dataType instanceof DayTimeIntervalType) { - return obj.getLong(ordinal); - } - if (dataType instanceof YearMonthIntervalType) { - return obj.getInt(ordinal); - } throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java index 32f6e71f77aac..05922a0cd5daa 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java @@ -19,6 +19,7 @@ import org.apache.spark.annotation.DeveloperApi; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.types.*; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -48,36 +49,33 @@ public InternalRow copy() { row.setNullAt(i); } else { DataType dt = columns[i].dataType(); - if (dt instanceof BooleanType) { + PhysicalDataType pdt = dt.physicalDataType(); + if (pdt instanceof PhysicalBooleanType) { row.setBoolean(i, getBoolean(i)); - } else if (dt instanceof ByteType) { + } else if (pdt instanceof PhysicalByteType) { row.setByte(i, getByte(i)); - } else if (dt instanceof ShortType) { + } else if (pdt instanceof PhysicalShortType) { row.setShort(i, getShort(i)); - } else if (dt instanceof IntegerType || dt instanceof YearMonthIntervalType) { + } else if (pdt instanceof PhysicalIntegerType) { row.setInt(i, getInt(i)); - } else if (dt instanceof LongType || dt instanceof DayTimeIntervalType) { + } else if (pdt instanceof PhysicalLongType) { row.setLong(i, getLong(i)); - } else if (dt instanceof FloatType) { + } else if (pdt instanceof PhysicalFloatType) { row.setFloat(i, getFloat(i)); - } else if (dt instanceof DoubleType) { + } else if (pdt instanceof PhysicalDoubleType) { row.setDouble(i, getDouble(i)); - } else if (dt instanceof StringType) { + } else if (pdt instanceof PhysicalStringType) { row.update(i, getUTF8String(i).copy()); - } else if (dt instanceof BinaryType) { + } else if (pdt instanceof PhysicalBinaryType) { row.update(i, getBinary(i)); - } else if (dt instanceof DecimalType) { - DecimalType t = (DecimalType)dt; + } else if (pdt instanceof PhysicalDecimalType) { + PhysicalDecimalType t = (PhysicalDecimalType)pdt; row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); - } else if (dt instanceof DateType) { - row.setInt(i, getInt(i)); - } else if (dt instanceof TimestampType) { - row.setLong(i, getLong(i)); - } else if (dt instanceof StructType) { - row.update(i, getStruct(i, ((StructType) dt).fields().length).copy()); - } else if (dt instanceof ArrayType) { + } else if (pdt instanceof PhysicalStructType) { + row.update(i, getStruct(i, ((PhysicalStructType) pdt).fields().length).copy()); + } else if (pdt instanceof PhysicalArrayType) { row.update(i, getArray(i).copy()); - } else if (dt instanceof MapType) { + } else if (pdt instanceof PhysicalMapType) { row.update(i, getMap(i).copy()); } else { throw new RuntimeException("Not implemented. " + dt); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index fd4e8ff5cab53..9c2b183334888 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -19,6 +19,7 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.types.*; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -55,36 +56,33 @@ public InternalRow copy() { row.setNullAt(i); } else { DataType dt = data.getChild(i).dataType(); - if (dt instanceof BooleanType) { + PhysicalDataType pdt = dt.physicalDataType(); + if (pdt instanceof PhysicalBooleanType) { row.setBoolean(i, getBoolean(i)); - } else if (dt instanceof ByteType) { + } else if (pdt instanceof PhysicalByteType) { row.setByte(i, getByte(i)); - } else if (dt instanceof ShortType) { + } else if (pdt instanceof PhysicalShortType) { row.setShort(i, getShort(i)); - } else if (dt instanceof IntegerType || dt instanceof YearMonthIntervalType) { + } else if (pdt instanceof PhysicalIntegerType) { row.setInt(i, getInt(i)); - } else if (dt instanceof LongType || dt instanceof DayTimeIntervalType) { + } else if (pdt instanceof PhysicalLongType) { row.setLong(i, getLong(i)); - } else if (dt instanceof FloatType) { + } else if (pdt instanceof PhysicalFloatType) { row.setFloat(i, getFloat(i)); - } else if (dt instanceof DoubleType) { + } else if (pdt instanceof PhysicalDoubleType) { row.setDouble(i, getDouble(i)); - } else if (dt instanceof StringType) { + } else if (pdt instanceof PhysicalStringType) { row.update(i, getUTF8String(i).copy()); - } else if (dt instanceof BinaryType) { + } else if (pdt instanceof PhysicalBinaryType) { row.update(i, getBinary(i)); - } else if (dt instanceof DecimalType) { - DecimalType t = (DecimalType)dt; + } else if (pdt instanceof PhysicalDecimalType) { + PhysicalDecimalType t = (PhysicalDecimalType)pdt; row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); - } else if (dt instanceof DateType) { - row.setInt(i, getInt(i)); - } else if (dt instanceof TimestampType) { - row.setLong(i, getLong(i)); - } else if (dt instanceof StructType) { - row.update(i, getStruct(i, ((StructType) dt).fields().length).copy()); - } else if (dt instanceof ArrayType) { + } else if (pdt instanceof PhysicalStructType) { + row.update(i, getStruct(i, ((PhysicalStructType) pdt).fields().length).copy()); + } else if (pdt instanceof PhysicalArrayType) { row.update(i, getArray(i).copy()); - } else if (dt instanceof MapType) { + } else if (pdt instanceof PhysicalMapType) { row.update(i, getMap(i).copy()); } else { throw new RuntimeException("Not implemented. " + dt); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 2b4482be4b69e..a44dca7dda937 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -129,24 +130,25 @@ object InternalRow { */ def getAccessor(dt: DataType, nullable: Boolean = true): (SpecializedGetters, Int) => Any = { val getValueNullSafe: (SpecializedGetters, Int) => Any = dt match { - case BooleanType => (input, ordinal) => input.getBoolean(ordinal) - case ByteType => (input, ordinal) => input.getByte(ordinal) - case ShortType => (input, ordinal) => input.getShort(ordinal) - case IntegerType | DateType | _: YearMonthIntervalType => - (input, ordinal) => input.getInt(ordinal) - case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType => - (input, ordinal) => input.getLong(ordinal) - case FloatType => (input, ordinal) => input.getFloat(ordinal) - case DoubleType => (input, ordinal) => input.getDouble(ordinal) - case StringType => (input, ordinal) => input.getUTF8String(ordinal) - case BinaryType => (input, ordinal) => input.getBinary(ordinal) - case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal) - case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale) - case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size) - case _: ArrayType => (input, ordinal) => input.getArray(ordinal) - case _: MapType => (input, ordinal) => input.getMap(ordinal) case u: UserDefinedType[_] => getAccessor(u.sqlType, nullable) - case _ => (input, ordinal) => input.get(ordinal, dt) + case _ => dt.physicalDataType match { + case PhysicalBooleanType => (input, ordinal) => input.getBoolean(ordinal) + case PhysicalByteType => (input, ordinal) => input.getByte(ordinal) + case PhysicalShortType => (input, ordinal) => input.getShort(ordinal) + case PhysicalIntegerType => (input, ordinal) => input.getInt(ordinal) + case PhysicalLongType => (input, ordinal) => input.getLong(ordinal) + case PhysicalFloatType => (input, ordinal) => input.getFloat(ordinal) + case PhysicalDoubleType => (input, ordinal) => input.getDouble(ordinal) + case PhysicalStringType => (input, ordinal) => input.getUTF8String(ordinal) + case PhysicalBinaryType => (input, ordinal) => input.getBinary(ordinal) + case PhysicalCalendarIntervalType => (input, ordinal) => input.getInterval(ordinal) + case t: PhysicalDecimalType => (input, ordinal) => + input.getDecimal(ordinal, t.precision, t.scale) + case t: PhysicalStructType => (input, ordinal) => input.getStruct(ordinal, t.fields.size) + case _: PhysicalArrayType => (input, ordinal) => input.getArray(ordinal) + case _: PhysicalMapType => (input, ordinal) => input.getMap(ordinal) + case _ => (input, ordinal) => input.get(ordinal, dt) + } } if (nullable) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index d7e497fafa86a..8eb3475acaaa0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.SerializerBuildHelper._ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -214,6 +215,8 @@ object RowEncoder { } else { nonNullOutput } + // For other data types, return the internal catalyst value as it is. + case _ => inputObject } /** @@ -253,13 +256,17 @@ object RowEncoder { } case _: DayTimeIntervalType => ObjectType(classOf[java.time.Duration]) case _: YearMonthIntervalType => ObjectType(classOf[java.time.Period]) - case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) - case StringType => ObjectType(classOf[java.lang.String]) - case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) - case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) - case _: StructType => ObjectType(classOf[Row]) case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType) case udt: UserDefinedType[_] => ObjectType(udt.userClass) + case _ => dt.physicalDataType match { + case _: PhysicalArrayType => ObjectType(classOf[scala.collection.Seq[_]]) + case _: PhysicalDecimalType => ObjectType(classOf[java.math.BigDecimal]) + case _: PhysicalMapType => ObjectType(classOf[scala.collection.Map[_, _]]) + case PhysicalStringType => ObjectType(classOf[java.lang.String]) + case _: PhysicalStructType => ObjectType(classOf[Row]) + // For other data types, return the data type as it is. + case _ => dt + } } private def deserializerFor(input: Expression, schema: StructType): Expression = { @@ -358,6 +365,9 @@ object RowEncoder { If(IsNull(input), Literal.create(null, externalDataTypeFor(input.dataType)), CreateExternalRow(convertedFields, schema)) + + // For other data types, return the internal catalyst value as it is. + case _ => input } private def expressionForNullableExpr( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 731ad16cc7d9f..c27863b7ef8b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{UserDefinedType, _} @@ -147,114 +148,103 @@ object InterpretedUnsafeProjection { // Create the basic writer. val unsafeWriter: (SpecializedGetters, Int) => Unit = dt match { - case BooleanType => - (v, i) => writer.write(i, v.getBoolean(i)) + case udt: UserDefinedType[_] => generateFieldWriter(writer, udt.sqlType, nullable) + case _ => dt.physicalDataType match { + case PhysicalBooleanType => (v, i) => writer.write(i, v.getBoolean(i)) - case ByteType => - (v, i) => writer.write(i, v.getByte(i)) + case PhysicalByteType => (v, i) => writer.write(i, v.getByte(i)) - case ShortType => - (v, i) => writer.write(i, v.getShort(i)) + case PhysicalShortType => (v, i) => writer.write(i, v.getShort(i)) - case IntegerType | DateType | _: YearMonthIntervalType => - (v, i) => writer.write(i, v.getInt(i)) + case PhysicalIntegerType => (v, i) => writer.write(i, v.getInt(i)) - case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType => - (v, i) => writer.write(i, v.getLong(i)) + case PhysicalLongType => (v, i) => writer.write(i, v.getLong(i)) - case FloatType => - (v, i) => writer.write(i, v.getFloat(i)) + case PhysicalFloatType => (v, i) => writer.write(i, v.getFloat(i)) - case DoubleType => - (v, i) => writer.write(i, v.getDouble(i)) + case PhysicalDoubleType => (v, i) => writer.write(i, v.getDouble(i)) - case DecimalType.Fixed(precision, scale) => - (v, i) => writer.write(i, v.getDecimal(i, precision, scale), precision, scale) + case PhysicalDecimalType(precision, scale) => + (v, i) => writer.write(i, v.getDecimal(i, precision, scale), precision, scale) - case CalendarIntervalType => - (v, i) => writer.write(i, v.getInterval(i)) + case PhysicalCalendarIntervalType => (v, i) => writer.write(i, v.getInterval(i)) - case BinaryType => - (v, i) => writer.write(i, v.getBinary(i)) + case PhysicalBinaryType => (v, i) => writer.write(i, v.getBinary(i)) - case StringType => - (v, i) => writer.write(i, v.getUTF8String(i)) + case PhysicalStringType => (v, i) => writer.write(i, v.getUTF8String(i)) - case StructType(fields) => - val numFields = fields.length - val rowWriter = new UnsafeRowWriter(writer, numFields) - val structWriter = generateStructWriter(rowWriter, fields) - (v, i) => { - v.getStruct(i, fields.length) match { - case row: UnsafeRow => - writer.write(i, row) - case row => - val previousCursor = writer.cursor() - // Nested struct. We don't know where this will start because a row can be - // variable length, so we need to update the offsets and zero out the bit mask. - rowWriter.resetRowWriter() - structWriter.apply(row) - writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) + case PhysicalStructType(fields) => + val numFields = fields.length + val rowWriter = new UnsafeRowWriter(writer, numFields) + val structWriter = generateStructWriter(rowWriter, fields) + (v, i) => { + v.getStruct(i, fields.length) match { + case row: UnsafeRow => + writer.write(i, row) + case row => + val previousCursor = writer.cursor() + // Nested struct. We don't know where this will start because a row can be + // variable length, so we need to update the offsets and zero out the bit mask. + rowWriter.resetRowWriter() + structWriter.apply(row) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) + } } - } - - case ArrayType(elementType, containsNull) => - val arrayWriter = new UnsafeArrayWriter(writer, getElementSize(elementType)) - val elementWriter = generateFieldWriter( - arrayWriter, - elementType, - containsNull) - (v, i) => { - val previousCursor = writer.cursor() - writeArray(arrayWriter, elementWriter, v.getArray(i)) - writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) - } - case MapType(keyType, valueType, valueContainsNull) => - val keyArrayWriter = new UnsafeArrayWriter(writer, getElementSize(keyType)) - val keyWriter = generateFieldWriter( - keyArrayWriter, - keyType, - nullable = false) - val valueArrayWriter = new UnsafeArrayWriter(writer, getElementSize(valueType)) - val valueWriter = generateFieldWriter( - valueArrayWriter, - valueType, - valueContainsNull) - (v, i) => { - v.getMap(i) match { - case map: UnsafeMapData => - writer.write(i, map) - case map => - val previousCursor = writer.cursor() - - // preserve 8 bytes to write the key array numBytes later. - valueArrayWriter.grow(8) - valueArrayWriter.increaseCursor(8) - - // Write the keys and write the numBytes of key array into the first 8 bytes. - writeArray(keyArrayWriter, keyWriter, map.keyArray()) - Platform.putLong( - valueArrayWriter.getBuffer, - previousCursor, - valueArrayWriter.cursor - previousCursor - 8 - ) - - // Write the values. - writeArray(valueArrayWriter, valueWriter, map.valueArray()) - writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) + case PhysicalArrayType(elementType, containsNull) => + val arrayWriter = new UnsafeArrayWriter(writer, getElementSize(elementType)) + val elementWriter = generateFieldWriter( + arrayWriter, + elementType, + containsNull) + (v, i) => { + val previousCursor = writer.cursor() + writeArray(arrayWriter, elementWriter, v.getArray(i)) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } - } - case udt: UserDefinedType[_] => - generateFieldWriter(writer, udt.sqlType, nullable) + case PhysicalMapType(keyType, valueType, valueContainsNull) => + val keyArrayWriter = new UnsafeArrayWriter(writer, getElementSize(keyType)) + val keyWriter = generateFieldWriter( + keyArrayWriter, + keyType, + nullable = false) + val valueArrayWriter = new UnsafeArrayWriter(writer, getElementSize(valueType)) + val valueWriter = generateFieldWriter( + valueArrayWriter, + valueType, + valueContainsNull) + (v, i) => { + v.getMap(i) match { + case map: UnsafeMapData => + writer.write(i, map) + case map => + val previousCursor = writer.cursor() + + // preserve 8 bytes to write the key array numBytes later. + valueArrayWriter.grow(8) + valueArrayWriter.increaseCursor(8) + + // Write the keys and write the numBytes of key array into the first 8 bytes. + writeArray(keyArrayWriter, keyWriter, map.keyArray()) + Platform.putLong( + valueArrayWriter.getBuffer, + previousCursor, + valueArrayWriter.cursor - previousCursor - 8 + ) + + // Write the values. + writeArray(valueArrayWriter, valueWriter, map.valueArray()) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) + } + } - case NullType => - (_, _) => {} + case PhysicalNullType => (_, _) => {} - case _ => - throw new IllegalStateException(s"The data type '${dt.typeName}' is not supported in " + - "generating a writer function for a struct field, array element, map key or map value.") + case _ => + throw new IllegalStateException(s"The data type '${dt.typeName}' is not supported in " + + "generating a writer function for a struct field, array element, map key or map value.") + } } // Always wrap the writer with a null safe version. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 175f4561f3cc7..c871b91e8bd80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -38,6 +38,7 @@ import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData, SQLOrderingUtil} import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.errors.QueryExecutionErrors @@ -1622,17 +1623,19 @@ object CodeGenerator extends Logging { def getValue(input: String, dataType: DataType, ordinal: String): String = { val jt = javaType(dataType) dataType match { - case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" - case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" - case StringType => s"$input.getUTF8String($ordinal)" - case BinaryType => s"$input.getBinary($ordinal)" - case CalendarIntervalType => s"$input.getInterval($ordinal)" - case t: StructType => s"$input.getStruct($ordinal, ${t.size})" - case _: ArrayType => s"$input.getArray($ordinal)" - case _: MapType => s"$input.getMap($ordinal)" - case NullType => "null" case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) - case _ => s"($jt)$input.get($ordinal, null)" + case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" + case _ => dataType.physicalDataType match { + case _: PhysicalArrayType => s"$input.getArray($ordinal)" + case PhysicalBinaryType => s"$input.getBinary($ordinal)" + case PhysicalCalendarIntervalType => s"$input.getInterval($ordinal)" + case t: PhysicalDecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" + case _: PhysicalMapType => s"$input.getMap($ordinal)" + case PhysicalNullType => "null" + case PhysicalStringType => s"$input.getUTF8String($ordinal)" + case t: PhysicalStructType => s"$input.getStruct($ordinal, ${t.fields.size})" + case _ => s"($jt)$input.get($ordinal, null)" + } } } @@ -1901,24 +1904,26 @@ object CodeGenerator extends Logging { * Returns the Java type for a DataType. */ def javaType(dt: DataType): String = dt match { - case BooleanType => JAVA_BOOLEAN - case ByteType => JAVA_BYTE - case ShortType => JAVA_SHORT - case IntegerType | DateType | _: YearMonthIntervalType => JAVA_INT - case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType => JAVA_LONG - case FloatType => JAVA_FLOAT - case DoubleType => JAVA_DOUBLE - case _: DecimalType => "Decimal" - case BinaryType => "byte[]" - case StringType => "UTF8String" - case CalendarIntervalType => "CalendarInterval" - case _: StructType => "InternalRow" - case _: ArrayType => "ArrayData" - case _: MapType => "MapData" case udt: UserDefinedType[_] => javaType(udt.sqlType) case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" case ObjectType(cls) => cls.getName - case _ => "Object" + case _ => dt.physicalDataType match { + case _: PhysicalArrayType => "ArrayData" + case PhysicalBinaryType => "byte[]" + case PhysicalBooleanType => JAVA_BOOLEAN + case PhysicalByteType => JAVA_BYTE + case PhysicalCalendarIntervalType => "CalendarInterval" + case PhysicalIntegerType => JAVA_INT + case _: PhysicalDecimalType => "Decimal" + case PhysicalDoubleType => JAVA_DOUBLE + case PhysicalFloatType => JAVA_FLOAT + case PhysicalLongType => JAVA_LONG + case _: PhysicalMapType => "MapData" + case PhysicalShortType => JAVA_SHORT + case PhysicalStringType => "UTF8String" + case _: PhysicalStructType => "InternalRow" + case _ => "Object" + } } @tailrec diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 38bef3bc36ee8..e8ac858eb1173 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, Scala import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreePattern import org.apache.spark.sql.catalyst.trees.TreePattern.{LITERAL, NULL_LITERAL, TRUE_OR_FALSE_LITERAL} +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE @@ -205,39 +206,41 @@ object Literal { private[expressions] def validateLiteralValue(value: Any, dataType: DataType): Unit = { def doValidate(v: Any, dataType: DataType): Boolean = dataType match { case _ if v == null => true - case BooleanType => v.isInstanceOf[Boolean] - case ByteType => v.isInstanceOf[Byte] - case ShortType => v.isInstanceOf[Short] - case IntegerType | DateType | _: YearMonthIntervalType => v.isInstanceOf[Int] - case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType => - v.isInstanceOf[Long] - case FloatType => v.isInstanceOf[Float] - case DoubleType => v.isInstanceOf[Double] - case _: DecimalType => v.isInstanceOf[Decimal] - case CalendarIntervalType => v.isInstanceOf[CalendarInterval] - case BinaryType => v.isInstanceOf[Array[Byte]] - case StringType => v.isInstanceOf[UTF8String] - case st: StructType => - v.isInstanceOf[InternalRow] && { - val row = v.asInstanceOf[InternalRow] - st.fields.map(_.dataType).zipWithIndex.forall { - case (dt, i) => doValidate(row.get(i, dt), dt) - } - } - case at: ArrayType => - v.isInstanceOf[ArrayData] && { - val ar = v.asInstanceOf[ArrayData] - ar.numElements() == 0 || doValidate(ar.get(0, at.elementType), at.elementType) - } - case mt: MapType => - v.isInstanceOf[MapData] && { - val map = v.asInstanceOf[MapData] - doValidate(map.keyArray(), ArrayType(mt.keyType)) && - doValidate(map.valueArray(), ArrayType(mt.valueType)) - } case ObjectType(cls) => cls.isInstance(v) case udt: UserDefinedType[_] => doValidate(v, udt.sqlType) - case _ => false + case dt => dataType.physicalDataType match { + case PhysicalArrayType(et, _) => + v.isInstanceOf[ArrayData] && { + val ar = v.asInstanceOf[ArrayData] + ar.numElements() == 0 || doValidate(ar.get(0, et), et) + } + case PhysicalBinaryType => v.isInstanceOf[Array[Byte]] + case PhysicalBooleanType => v.isInstanceOf[Boolean] + case PhysicalByteType => v.isInstanceOf[Byte] + case PhysicalCalendarIntervalType => v.isInstanceOf[CalendarInterval] + case PhysicalIntegerType => v.isInstanceOf[Int] + case _: PhysicalDecimalType => v.isInstanceOf[Decimal] + case PhysicalDoubleType => v.isInstanceOf[Double] + case PhysicalFloatType => v.isInstanceOf[Float] + case PhysicalLongType => v.isInstanceOf[Long] + case PhysicalMapType(kt, vt, _) => + v.isInstanceOf[MapData] && { + val map = v.asInstanceOf[MapData] + doValidate(map.keyArray(), ArrayType(kt)) && + doValidate(map.valueArray(), ArrayType(vt)) + } + case PhysicalNullType => true + case PhysicalShortType => v.isInstanceOf[Short] + case PhysicalStringType => v.isInstanceOf[UTF8String] + case st: PhysicalStructType => + v.isInstanceOf[InternalRow] && { + val row = v.asInstanceOf[InternalRow] + st.fields.map(_.dataType).zipWithIndex.forall { + case (fieldDataType, i) => doValidate(row.get(i, fieldDataType), fieldDataType) + } + } + case _ => false + } } require(doValidate(value, dataType), s"Literal must have a corresponding value to ${dataType.catalogString}, " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala new file mode 100644 index 0000000000000..26096e85b3571 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.types + +import org.apache.spark.sql.types._ + +sealed abstract class PhysicalDataType + +case class PhysicalArrayType(elementType: DataType, containsNull: Boolean) extends PhysicalDataType + +class PhysicalBinaryType() extends PhysicalDataType +case object PhysicalBinaryType extends PhysicalBinaryType + +class PhysicalBooleanType() extends PhysicalDataType +case object PhysicalBooleanType extends PhysicalBooleanType + +class PhysicalByteType() extends PhysicalDataType +case object PhysicalByteType extends PhysicalByteType + +class PhysicalCalendarIntervalType() extends PhysicalDataType +case object PhysicalCalendarIntervalType extends PhysicalCalendarIntervalType + +case class PhysicalDecimalType(precision: Int, scale: Int) extends PhysicalDataType + +class PhysicalDoubleType() extends PhysicalDataType +case object PhysicalDoubleType extends PhysicalDoubleType + +class PhysicalFloatType() extends PhysicalDataType +case object PhysicalFloatType extends PhysicalFloatType + +class PhysicalIntegerType() extends PhysicalDataType +case object PhysicalIntegerType extends PhysicalIntegerType + +class PhysicalLongType() extends PhysicalDataType +case object PhysicalLongType extends PhysicalLongType + +case class PhysicalMapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) + extends PhysicalDataType + +class PhysicalNullType() extends PhysicalDataType +case object PhysicalNullType extends PhysicalNullType + +class PhysicalShortType() extends PhysicalDataType +case object PhysicalShortType extends PhysicalShortType + +class PhysicalStringType() extends PhysicalDataType +case object PhysicalStringType extends PhysicalStringType + +case class PhysicalStructType(fields: Array[StructField]) extends PhysicalDataType + +object UninitializedPhysicalType extends PhysicalDataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index e139823b2bd01..3e5f447a7621c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -22,6 +22,7 @@ import scala.math.Ordering import org.json4s.JsonDSL._ import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, PhysicalDataType} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat @@ -90,6 +91,9 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT */ override def defaultSize: Int = 1 * elementType.defaultSize + override def physicalDataType: PhysicalDataType = + PhysicalArrayType(elementType, containsNull) + override def simpleString: String = s"array<${elementType.simpleString}>" override def catalogString: String = s"array<${elementType.catalogString}>" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index c3fa54c1767de..d2998f533de2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.types.{PhysicalBinaryType, PhysicalDataType} import org.apache.spark.unsafe.types.ByteArray /** @@ -44,6 +45,8 @@ class BinaryType private() extends AtomicType { */ override def defaultSize: Int = 100 + override def physicalDataType: PhysicalDataType = PhysicalBinaryType + private[spark] override def asNullable: BinaryType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala index 5e3de71caa37e..d8766e95e200c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala @@ -21,6 +21,7 @@ import scala.math.Ordering import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.types.{PhysicalBooleanType, PhysicalDataType} /** * The data type representing `Boolean` values. Please use the singleton `DataTypes.BooleanType`. @@ -41,6 +42,8 @@ class BooleanType private() extends AtomicType { */ override def defaultSize: Int = 1 + override def physicalDataType: PhysicalDataType = PhysicalBooleanType + private[spark] override def asNullable: BooleanType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala index 0df9518045f07..7c361fc78e2da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -21,6 +21,7 @@ import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.types.{PhysicalByteType, PhysicalDataType} /** * The data type representing `Byte` values. Please use the singleton `DataTypes.ByteType`. @@ -44,6 +45,8 @@ class ByteType private() extends IntegralType { */ override def defaultSize: Int = 1 + override def physicalDataType: PhysicalDataType = PhysicalByteType + override def simpleString: String = "tinyint" private[spark] override def asNullable: ByteType = this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala index d506a1521e183..6073aacb03ed7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.types.{PhysicalCalendarIntervalType, PhysicalDataType} /** * The data type representing calendar intervals. The calendar interval is stored internally in @@ -37,6 +38,8 @@ class CalendarIntervalType private() extends DataType { override def defaultSize: Int = 16 + override def physicalDataType: PhysicalDataType = PhysicalCalendarIntervalType + override def typeName: String = "interval" private[spark] override def asNullable: CalendarIntervalType = this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala index 67ab1cc2f3321..6bc6d39f14337 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala @@ -21,6 +21,7 @@ import scala.math.Ordering import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalStringType} import org.apache.spark.unsafe.types.UTF8String @Experimental @@ -32,6 +33,7 @@ case class CharType(length: Int) extends AtomicType { private[sql] val ordering = implicitly[Ordering[InternalType]] override def defaultSize: Int = length + override def physicalDataType: PhysicalDataType = PhysicalStringType override def typeName: String = s"char($length)" override def toString: String = s"CharType($length)" private[spark] override def asNullable: CharType = this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index ef7f1553be9da..d29593f1e2e74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -31,6 +31,7 @@ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer, DataTypeJsonSerializer} import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.errors.QueryCompilationErrors @@ -116,6 +117,8 @@ abstract class DataType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = this override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) + + def physicalDataType: PhysicalDataType = UninitializedPhysicalType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala index 700e95bc75946..0a794266acdfd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala @@ -21,6 +21,7 @@ import scala.math.Ordering import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalIntegerType} /** * The date type represents a valid date in the proleptic Gregorian calendar. @@ -46,6 +47,8 @@ class DateType private() extends DatetimeType { */ override def defaultSize: Int = 4 + override def physicalDataType: PhysicalDataType = PhysicalIntegerType + private[spark] override def asNullable: DateType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala index ca8a1f71bdd88..802c8a766374f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala @@ -21,6 +21,7 @@ import scala.math.Ordering import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Unstable +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalLongType} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.DayTimeIntervalType.fieldToString @@ -60,6 +61,8 @@ case class DayTimeIntervalType(startField: Byte, endField: Byte) extends AnsiInt */ override def defaultSize: Int = 8 + override def physicalDataType: PhysicalDataType = PhysicalLongType + private[spark] override def asNullable: DayTimeIntervalType = this override val typeName: String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index ec7dc62d0dc73..7d0b4a0904787 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -24,6 +24,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalDecimalType} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -110,6 +111,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { */ override def defaultSize: Int = if (precision <= Decimal.MAX_LONG_DIGITS) 8 else 16 + override def physicalDataType: PhysicalDataType = PhysicalDecimalType(precision, scale) + override def simpleString: String = s"decimal($precision,$scale)" private[spark] override def asNullable: DecimalType = this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index ea4f39d4b19d2..cef0681e88df2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.typeTag import scala.util.Try import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalDoubleType} import org.apache.spark.sql.catalyst.util.SQLOrderingUtil /** @@ -49,6 +50,8 @@ class DoubleType private() extends FractionalType { */ override def defaultSize: Int = 8 + override def physicalDataType: PhysicalDataType = PhysicalDoubleType + private[spark] override def asNullable: DoubleType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index f00046facf693..2e3992546d0ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.typeTag import scala.util.Try import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalFloatType} import org.apache.spark.sql.catalyst.util.SQLOrderingUtil /** @@ -49,6 +50,8 @@ class FloatType private() extends FractionalType { */ override def defaultSize: Int = 4 + override def physicalDataType: PhysicalDataType = PhysicalFloatType + private[spark] override def asNullable: FloatType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala index c344523bdcb89..d58a4b6355457 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -21,6 +21,7 @@ import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalIntegerType} /** * The data type representing `Int` values. Please use the singleton `DataTypes.IntegerType`. @@ -44,6 +45,8 @@ class IntegerType private() extends IntegralType { */ override def defaultSize: Int = 4 + override def physicalDataType: PhysicalDataType = PhysicalIntegerType + override def simpleString: String = "int" private[spark] override def asNullable: IntegerType = this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala index f030920db4517..be0560657c787 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -21,6 +21,7 @@ import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalLongType} /** * The data type representing `Long` values. Please use the singleton `DataTypes.LongType`. @@ -44,6 +45,8 @@ class LongType private() extends IntegralType { */ override def defaultSize: Int = 8 + override def physicalDataType: PhysicalDataType = PhysicalLongType + override def simpleString: String = "bigint" private[spark] override def asNullable: LongType = this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 2e5c7f731dcc7..df7c18edc8a61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -21,6 +21,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalMapType} import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat /** @@ -67,6 +68,9 @@ case class MapType( */ override def defaultSize: Int = 1 * (keyType.defaultSize + valueType.defaultSize) + override def physicalDataType: PhysicalDataType = + PhysicalMapType(keyType, valueType, valueContainsNull) + override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" override def catalogString: String = s"map<${keyType.catalogString},${valueType.catalogString}>" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala index d211fac70c641..171c9a6a67d82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalNullType} /** * The data type representing `NULL` values. Please use the singleton `DataTypes.NullType`. @@ -31,6 +32,8 @@ class NullType private() extends DataType { // Defined with a private constructor so the companion object is the only possible instantiation. override def defaultSize: Int = 1 + override def physicalDataType: PhysicalDataType = PhysicalNullType + private[spark] override def asNullable: NullType = this override def typeName: String = "void" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala index 8252689958531..3d40610c168a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -21,6 +21,7 @@ import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalShortType} /** * The data type representing `Short` values. Please use the singleton `DataTypes.ShortType`. @@ -44,6 +45,8 @@ class ShortType private() extends IntegralType { */ override def defaultSize: Int = 2 + override def physicalDataType: PhysicalDataType = PhysicalShortType + override def simpleString: String = "smallint" private[spark] override def asNullable: ShortType = this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala index 8ce1cd078e312..9ab40d3d89e1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -21,6 +21,7 @@ import scala.math.Ordering import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalStringType} import org.apache.spark.unsafe.types.UTF8String /** @@ -42,6 +43,8 @@ class StringType private() extends AtomicType { */ override def defaultSize: Int = 20 + override def physicalDataType: PhysicalDataType = PhysicalStringType + private[spark] override def asNullable: StringType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d5f32aac55a4e..6d7a948fc997f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalStructType} import org.apache.spark.sql.catalyst.util.{truncatedString, StringUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat @@ -431,6 +432,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru */ override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum + override def physicalDataType: PhysicalDataType = PhysicalStructType(fields) + override def simpleString: String = { val fieldTypes = fields.view.map(field => s"${field.name}:${field.dataType.simpleString}").toSeq truncatedString( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala index 508a1b03bde9a..a554a0bcfa3b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Unstable +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalLongType} /** * The timestamp without time zone type represents a local time in microsecond precision, @@ -47,6 +48,8 @@ class TimestampNTZType private() extends DatetimeType { */ override def defaultSize: Int = 8 + override def physicalDataType: PhysicalDataType = PhysicalLongType + override def typeName: String = "timestamp_ntz" private[spark] override def asNullable: TimestampNTZType = this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala index d52de414861f6..b3a45275f2fa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -21,6 +21,7 @@ import scala.math.Ordering import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalLongType} /** * The timestamp type represents a time instant in microsecond precision. @@ -48,6 +49,8 @@ class TimestampType private() extends DatetimeType { */ override def defaultSize: Int = 8 + override def physicalDataType: PhysicalDataType = PhysicalLongType + private[spark] override def asNullable: TimestampType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala index 2e30820ef0a05..eab9be096ff02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala @@ -20,6 +20,7 @@ import scala.math.Ordering import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalStringType} import org.apache.spark.unsafe.types.UTF8String @Experimental @@ -27,6 +28,7 @@ case class VarcharType(length: Int) extends AtomicType { require(length >= 0, "The length of varchar type cannot be negative.") private[sql] type InternalType = UTF8String + override def physicalDataType: PhysicalDataType = PhysicalStringType @transient private[sql] lazy val tag = typeTag[InternalType] private[sql] val ordering = implicitly[Ordering[InternalType]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala index 4d9168f6ec86a..5ed3b5574ef9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala @@ -21,6 +21,7 @@ import scala.math.Ordering import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Unstable +import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalIntegerType} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.YearMonthIntervalType.fieldToString @@ -58,6 +59,8 @@ case class YearMonthIntervalType(startField: Byte, endField: Byte) extends AnsiI */ override def defaultSize: Int = 4 + override def physicalDataType: PhysicalDataType = PhysicalIntegerType + private[spark] override def asNullable: YearMonthIntervalType = this override val typeName: String = {