Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def join(self, other, joinExprs=None, joinType=None):
One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.

>>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
[Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
[Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]
"""

if joinExprs is None:
Expand Down Expand Up @@ -637,9 +637,9 @@ def groupBy(self, *cols):
>>> df.groupBy().avg().collect()
[Row(AVG(age)=3.5)]
>>> df.groupBy('name').agg({'age': 'mean'}).collect()
[Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)]
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
>>> df.groupBy(df.name).avg().collect()
[Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)]
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
"""
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
Expand Down Expand Up @@ -867,11 +867,11 @@ def agg(self, *exprs):

>>> gdf = df.groupBy(df.name)
>>> gdf.agg({"*": "count"}).collect()
[Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)]
[Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]

>>> from pyspark.sql import functions as F
>>> gdf.agg(F.min(df.age)).collect()
[Row(MIN(age)=5), Row(MIN(age)=2)]
[Row(MIN(age)=2), Row(MIN(age)=5)]
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
Expand Down
3 changes: 2 additions & 1 deletion sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.util.hashing.MurmurHash3

import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types.{StructType, DateUtils}
import org.apache.spark.sql.types.StructType

object Row {
/**
Expand Down Expand Up @@ -257,6 +257,7 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
// TODO(davies): This is not the right default implementation, we use Int as Date internally
def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date]

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ trait ScalaReflection {
case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
case (s: Array[_], arrayType: ArrayType) => if (arrayType.elementType.isPrimitive) {
s.toSeq
} else {
s.toSeq.map(convertToCatalyst(_, arrayType.elementType))
}
case (s: Array[_], arrayType: ArrayType) =>
if (arrayType.elementType.isPrimitive) {
s.toSeq
} else {
s.toSeq.map(convertToCatalyst(_, arrayType.elementType))
}
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
}
Expand All @@ -72,6 +73,7 @@ trait ScalaReflection {
case (d: BigDecimal, _) => Decimal(d)
case (d: java.math.BigDecimal, _) => Decimal(d)
case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d)
case (s: String, _) => UTF8String(s)
case (r: Row, structType: StructType) =>
new GenericRow(
r.toSeq.zip(structType.fields).map { case (elem, field) =>
Expand All @@ -80,6 +82,24 @@ trait ScalaReflection {
case (other, _) => other
}

/**
* Converts Scala objects to catalyst rows / types.
* Note: This should be called before do evaluation on Row
* (It does not support UDT)
*/
def convertToCatalyst(a: Any): Any = a match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this function somewhere before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we have a similar one: convertToCatalyst (a: Any, dt: DataType)

case s: String => UTF8String(s)
case d: java.sql.Date => DateUtils.fromJavaDate(d)
case d: BigDecimal => Decimal(d)
case d: java.math.BigDecimal => Decimal(d)
case seq: Seq[Any] => seq.map(convertToCatalyst)
case r: Row => Row(r.toSeq.map(convertToCatalyst): _*)
case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray
case m: Map[Any, Any] =>
m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap
case other => other
}

/** Converts Catalyst types used internally in rows to standard Scala types */
def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
// Check UDT first since UDTs can override other types
Expand All @@ -91,6 +111,7 @@ trait ScalaReflection {
case (r: Row, s: StructType) => convertRowToScala(r, s)
case (d: Decimal, _: DecimalType) => d.toJavaBigDecimal
case (i: Int, DateType) => DateUtils.toJavaDate(i)
case (s: UTF8String, StringType) => s.toString()
case (other, _) => other
}

Expand Down Expand Up @@ -193,6 +214,7 @@ trait ScalaReflection {
// The data type can be determined without ambiguity.
case obj: BooleanType.JvmType => BooleanType
case obj: BinaryType.JvmType => BinaryType
case obj: String => StringType
case obj: StringType.JvmType => StringType
case obj: ByteType.JvmType => ByteType
case obj: ShortType.JvmType => ShortType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ trait HiveTypeCoercion {
* the appropriate numeric equivalent.
*/
object ConvertNaNs extends Rule[LogicalPlan] {
val stringNaN = Literal.create("NaN", StringType)
val stringNaN = Literal("NaN")

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.types._

/** Cast the child expression to the target data type. */
Expand Down Expand Up @@ -112,21 +111,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w

// UDFToString
private[this] def castToString(from: DataType): Any => Any = from match {
case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
case DateType => buildCast[Int](_, d => DateUtils.toString(d))
case TimestampType => buildCast[Timestamp](_, timestampToString)
case _ => buildCast[Any](_, _.toString)
case BinaryType => buildCast[Array[Byte]](_, UTF8String(_))
case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d)))
case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t)))
case _ => buildCast[Any](_, o => UTF8String(o.toString))
}

// BinaryConverter
private[this] def castToBinary(from: DataType): Any => Any = from match {
case StringType => buildCast[String](_, _.getBytes("UTF-8"))
case StringType => buildCast[UTF8String](_, _.getBytes)
}

// UDFToBoolean
private[this] def castToBoolean(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, _.length() != 0)
buildCast[UTF8String](_, _.length() != 0)
case TimestampType =>
buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0)
case DateType =>
Expand All @@ -151,8 +150,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// TimestampConverter
private[this] def castToTimestamp(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => {
buildCast[UTF8String](_, utfs => {
// Throw away extra if more than 9 decimal places
val s = utfs.toString
val periodIdx = s.indexOf(".")
var n = s
if (periodIdx != -1 && n.length() - periodIdx > 9) {
Expand Down Expand Up @@ -227,8 +227,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// DateConverter
private[this] def castToDate(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s =>
try DateUtils.fromJavaDate(Date.valueOf(s))
buildCast[UTF8String](_, s =>
try DateUtils.fromJavaDate(Date.valueOf(s.toString))
catch { case _: java.lang.IllegalArgumentException => null }
)
case TimestampType =>
Expand All @@ -245,7 +245,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toLong catch {
buildCast[UTF8String](_, s => try s.toString.toLong catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand All @@ -261,7 +261,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toInt catch {
buildCast[UTF8String](_, s => try s.toString.toInt catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand All @@ -277,7 +277,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// ShortConverter
private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toShort catch {
buildCast[UTF8String](_, s => try s.toString.toShort catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand All @@ -293,7 +293,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// ByteConverter
private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toByte catch {
buildCast[UTF8String](_, s => try s.toString.toByte catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand Down Expand Up @@ -323,7 +323,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w

private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch {
buildCast[UTF8String](_, s => try {
changePrecision(Decimal(s.toString.toDouble), target)
} catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand All @@ -348,7 +350,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// DoubleConverter
private[this] def castToDouble(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toDouble catch {
buildCast[UTF8String](_, s => try s.toString.toDouble catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand All @@ -364,7 +366,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// FloatConverter
private[this] def castToFloat(from: DataType): Any => Any = from match {
case StringType =>
buildCast[String](_, s => try s.toFloat catch {
buildCast[UTF8String](_, s => try s.toString.toFloat catch {
case _: NumberFormatException => null
})
case BooleanType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,25 @@ final class MutableByte extends MutableValue {
}
}

final class MutableString extends MutableValue {
var value: UTF8String = _
override def boxed: Any = if (isNull) null else value
override def update(v: Any): Unit = {
isNull = false
if (value == null) {
value = v.asInstanceOf[UTF8String]
} else {
value.set(v.asInstanceOf[UTF8String].getBytes)
}
}
override def copy(): MutableString = {
val newCopy = new MutableString
newCopy.isNull = isNull
newCopy.value = value.clone()
newCopy.asInstanceOf[MutableString]
}
}

final class MutableAny extends MutableValue {
var value: Any = _
override def boxed: Any = if (isNull) null else value
Expand Down Expand Up @@ -202,6 +221,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
case DoubleType => new MutableDouble
case BooleanType => new MutableBoolean
case LongType => new MutableLong
// TODO(davies): enable this
// case StringType => new MutableString
case _ => new MutableAny
}.toArray)

Expand Down Expand Up @@ -230,13 +251,17 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
new GenericRow(newValues)
}

override def update(ordinal: Int, value: Any): Unit = {
if (value == null) setNullAt(ordinal) else values(ordinal).update(value)
override def update(ordinal: Int, value: Any): Unit = value match {
case null => setNullAt(ordinal)
case s: String =>
// for tests
throw new Exception("String should be converted into UTF8String")
case other => values(ordinal).update(value)
}

override def setString(ordinal: Int, value: String): Unit = update(ordinal, value)
override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value))

override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String]
override def getString(ordinal: Int): String = apply(ordinal).toString

override def setInt(ordinal: Int, value: Int): Unit = {
val currentValue = values(ordinal).asInstanceOf[MutableInt]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val $primitiveTerm: ${termForType(dataType)} = $value
""".children

case expressions.Literal(value: String, dataType) =>
case expressions.Literal(value: UTF8String, dataType) =>
q"""
val $nullTerm = ${value == null}
val $primitiveTerm: ${termForType(dataType)} = $value
val $primitiveTerm: ${termForType(dataType)} =
org.apache.spark.sql.types.UTF8String(${value.toString})
""".children

case expressions.Literal(value: Int, dataType) =>
Expand All @@ -243,11 +244,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
if($nullTerm)
${defaultPrimitive(StringType)}
else
new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
""".children

case Cast(child @ DateType(), StringType) =>
child.castOrNull(c => q"org.apache.spark.sql.types.DateUtils.toString($c)", StringType)
child.castOrNull(c =>
q"""org.apache.spark.sql.types.UTF8String(
org.apache.spark.sql.types.DateUtils.toString($c))""",
StringType)

case Cast(child @ NumericType(), IntegerType) =>
child.castOrNull(c => q"$c.toInt", IntegerType)
Expand All @@ -272,7 +276,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
if($nullTerm)
${defaultPrimitive(StringType)}
else
${eval.primitiveTerm}.toString
org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString)
""".children

case EqualTo(e1, e2) =>
Expand Down Expand Up @@ -573,7 +577,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val localLogger = log
val localLoggerTree = reify { localLogger }
q"""
$localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm))
$localLoggerTree.debug(
${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString))
""" :: Nil
} else {
Nil
Expand All @@ -584,6 +589,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin

protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = {
dataType match {
case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]"
case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)"
case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]"
}
Expand All @@ -595,6 +601,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
ordinal: Int,
value: TermName) = {
dataType match {
case StringType => q"$destinationRow.update($ordinal, $value)"
case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
case _ => q"$destinationRow.update($ordinal, $value)"
}
Expand All @@ -618,13 +625,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
case DoubleType => "Double"
case FloatType => "Float"
case BooleanType => "Boolean"
case StringType => "String"
case StringType => "org.apache.spark.sql.types.UTF8String"
}

protected def defaultPrimitive(dt: DataType) = dt match {
case BooleanType => ru.Literal(Constant(false))
case FloatType => ru.Literal(Constant(-1.0.toFloat))
case StringType => ru.Literal(Constant("<uninit>"))
case StringType => q"""org.apache.spark.sql.types.UTF8String("<uninit>")"""
case ShortType => ru.Literal(Constant(-1.toShort))
case LongType => ru.Literal(Constant(-1L))
case ByteType => ru.Literal(Constant(-1.toByte))
Expand Down
Loading