diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 6de40629ff27e..1a14a7a449342 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -392,12 +392,13 @@ case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInput /** * Decodes the first argument into a String using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. (As of Hive 0.12.0.). + * If either argument is null, the result will also be null. */ -case class Decode(bin: Expression, charset: Expression) extends Expression with ExpectsInputTypes { - override def children: Seq[Expression] = bin :: charset :: Nil - override def foldable: Boolean = bin.foldable && charset.foldable - override def nullable: Boolean = bin.nullable || charset.nullable +case class Decode(bin: Expression, charset: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = bin + override def right: Expression = charset override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType) @@ -420,13 +421,13 @@ case class Decode(bin: Expression, charset: Expression) extends Expression with /** * Encodes the first argument into a BINARY using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. (As of Hive 0.12.0.) + * If either argument is null, the result will also be null. */ case class Encode(value: Expression, charset: Expression) - extends Expression with ExpectsInputTypes { - override def children: Seq[Expression] = value :: charset :: Nil - override def foldable: Boolean = value.foldable && charset.foldable - override def nullable: Boolean = value.nullable || charset.nullable + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = value + override def right: Expression = charset override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType, StringType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index abcfc0b65020c..f80291776f335 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1666,18 +1666,19 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def encode(value: Column, charset: Column): Column = Encode(value.expr, charset.expr) + def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr) /** * Computes the first argument into a binary from a string using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. + * NOTE: charset represents the string value of the character set, not the column name. * * @group string_funcs * @since 1.5.0 */ - def encode(columnName: String, charsetColumnName: String): Column = - encode(Column(columnName), Column(charsetColumnName)) + def encode(columnName: String, charset: String): Column = + encode(Column(columnName), charset) /** * Computes the first argument into a string from a binary using the provided character set @@ -1687,18 +1688,19 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def decode(value: Column, charset: Column): Column = Decode(value.expr, charset.expr) + def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr) /** * Computes the first argument into a string from a binary using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. + * NOTE: charset represents the string value of the character set, not the column name. * * @group string_funcs * @since 1.5.0 */ - def decode(columnName: String, charsetColumnName: String): Column = - decode(Column(columnName), Column(charsetColumnName)) + def decode(columnName: String, charset: String): Column = + decode(Column(columnName), charset) ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index bc455a922d154..afba28515e032 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -261,11 +261,15 @@ class DataFrameFunctionsSuite extends QueryTest { // non ascii characters are not allowed in the code, so we disable the scalastyle here. val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") checkAnswer( - df.select(encode($"a", $"b"), encode("a", "b"), decode($"c", $"b"), decode("c", "b")), + df.select( + encode($"a", "utf-8"), + encode("a", "utf-8"), + decode($"c", "utf-8"), + decode("c", "utf-8")), Row(bytes, bytes, "大千世界", "大千世界")) checkAnswer( - df.selectExpr("encode(a, b)", "decode(c, b)"), + df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"), Row(bytes, "大千世界")) // scalastyle:on }