Skip to content
Permalink
Browse files

[SPARK-23054][SQL][PYSPARK][FOLLOWUP] Use sqlType casting when castin…

…g PythonUserDefinedType to String.

## What changes were proposed in this pull request?

This is a follow-up of #20246.

If a UDT in Python doesn't have its corresponding Scala UDT, cast to string will be the raw string of the internal value, e.g. `"org.apache.spark.sql.catalyst.expressions.UnsafeArrayDataxxxxxxxx"` if the internal type is `ArrayType`.

This pr fixes it by using its `sqlType` casting.

## How was this patch tested?

Added a test and existing tests.

Author: Takuya UESHIN <ueshin@databricks.com>

Closes #20306 from ueshin/issues/SPARK-23054/fup1.

(cherry picked from commit 568055d)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information...
ueshin authored and cloud-fan committed Jan 19, 2018
1 parent 225b1af commit 541dbc00b24f17d83ea2531970f2e9fe57fe3718
@@ -1189,6 +1189,17 @@ def test_union_with_udt(self):
]
)

def test_cast_to_string_with_udt(self):
from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
from pyspark.sql.functions import col
row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
schema = StructType([StructField("point", ExamplePointUDT(), False),
StructField("pypoint", PythonOnlyUDT(), False)])
df = self.spark.createDataFrame([row], schema)

result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head()
self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]'))

def test_column_operators(self):
ci = self.df.key
cs = self.df.value
@@ -282,6 +282,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
builder.append("]")
builder.build()
})
case pudt: PythonUserDefinedType => castToString(pudt.sqlType)
case udt: UserDefinedType[_] =>
buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString))
case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
@@ -838,6 +839,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|$evPrim = $buffer.build();
""".stripMargin
}
case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx)
case udt: UserDefinedType[_] =>
val udtRef = ctx.addReferenceObj("udt", udt)
(c, evPrim, evNull) => {
@@ -34,6 +34,8 @@ private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializab
case that: ExamplePoint => this.x == that.x && this.y == that.y
case _ => false
}

override def toString(): String = s"($x, $y)"
}

/**

0 comments on commit 541dbc0

Please sign in to comment.
You can’t perform that action at this time.