From f3d635d58060a1efaa5df1116917a1985886bcf2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 24 Apr 2023 16:10:13 +0800 Subject: [PATCH 1/2] Add ToPrettyString expression for Dataset.show --- .../spark/sql/catalyst/expressions/Cast.scala | 375 +--------------- .../catalyst/expressions/ToPrettyString.scala | 76 ++++ .../catalyst/expressions/ToStringBase.scala | 414 ++++++++++++++++++ .../spark/sql/catalyst/util/StringUtils.scala | 2 + .../catalyst/expressions/CastSuiteBase.scala | 10 +- .../scala/org/apache/spark/sql/Dataset.scala | 17 +- 6 files changed, 514 insertions(+), 380 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 3de31b1ed28df..88b2718789e32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -33,13 +33,11 @@ import org.apache.spark.sql.catalyst.types.{PhysicalFractionalType, PhysicalInte import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ -import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte, dayTimeIntervalToDecimal, dayTimeIntervalToInt, dayTimeIntervalToLong, dayTimeIntervalToShort, yearMonthIntervalToByte, yearMonthIntervalToInt, yearMonthIntervalToShort} import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.UTF8StringBuilder -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper} object Cast extends QueryErrorsBase { @@ -496,6 +494,7 @@ case class Cast( evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends UnaryExpression with TimeZoneAwareExpression + with ToStringBase with NullIntolerant with SupportQueryContext with QueryErrorsBase { @@ -591,133 +590,22 @@ case class Cast( // [[func]] assumes the input is no longer null because eval already does the null check. @inline protected[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) - private lazy val dateFormatter = DateFormatter() - private lazy val timestampFormatter = TimestampFormatter.getFractionFormatter(zoneId) - private lazy val timestampNTZFormatter = - TimestampFormatter.getFractionFormatter(ZoneOffset.UTC) - private val legacyCastToStr = SQLConf.get.getConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING) - // The brackets that are used in casting structs and maps to strings - private val (leftBracket, rightBracket) = if (legacyCastToStr) ("[", "]") else ("{", "}") + + protected val (leftBracket, rightBracket) = if (legacyCastToStr) ("[", "]") else ("{", "}") + + override protected def nullString: String = if (legacyCastToStr) "" else "null" + + // In ANSI mode, Spark always use plain string representation on casting Decimal values + // as strings. Otherwise, the casting is using `BigDecimal.toString` which may use scientific + // notation if an exponent is needed. + override protected def useDecimalPlainString: Boolean = ansiEnabled + + override protected def useHexFormatForBinary: Boolean = false // The class name of `DateTimeUtils` protected def dateTimeUtilsCls: String = DateTimeUtils.getClass.getName.stripSuffix("$") - // UDFToString - private[this] def castToString(from: DataType): Any => Any = from match { - case CalendarIntervalType => - buildCast[CalendarInterval](_, i => UTF8String.fromString(i.toString)) - case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes) - case DateType => buildCast[Int](_, d => UTF8String.fromString(dateFormatter.format(d))) - case TimestampType => buildCast[Long](_, - t => UTF8String.fromString(timestampFormatter.format(t))) - case TimestampNTZType => buildCast[Long](_, - t => UTF8String.fromString(timestampNTZFormatter.format(t))) - case ArrayType(et, _) => - buildCast[ArrayData](_, array => { - val builder = new UTF8StringBuilder - builder.append("[") - if (array.numElements > 0) { - val toUTF8String = castToString(et) - if (array.isNullAt(0)) { - if (!legacyCastToStr) builder.append("NULL") - } else { - builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String]) - } - var i = 1 - while (i < array.numElements) { - builder.append(",") - if (array.isNullAt(i)) { - if (!legacyCastToStr) builder.append(" NULL") - } else { - builder.append(" ") - builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String]) - } - i += 1 - } - } - builder.append("]") - builder.build() - }) - case MapType(kt, vt, _) => - buildCast[MapData](_, map => { - val builder = new UTF8StringBuilder - builder.append(leftBracket) - if (map.numElements > 0) { - val keyArray = map.keyArray() - val valueArray = map.valueArray() - val keyToUTF8String = castToString(kt) - val valueToUTF8String = castToString(vt) - builder.append(keyToUTF8String(keyArray.get(0, kt)).asInstanceOf[UTF8String]) - builder.append(" ->") - if (valueArray.isNullAt(0)) { - if (!legacyCastToStr) builder.append(" NULL") - } else { - builder.append(" ") - builder.append(valueToUTF8String(valueArray.get(0, vt)).asInstanceOf[UTF8String]) - } - var i = 1 - while (i < map.numElements) { - builder.append(", ") - builder.append(keyToUTF8String(keyArray.get(i, kt)).asInstanceOf[UTF8String]) - builder.append(" ->") - if (valueArray.isNullAt(i)) { - if (!legacyCastToStr) builder.append(" NULL") - } else { - builder.append(" ") - builder.append(valueToUTF8String(valueArray.get(i, vt)) - .asInstanceOf[UTF8String]) - } - i += 1 - } - } - builder.append(rightBracket) - builder.build() - }) - case StructType(fields) => - buildCast[InternalRow](_, row => { - val builder = new UTF8StringBuilder - builder.append(leftBracket) - if (row.numFields > 0) { - val st = fields.map(_.dataType) - val toUTF8StringFuncs = st.map(castToString) - if (row.isNullAt(0)) { - if (!legacyCastToStr) builder.append("NULL") - } else { - builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))).asInstanceOf[UTF8String]) - } - var i = 1 - while (i < row.numFields) { - builder.append(",") - if (row.isNullAt(i)) { - if (!legacyCastToStr) builder.append(" NULL") - } else { - builder.append(" ") - builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))).asInstanceOf[UTF8String]) - } - i += 1 - } - } - builder.append(rightBracket) - builder.build() - }) - case pudt: PythonUserDefinedType => castToString(pudt.sqlType) - case udt: UserDefinedType[_] => - buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString)) - case YearMonthIntervalType(startField, endField) => - buildCast[Int](_, i => UTF8String.fromString( - IntervalUtils.toYearMonthIntervalString(i, ANSI_STYLE, startField, endField))) - case DayTimeIntervalType(startField, endField) => - buildCast[Long](_, i => UTF8String.fromString( - IntervalUtils.toDayTimeIntervalString(i, ANSI_STYLE, startField, endField))) - // In ANSI mode, Spark always use plain string representation on casting Decimal values - // as strings. Otherwise, the casting is using `BigDecimal.toString` which may use scientific - // notation if an exponent is needed. - case _: DecimalType if ansiEnabled => - buildCast[Decimal](_, d => UTF8String.fromString(d.toPlainString)) - case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) - } - // BinaryConverter private[this] def castToBinary(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, _.getBytes) @@ -1342,7 +1230,7 @@ case class Cast( case _ if from == NullType => (c, evPrim, evNull) => code"$evNull = true;" case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;" - case StringType => castToStringCode(from, ctx) + case StringType => (c, evPrim, _) => castToStringCode(from, ctx).apply(c, evPrim) case BinaryType => castToBinaryCode(from) case DateType => castToDateCode(from, ctx) case decimal: DecimalType => castToDecimalCode(from, decimal, ctx) @@ -1394,241 +1282,6 @@ case class Cast( """ } - private def appendIfNotLegacyCastToStr(buffer: ExprValue, s: String): Block = { - if (!legacyCastToStr) code"""$buffer.append("$s");""" else EmptyBlock - } - - private def writeArrayToStringBuilder( - et: DataType, - array: ExprValue, - buffer: ExprValue, - ctx: CodegenContext): Block = { - val elementToStringCode = castToStringCode(et, ctx) - val funcName = ctx.freshName("elementToString") - val element = JavaCode.variable("element", et) - val elementStr = JavaCode.variable("elementStr", StringType) - val elementToStringFunc = inline"${ctx.addNewFunction(funcName, - s""" - |private UTF8String $funcName(${CodeGenerator.javaType(et)} $element) { - | UTF8String $elementStr = null; - | ${elementToStringCode(element, elementStr, null /* resultIsNull won't be used */)} - | return elementStr; - |} - """.stripMargin)}" - - val loopIndex = ctx.freshVariable("loopIndex", IntegerType) - code""" - |$buffer.append("["); - |if ($array.numElements() > 0) { - | if ($array.isNullAt(0)) { - | ${appendIfNotLegacyCastToStr(buffer, "NULL")} - | } else { - | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, "0")})); - | } - | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { - | $buffer.append(","); - | if ($array.isNullAt($loopIndex)) { - | ${appendIfNotLegacyCastToStr(buffer, " NULL")} - | } else { - | $buffer.append(" "); - | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, loopIndex)})); - | } - | } - |} - |$buffer.append("]"); - """.stripMargin - } - - private def writeMapToStringBuilder( - kt: DataType, - vt: DataType, - map: ExprValue, - buffer: ExprValue, - ctx: CodegenContext): Block = { - - def dataToStringFunc(func: String, dataType: DataType) = { - val funcName = ctx.freshName(func) - val dataToStringCode = castToStringCode(dataType, ctx) - val data = JavaCode.variable("data", dataType) - val dataStr = JavaCode.variable("dataStr", StringType) - val functionCall = ctx.addNewFunction(funcName, - s""" - |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} $data) { - | UTF8String $dataStr = null; - | ${dataToStringCode(data, dataStr, null /* resultIsNull won't be used */)} - | return dataStr; - |} - """.stripMargin) - inline"$functionCall" - } - - val keyToStringFunc = dataToStringFunc("keyToString", kt) - val valueToStringFunc = dataToStringFunc("valueToString", vt) - val loopIndex = ctx.freshVariable("loopIndex", IntegerType) - val mapKeyArray = JavaCode.expression(s"$map.keyArray()", classOf[ArrayData]) - val mapValueArray = JavaCode.expression(s"$map.valueArray()", classOf[ArrayData]) - val getMapFirstKey = CodeGenerator.getValue(mapKeyArray, kt, JavaCode.literal("0", IntegerType)) - val getMapFirstValue = CodeGenerator.getValue(mapValueArray, vt, - JavaCode.literal("0", IntegerType)) - val getMapKeyArray = CodeGenerator.getValue(mapKeyArray, kt, loopIndex) - val getMapValueArray = CodeGenerator.getValue(mapValueArray, vt, loopIndex) - code""" - |$buffer.append("$leftBracket"); - |if ($map.numElements() > 0) { - | $buffer.append($keyToStringFunc($getMapFirstKey)); - | $buffer.append(" ->"); - | if ($map.valueArray().isNullAt(0)) { - | ${appendIfNotLegacyCastToStr(buffer, " NULL")} - | } else { - | $buffer.append(" "); - | $buffer.append($valueToStringFunc($getMapFirstValue)); - | } - | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) { - | $buffer.append(", "); - | $buffer.append($keyToStringFunc($getMapKeyArray)); - | $buffer.append(" ->"); - | if ($map.valueArray().isNullAt($loopIndex)) { - | ${appendIfNotLegacyCastToStr(buffer, " NULL")} - | } else { - | $buffer.append(" "); - | $buffer.append($valueToStringFunc($getMapValueArray)); - | } - | } - |} - |$buffer.append("$rightBracket"); - """.stripMargin - } - - private def writeStructToStringBuilder( - st: Seq[DataType], - row: ExprValue, - buffer: ExprValue, - ctx: CodegenContext): Block = { - val structToStringCode = st.zipWithIndex.map { case (ft, i) => - val fieldToStringCode = castToStringCode(ft, ctx) - val field = ctx.freshVariable("field", ft) - val fieldStr = ctx.freshVariable("fieldStr", StringType) - val javaType = JavaCode.javaType(ft) - code""" - |${if (i != 0) code"""$buffer.append(",");""" else EmptyBlock} - |if ($row.isNullAt($i)) { - | ${appendIfNotLegacyCastToStr(buffer, if (i == 0) "NULL" else " NULL")} - |} else { - | ${if (i != 0) code"""$buffer.append(" ");""" else EmptyBlock} - | - | // Append $i field into the string buffer - | $javaType $field = ${CodeGenerator.getValue(row, ft, s"$i")}; - | UTF8String $fieldStr = null; - | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} - | $buffer.append($fieldStr); - |} - """.stripMargin - } - - val writeStructCode = ctx.splitExpressions( - expressions = structToStringCode.map(_.code), - funcName = "fieldToString", - arguments = ("InternalRow", row.code) :: - (classOf[UTF8StringBuilder].getName, buffer.code) :: Nil) - - code""" - |$buffer.append("$leftBracket"); - |$writeStructCode - |$buffer.append("$rightBracket"); - """.stripMargin - } - - @scala.annotation.tailrec - private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { - from match { - case BinaryType => - (c, evPrim, evNull) => code"$evPrim = UTF8String.fromBytes($c);" - case DateType => - val df = JavaCode.global( - ctx.addReferenceObj("dateFormatter", dateFormatter), - dateFormatter.getClass) - (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString(${df}.format($c));""" - case TimestampType => - val tf = JavaCode.global( - ctx.addReferenceObj("timestampFormatter", timestampFormatter), - timestampFormatter.getClass) - (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString($tf.format($c));""" - case TimestampNTZType => - val tf = JavaCode.global( - ctx.addReferenceObj("timestampNTZFormatter", timestampNTZFormatter), - timestampNTZFormatter.getClass) - (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString($tf.format($c));""" - case CalendarIntervalType => - (c, evPrim, _) => code"""$evPrim = UTF8String.fromString($c.toString());""" - case ArrayType(et, _) => - (c, evPrim, evNull) => { - val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) - val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) - val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) - code""" - |$bufferClass $buffer = new $bufferClass(); - |$writeArrayElemCode; - |$evPrim = $buffer.build(); - """.stripMargin - } - case MapType(kt, vt, _) => - (c, evPrim, evNull) => { - val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) - val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) - val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) - code""" - |$bufferClass $buffer = new $bufferClass(); - |$writeMapElemCode; - |$evPrim = $buffer.build(); - """.stripMargin - } - case StructType(fields) => - (c, evPrim, evNull) => { - val row = ctx.freshVariable("row", classOf[InternalRow]) - val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) - val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) - val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) - code""" - |InternalRow $row = $c; - |$bufferClass $buffer = new $bufferClass(); - |$writeStructCode - |$evPrim = $buffer.build(); - """.stripMargin - } - case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx) - case udt: UserDefinedType[_] => - val udtRef = JavaCode.global(ctx.addReferenceObj("udt", udt), udt.sqlType) - (c, evPrim, evNull) => { - code"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" - } - case i: YearMonthIntervalType => - val iu = IntervalUtils.getClass.getName.stripSuffix("$") - val iss = IntervalStringStyles.getClass.getName.stripSuffix("$") - val style = s"$iss$$.MODULE$$.ANSI_STYLE()" - (c, evPrim, _) => - code""" - $evPrim = UTF8String.fromString($iu.toYearMonthIntervalString($c, $style, - (byte)${i.startField}, (byte)${i.endField})); - """ - case i: DayTimeIntervalType => - val iu = IntervalUtils.getClass.getName.stripSuffix("$") - val iss = IntervalStringStyles.getClass.getName.stripSuffix("$") - val style = s"$iss$$.MODULE$$.ANSI_STYLE()" - (c, evPrim, _) => - code""" - $evPrim = UTF8String.fromString($iu.toDayTimeIntervalString($c, $style, - (byte)${i.startField}, (byte)${i.endField})); - """ - // In ANSI mode, Spark always use plain string representation on casting Decimal values - // as strings. Otherwise, the casting is using `BigDecimal.toString` which may use scientific - // notation if an exponent is needed. - case _: DecimalType if ansiEnabled => - (c, evPrim, _) => code"$evPrim = UTF8String.fromString($c.toPlainString());" - case _ => - (c, evPrim, evNull) => code"$evPrim = UTF8String.fromString(String.valueOf($c));" - } - } - private[this] def castToBinaryCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala new file mode 100644 index 0000000000000..3b2868286dddf --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * An internal expressions which is used to generate pretty string for all kinds of values. It has + * several differences with casting value to string: + * - It prints null values (either from column or struct field) as "NULL". + * - It prints binary values (either from column or struct field) using the hex format. + */ +case class ToPrettyString(child: Expression, timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with ToStringBase { + + override def dataType: DataType = StringType + + override def nullable: Boolean = false + + override def withTimeZone(timeZoneId: String): ToPrettyString = + copy(timeZoneId = Some(timeZoneId)) + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) + + override protected def leftBracket: String = "{" + override protected def rightBracket: String = "}" + + override protected def nullString: String = "NULL" + + override protected def useDecimalPlainString: Boolean = true + + override protected def useHexFormatForBinary: Boolean = true + + private[this] lazy val castFunc: Any => Any = castToString(child.dataType) + + override def eval(input: InternalRow): Any = { + val v = child.eval(input) + if (v == null) UTF8String.fromString("NULL") else castFunc(v) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childCode = child.genCode(ctx) + val toStringCode = castToStringCode(child.dataType, ctx).apply(childCode.value, ev.value) + val finalCode = + code""" + |${childCode.code} + |UTF8String ${ev.value}; + |if (${childCode.isNull}) { + | ${ev.value} = UTF8String.fromString("NULL"); + |} else { + | $toStringCode + |} + |""".stripMargin + ev.copy(code = finalCode, isNull = FalseLiteral) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala new file mode 100644 index 0000000000000..7304d6739e880 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.time.ZoneOffset + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.util.{ArrayData, DateFormatter, IntervalStringStyles, IntervalUtils, MapData, StringUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.UTF8StringBuilder +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => + + private lazy val dateFormatter = DateFormatter() + private lazy val timestampFormatter = TimestampFormatter.getFractionFormatter(zoneId) + private lazy val timestampNTZFormatter = TimestampFormatter.getFractionFormatter(ZoneOffset.UTC) + + // The brackets that are used in casting structs and maps to strings + protected def leftBracket: String + protected def rightBracket: String + + // The string value to use to represent null elements in array/struct/map. + protected def nullString: String + + protected def useDecimalPlainString: Boolean + + protected def useHexFormatForBinary: Boolean + + // Makes the function accept Any type input by doing `asInstanceOf[T]`. + @inline private def acceptAny[T](func: T => Any): Any => Any = i => func(i.asInstanceOf[T]) + + // Returns a function to convert a value to pretty string. The function assumes input is not null. + protected final def castToString(from: DataType): Any => Any = from match { + case CalendarIntervalType => + acceptAny[CalendarInterval](i => UTF8String.fromString(i.toString)) + case BinaryType if useHexFormatForBinary => + acceptAny[Array[Byte]](binary => UTF8String.fromString(StringUtils.getHexString(binary))) + case BinaryType => + acceptAny[Array[Byte]](UTF8String.fromBytes) + case DateType => + acceptAny[Int](d => UTF8String.fromString(dateFormatter.format(d))) + case TimestampType => + acceptAny[Long](t => UTF8String.fromString(timestampFormatter.format(t))) + case TimestampNTZType => + acceptAny[Long](t => UTF8String.fromString(timestampNTZFormatter.format(t))) + case ArrayType(et, _) => + acceptAny[ArrayData](array => { + val builder = new UTF8StringBuilder + builder.append("[") + if (array.numElements > 0) { + val toUTF8String = castToString(et) + if (array.isNullAt(0)) { + if (nullString.nonEmpty) builder.append(nullString) + } else { + builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < array.numElements) { + builder.append(",") + if (array.isNullAt(i)) { + if (nullString.nonEmpty) builder.append(" " + nullString) + } else { + builder.append(" ") + builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) + case MapType(kt, vt, _) => + acceptAny[MapData](map => { + val builder = new UTF8StringBuilder + builder.append(leftBracket) + if (map.numElements > 0) { + val keyArray = map.keyArray() + val valueArray = map.valueArray() + val keyToUTF8String = castToString(kt) + val valueToUTF8String = castToString(vt) + builder.append(keyToUTF8String(keyArray.get(0, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (valueArray.isNullAt(0)) { + if (nullString.nonEmpty) builder.append(nullString) + } else { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(0, vt)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < map.numElements) { + builder.append(", ") + builder.append(keyToUTF8String(keyArray.get(i, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (valueArray.isNullAt(i)) { + if (nullString.nonEmpty) builder.append(" " + nullString) + } else { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(i, vt)) + .asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append(rightBracket) + builder.build() + }) + case StructType(fields) => + acceptAny[InternalRow](row => { + val builder = new UTF8StringBuilder + builder.append(leftBracket) + if (row.numFields > 0) { + val st = fields.map(_.dataType) + val toUTF8StringFuncs = st.map(castToString) + if (row.isNullAt(0)) { + if (nullString.nonEmpty) builder.append(nullString) + } else { + builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < row.numFields) { + builder.append(",") + if (row.isNullAt(i)) { + if (nullString.nonEmpty) builder.append(" " + nullString) + } else { + builder.append(" ") + builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append(rightBracket) + builder.build() + }) + case pudt: PythonUserDefinedType => castToString(pudt.sqlType) + case udt: UserDefinedType[_] => + o => UTF8String.fromString(udt.deserialize(o).toString) + case YearMonthIntervalType(startField, endField) => + acceptAny[Int](i => UTF8String.fromString( + IntervalUtils.toYearMonthIntervalString(i, ANSI_STYLE, startField, endField))) + case DayTimeIntervalType(startField, endField) => + acceptAny[Long](i => UTF8String.fromString( + IntervalUtils.toDayTimeIntervalString(i, ANSI_STYLE, startField, endField))) + case _: DecimalType if useDecimalPlainString => + acceptAny[Decimal](d => UTF8String.fromString(d.toPlainString)) + case StringType => identity + case _ => o => UTF8String.fromString(o.toString) + } + + // Returns a function to generate code to convert a value to pretty string. It assumes the input + // is not null. + @scala.annotation.tailrec + protected final def castToStringCode( + from: DataType, ctx: CodegenContext): (ExprValue, ExprValue) => Block = { + from match { + case BinaryType if useHexFormatForBinary => + (c, evPrim) => + val utilCls = StringUtils.getClass.getName.stripSuffix("$") + code"$evPrim = UTF8String.fromString($utilCls.getHexString($c));" + case BinaryType => + (c, evPrim) => code"$evPrim = UTF8String.fromBytes($c);" + case DateType => + val df = JavaCode.global( + ctx.addReferenceObj("dateFormatter", dateFormatter), + dateFormatter.getClass) + (c, evPrim) => code"$evPrim = UTF8String.fromString($df.format($c));" + case TimestampType => + val tf = JavaCode.global( + ctx.addReferenceObj("timestampFormatter", timestampFormatter), + timestampFormatter.getClass) + (c, evPrim) => code"$evPrim = UTF8String.fromString($tf.format($c));" + case TimestampNTZType => + val tf = JavaCode.global( + ctx.addReferenceObj("timestampNTZFormatter", timestampNTZFormatter), + timestampNTZFormatter.getClass) + (c, evPrim) => code"$evPrim = UTF8String.fromString($tf.format($c));" + case CalendarIntervalType => + (c, evPrim) => code"$evPrim = UTF8String.fromString($c.toString());" + case ArrayType(et, _) => + (c, evPrim) => { + val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) + val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) + code""" + |$bufferClass $buffer = new $bufferClass(); + |$writeArrayElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } + case MapType(kt, vt, _) => + (c, evPrim) => { + val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) + val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) + code""" + |$bufferClass $buffer = new $bufferClass(); + |$writeMapElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } + case StructType(fields) => + (c, evPrim) => { + val row = ctx.freshVariable("row", classOf[InternalRow]) + val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) + val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) + code""" + |InternalRow $row = $c; + |$bufferClass $buffer = new $bufferClass(); + |$writeStructCode + |$evPrim = $buffer.build(); + """.stripMargin + } + case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx) + case udt: UserDefinedType[_] => + val udtRef = JavaCode.global(ctx.addReferenceObj("udt", udt), udt.sqlType) + (c, evPrim) => + code"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" + case i: YearMonthIntervalType => + val iu = IntervalUtils.getClass.getName.stripSuffix("$") + val iss = IntervalStringStyles.getClass.getName.stripSuffix("$") + val style = s"$iss$$.MODULE$$.ANSI_STYLE()" + (c, evPrim) => + // scalastyle:off line.size.limit + code"$evPrim = UTF8String.fromString($iu.toYearMonthIntervalString($c, $style, (byte)${i.startField}, (byte)${i.endField}));" + // scalastyle:on line.size.limit + case i: DayTimeIntervalType => + val iu = IntervalUtils.getClass.getName.stripSuffix("$") + val iss = IntervalStringStyles.getClass.getName.stripSuffix("$") + val style = s"$iss$$.MODULE$$.ANSI_STYLE()" + (c, evPrim) => + // scalastyle:off line.size.limit + code"$evPrim = UTF8String.fromString($iu.toDayTimeIntervalString($c, $style, (byte)${i.startField}, (byte)${i.endField}));" + // scalastyle:on line.size.limit + // In ANSI mode, Spark always use plain string representation on casting Decimal values + // as strings. Otherwise, the casting is using `BigDecimal.toString` which may use scientific + // notation if an exponent is needed. + case _: DecimalType if useDecimalPlainString => + (c, evPrim) => code"$evPrim = UTF8String.fromString($c.toPlainString());" + case StringType => + (c, evPrim) => code"$evPrim = $c;" + case _ => + (c, evPrim) => code"$evPrim = UTF8String.fromString(String.valueOf($c));" + } + } + + private def appendNull(buffer: ExprValue, isFirstElement: Boolean): Block = { + if (nullString.isEmpty) { + EmptyBlock + } else if (isFirstElement) { + code"""$buffer.append("$nullString");""" + } else { + code"""$buffer.append(" $nullString");""" + } + } + + private def writeArrayToStringBuilder( + et: DataType, + array: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { + val elementToStringCode = castToStringCode(et, ctx) + val funcName = ctx.freshName("elementToString") + val element = JavaCode.variable("element", et) + val elementStr = JavaCode.variable("elementStr", StringType) + val elementToStringFunc = inline"${ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${CodeGenerator.javaType(et)} $element) { + | UTF8String $elementStr = null; + | ${elementToStringCode(element, elementStr)} + | return elementStr; + |} + """.stripMargin)}" + + val loopIndex = ctx.freshVariable("loopIndex", IntegerType) + code""" + |$buffer.append("["); + |if ($array.numElements() > 0) { + | if ($array.isNullAt(0)) { + | ${appendNull(buffer, isFirstElement = true)} + | } else { + | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { + | $buffer.append(","); + | if ($array.isNullAt($loopIndex)) { + | ${appendNull(buffer, isFirstElement = false)} + | } else { + | $buffer.append(" "); + | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + + private def writeMapToStringBuilder( + kt: DataType, + vt: DataType, + map: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { + + def dataToStringFunc(func: String, dataType: DataType) = { + val funcName = ctx.freshName(func) + val dataToStringCode = castToStringCode(dataType, ctx) + val data = JavaCode.variable("data", dataType) + val dataStr = JavaCode.variable("dataStr", StringType) + val functionCall = ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} $data) { + | UTF8String $dataStr = null; + | ${dataToStringCode(data, dataStr)} + | return dataStr; + |} + """.stripMargin) + inline"$functionCall" + } + + val keyToStringFunc = dataToStringFunc("keyToString", kt) + val valueToStringFunc = dataToStringFunc("valueToString", vt) + val loopIndex = ctx.freshVariable("loopIndex", IntegerType) + val mapKeyArray = JavaCode.expression(s"$map.keyArray()", classOf[ArrayData]) + val mapValueArray = JavaCode.expression(s"$map.valueArray()", classOf[ArrayData]) + val getMapFirstKey = CodeGenerator.getValue(mapKeyArray, kt, JavaCode.literal("0", IntegerType)) + val getMapFirstValue = CodeGenerator.getValue(mapValueArray, vt, + JavaCode.literal("0", IntegerType)) + val getMapKeyArray = CodeGenerator.getValue(mapKeyArray, kt, loopIndex) + val getMapValueArray = CodeGenerator.getValue(mapValueArray, vt, loopIndex) + code""" + |$buffer.append("$leftBracket"); + |if ($map.numElements() > 0) { + | $buffer.append($keyToStringFunc($getMapFirstKey)); + | $buffer.append(" ->"); + | if ($map.valueArray().isNullAt(0)) { + | ${appendNull(buffer, isFirstElement = true)} + | } else { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc($getMapFirstValue)); + | } + | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) { + | $buffer.append(", "); + | $buffer.append($keyToStringFunc($getMapKeyArray)); + | $buffer.append(" ->"); + | if ($map.valueArray().isNullAt($loopIndex)) { + | ${appendNull(buffer, isFirstElement = false)} + | } else { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc($getMapValueArray)); + | } + | } + |} + |$buffer.append("$rightBracket"); + """.stripMargin + } + + private def writeStructToStringBuilder( + st: Seq[DataType], + row: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { + val structToStringCode = st.zipWithIndex.map { case (ft, i) => + val fieldToStringCode = castToStringCode(ft, ctx) + val field = ctx.freshVariable("field", ft) + val fieldStr = ctx.freshVariable("fieldStr", StringType) + val javaType = JavaCode.javaType(ft) + code""" + |${if (i != 0) code"""$buffer.append(",");""" else EmptyBlock} + |if ($row.isNullAt($i)) { + | ${appendNull(buffer, isFirstElement = i == 0)} + |} else { + | ${if (i != 0) code"""$buffer.append(" ");""" else EmptyBlock} + | + | // Append $i field into the string buffer + | $javaType $field = ${CodeGenerator.getValue(row, ft, s"$i")}; + | UTF8String $fieldStr = null; + | ${fieldToStringCode(field, fieldStr)} + | $buffer.append($fieldStr); + |} + """.stripMargin + } + + val writeStructCode = ctx.splitExpressions( + expressions = structToStringCode.map(_.code), + funcName = "fieldToString", + arguments = ("InternalRow", row.code) :: + (classOf[UTF8StringBuilder].getName, buffer.code) :: Nil) + + code""" + |$buffer.append("$leftBracket"); + |$writeStructCode + |$buffer.append("$rightBracket"); + """.stripMargin + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index ad3cb449b5364..71936b71b53d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -67,6 +67,8 @@ object StringUtils extends Logging { "(?s)" + out.result() // (?s) enables dotall mode, causing "." to match new lines } + def getHexString(bytes: Array[Byte]): String = bytes.map("%02X".format(_)).mkString("[", " ", "]") + private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index 504fb8648b6e4..bad85ca4176b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -784,7 +784,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { Seq(false, true).foreach { omitNull => withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> omitNull.toString) { val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType) - checkEvaluation(ret3, s"[ab,${if (omitNull) "" else " NULL"}, c]") + checkEvaluation(ret3, s"[ab,${if (omitNull) "" else " null"}, c]") } } val ret4 = @@ -813,7 +813,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { val ret1 = cast(Literal.create(Array(null, null)), StringType) checkEvaluation( ret1, - s"[${if (omitNull) "" else "NULL"},${if (omitNull) "" else " NULL"}]") + s"[${if (omitNull) "" else "null"},${if (omitNull) "" else " null"}]") } } } @@ -828,7 +828,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { val ret2 = cast( Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)), StringType) - checkEvaluation(ret2, s"${lb}1 -> a, 2 ->${if (legacyCast) "" else " NULL"}, 3 -> c$rb") + checkEvaluation(ret2, s"${lb}1 -> a, 2 ->${if (legacyCast) "" else " null"}, 3 -> c$rb") val ret3 = cast( Literal.create(Map( 1 -> Date.valueOf("2014-12-03"), @@ -860,7 +860,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { val ret1 = cast(Literal.create((1, "a", 0.1)), StringType) checkEvaluation(ret1, s"${lb}1, a, 0.1$rb") val ret2 = cast(Literal.create(Tuple3[Int, String, String](1, null, "a")), StringType) - checkEvaluation(ret2, s"${lb}1,${if (legacyCast) "" else " NULL"}, a$rb") + checkEvaluation(ret2, s"${lb}1,${if (legacyCast) "" else " null"}, a$rb") val ret3 = cast(Literal.create( (Date.valueOf("2014-12-03"), Timestamp.valueOf("2014-12-03 15:05:00"))), StringType) checkEvaluation(ret3, s"${lb}2014-12-03, 2014-12-03 15:05:00$rb") @@ -882,7 +882,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { val ret1 = cast(Literal.create(Tuple2[String, String](null, null)), StringType) checkEvaluation( ret1, - s"$lb${if (legacyCast) "" else "NULL"},${if (legacyCast) "" else " NULL"}$rb") + s"$lb${if (legacyCast) "" else "null"},${if (legacyCast) "" else " null"}$rb") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d33a36a8380fe..1445af65f9f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -280,13 +280,7 @@ class Dataset[T] private[sql]( case _ => toDF() } val castCols = newDf.logicalPlan.output.map { col => - // Since binary types in top-level schema fields have a specific format to print, - // so we do not cast them to strings here. - if (col.dataType == BinaryType) { - Column(col) - } else { - Column(col).cast(StringType) - } + Column(ToPrettyString(col)) } val data = newDf.select(castCols: _*).take(numRows + 1) @@ -295,13 +289,8 @@ class Dataset[T] private[sql]( // first `truncate-3` and "..." schema.fieldNames.map(SchemaUtils.escapeMetaCharacters).toSeq +: data.map { row => row.toSeq.map { cell => - val str = cell match { - case null => "NULL" - case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") - case _ => - // Escapes meta-characters not to break the `showString` format - SchemaUtils.escapeMetaCharacters(cell.toString) - } + // Escapes meta-characters not to break the `showString` format + val str = SchemaUtils.escapeMetaCharacters(cell.toString) if (truncate > 0 && str.length > truncate) { // do not show ellipses for strings shorter than 4 characters. if (truncate < 4) str.substring(0, truncate) From 8e29c4334c6e4ddd0e4ff170f947c7ef43239f89 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 26 Apr 2023 20:43:22 +0800 Subject: [PATCH 2/2] address comments --- .../spark/sql/catalyst/expressions/ToPrettyString.scala | 4 ++-- .../org/apache/spark/sql/catalyst/util/StringUtils.scala | 4 ++++ sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 1 + 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala index 3b2868286dddf..aea704d4b7888 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala @@ -55,7 +55,7 @@ case class ToPrettyString(child: Expression, timeZoneId: Option[String] = None) override def eval(input: InternalRow): Any = { val v = child.eval(input) - if (v == null) UTF8String.fromString("NULL") else castFunc(v) + if (v == null) UTF8String.fromString(nullString) else castFunc(v) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -66,7 +66,7 @@ case class ToPrettyString(child: Expression, timeZoneId: Option[String] = None) |${childCode.code} |UTF8String ${ev.value}; |if (${childCode.isNull}) { - | ${ev.value} = UTF8String.fromString("NULL"); + | ${ev.value} = UTF8String.fromString("$nullString"); |} else { | $toStringCode |} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 71936b71b53d9..8a05616cac779 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -67,6 +67,10 @@ object StringUtils extends Logging { "(?s)" + out.result() // (?s) enables dotall mode, causing "." to match new lines } + /** + * Returns a pretty string of the byte array which prints each byte as a hex digit and add spaces + * between them. For example, [1A C0]. + */ def getHexString(bytes: Array[Byte]): String = bytes.map("%02X".format(_)).mkString("[", " ", "]") private[this] val trueStrings = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1445af65f9f1a..7973ba38b1a79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -289,6 +289,7 @@ class Dataset[T] private[sql]( // first `truncate-3` and "..." schema.fieldNames.map(SchemaUtils.escapeMetaCharacters).toSeq +: data.map { row => row.toSeq.map { cell => + assert(cell != null, "ToPrettyString is not nullable and should not return null value") // Escapes meta-characters not to break the `showString` format val str = SchemaUtils.escapeMetaCharacters(cell.toString) if (truncate > 0 && str.length > truncate) {