Skip to content

Commit

Permalink
Address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 29, 2014
1 parent 991f860 commit bd40a33
Show file tree
Hide file tree
Showing 28 changed files with 154 additions and 78 deletions.
88 changes: 52 additions & 36 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,18 @@ class StringType(object):
"""
__metaclass__ = PrimitiveTypeSingleton

def _get_scala_type_string(self):
def __repr__(self):
return "StringType"

class BinaryType(object):
"""Spark SQL BinaryType
The data type representing bytes values and bytearray values.
The data type representing bytearray values.
"""
__metaclass__ = PrimitiveTypeSingleton

def _get_scala_type_string(self):
def __repr__(self):
return "BinaryType"

class BooleanType(object):
Expand All @@ -63,14 +63,18 @@ class BooleanType(object):
"""
__metaclass__ = PrimitiveTypeSingleton

def _get_scala_type_string(self):
def __repr__(self):
return "BooleanType"

class TimestampType(object):
"""Spark SQL TimestampType"""
"""Spark SQL TimestampType
The data type representing datetime.datetime values.
"""
__metaclass__ = PrimitiveTypeSingleton

def _get_scala_type_string(self):
def __repr__(self):
return "TimestampType"

class DecimalType(object):
Expand All @@ -81,40 +85,48 @@ class DecimalType(object):
"""
__metaclass__ = PrimitiveTypeSingleton

def _get_scala_type_string(self):
def __repr__(self):
return "DecimalType"

class DoubleType(object):
"""Spark SQL DoubleType
The data type representing float values. Because a float value
The data type representing float values.
"""
__metaclass__ = PrimitiveTypeSingleton

def _get_scala_type_string(self):
def __repr__(self):
return "DoubleType"

class FloatType(object):
"""Spark SQL FloatType
For PySpark, please use L{DoubleType} instead of using L{FloatType}.
For now, please use L{DoubleType} instead of using L{FloatType}.
Because query evaluation is done in Scala, java.lang.Double will be be used
for Python float numbers. Because the underlying JVM type of FloatType is
java.lang.Float (in Java) and Float (in scala), there will be a java.lang.ClassCastException
if FloatType (Python) used.
"""
__metaclass__ = PrimitiveTypeSingleton

def _get_scala_type_string(self):
def __repr__(self):
return "FloatType"

class ByteType(object):
"""Spark SQL ByteType
For PySpark, please use L{IntegerType} instead of using L{ByteType}.
For now, please use L{IntegerType} instead of using L{ByteType}.
Because query evaluation is done in Scala, java.lang.Integer will be be used
for Python int numbers. Because the underlying JVM type of ByteType is
java.lang.Byte (in Java) and Byte (in scala), there will be a java.lang.ClassCastException
if ByteType (Python) used.
"""
__metaclass__ = PrimitiveTypeSingleton

def _get_scala_type_string(self):
def __repr__(self):
return "ByteType"

class IntegerType(object):
Expand All @@ -125,7 +137,7 @@ class IntegerType(object):
"""
__metaclass__ = PrimitiveTypeSingleton

def _get_scala_type_string(self):
def __repr__(self):
return "IntegerType"

class LongType(object):
Expand All @@ -137,18 +149,22 @@ class LongType(object):
"""
__metaclass__ = PrimitiveTypeSingleton

def _get_scala_type_string(self):
def __repr__(self):
return "LongType"

class ShortType(object):
"""Spark SQL ShortType
For PySpark, please use L{IntegerType} instead of using L{ShortType}.
For now, please use L{IntegerType} instead of using L{ShortType}.
Because query evaluation is done in Scala, java.lang.Integer will be be used
for Python int numbers. Because the underlying JVM type of ShortType is
java.lang.Short (in Java) and Short (in scala), there will be a java.lang.ClassCastException
if ShortType (Python) used.
"""
__metaclass__ = PrimitiveTypeSingleton

def _get_scala_type_string(self):
def __repr__(self):
return "ShortType"

class ArrayType(object):
Expand All @@ -157,23 +173,23 @@ class ArrayType(object):
The data type representing list values.
"""
def __init__(self, elementType, containsNull):
def __init__(self, elementType, containsNull=False):
"""Creates an ArrayType
:param elementType: the data type of elements.
:param containsNull: indicates whether the list contains null values.
:return:
>>> ArrayType(StringType, True) == ArrayType(StringType, False)
False
>>> ArrayType(StringType, True) == ArrayType(StringType, True)
>>> ArrayType(StringType) == ArrayType(StringType, False)
True
>>> ArrayType(StringType, True) == ArrayType(StringType)
False
"""
self.elementType = elementType
self.containsNull = containsNull

def _get_scala_type_string(self):
return "ArrayType(" + self.elementType._get_scala_type_string() + "," + \
def __repr__(self):
return "ArrayType(" + self.elementType.__repr__() + "," + \
str(self.containsNull).lower() + ")"

def __eq__(self, other):
Expand Down Expand Up @@ -207,9 +223,9 @@ def __init__(self, keyType, valueType, valueContainsNull=True):
self.valueType = valueType
self.valueContainsNull = valueContainsNull

def _get_scala_type_string(self):
return "MapType(" + self.keyType._get_scala_type_string() + "," + \
self.valueType._get_scala_type_string() + "," + \
def __repr__(self):
return "MapType(" + self.keyType.__repr__() + "," + \
self.valueType.__repr__() + "," + \
str(self.valueContainsNull).lower() + ")"

def __eq__(self, other):
Expand Down Expand Up @@ -243,9 +259,9 @@ def __init__(self, name, dataType, nullable):
self.dataType = dataType
self.nullable = nullable

def _get_scala_type_string(self):
def __repr__(self):
return "StructField(" + self.name + "," + \
self.dataType._get_scala_type_string() + "," + \
self.dataType.__repr__() + "," + \
str(self.nullable).lower() + ")"

def __eq__(self, other):
Expand Down Expand Up @@ -280,9 +296,9 @@ def __init__(self, fields):
"""
self.fields = fields

def _get_scala_type_string(self):
def __repr__(self):
return "StructType(List(" + \
",".join([field._get_scala_type_string() for field in self.fields]) + "))"
",".join([field.__repr__() for field in self.fields]) + "))"

def __eq__(self, other):
return (isinstance(other, self.__class__) and \
Expand Down Expand Up @@ -319,7 +335,7 @@ def _parse_datatype_string(datatype_string):
:return:
>>> def check_datatype(datatype):
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype._get_scala_type_string())
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.__repr__())
... python_datatype = _parse_datatype_string(scala_datatype.toString())
... return datatype == python_datatype
>>> check_datatype(StringType())
Expand Down Expand Up @@ -536,7 +552,7 @@ def applySchema(self, rdd, schema):
True
"""
jrdd = self._pythonToJavaMap(rdd._jrdd)
srdd = self._ssql_ctx.applySchema(jrdd.rdd(), schema._get_scala_type_string())
srdd = self._ssql_ctx.applySchema(jrdd.rdd(), schema.__repr__())
return SchemaRDD(srdd, self)

def registerRDDAsTable(self, rdd, tableName):
Expand Down Expand Up @@ -569,7 +585,7 @@ def parquetFile(self, path):
jschema_rdd = self._ssql_ctx.parquetFile(path)
return SchemaRDD(jschema_rdd, self)

def jsonFile(self, path, schema = None):
def jsonFile(self, path, schema=None):
"""Loads a text file storing one JSON object per line as a L{SchemaRDD}.
If the schema is provided, applies the given schema to this JSON dataset.
Expand Down Expand Up @@ -618,11 +634,11 @@ def jsonFile(self, path, schema = None):
if schema is None:
jschema_rdd = self._ssql_ctx.jsonFile(path)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema._get_scala_type_string())
scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__())
jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(jschema_rdd, self)

def jsonRDD(self, rdd, schema = None):
def jsonRDD(self, rdd, schema=None):
"""Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
If the schema is provided, applies the given schema to this JSON dataset.
Expand Down Expand Up @@ -672,7 +688,7 @@ def func(split, iterator):
if schema is None:
jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
else:
scala_datatype = self._ssql_ctx.parseDataType(schema._get_scala_type_string())
scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__())
jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return SchemaRDD(jschema_rdd, self)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
object ResolveReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case q: LogicalPlan if q.childrenResolved =>
logger.trace(s"Attempting to resolve ${q.simpleString}")
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressions {
case u @ UnresolvedAttribute(name) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result = q.resolve(name).getOrElse(u)
logger.debug(s"Resolving $u to $result")
logDebug(s"Resolving $u to $result")
result
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ trait HiveTypeCoercion {
// Leave the same if the dataTypes match.
case Some(newType) if a.dataType == newType.dataType => a
case Some(newType) =>
logger.debug(s"Promoting $a to $newType in ${q.simpleString}}")
logDebug(s"Promoting $a to $newType in ${q.simpleString}}")
newType
}
}
Expand Down Expand Up @@ -154,7 +154,7 @@ trait HiveTypeCoercion {
(Alias(Cast(l, StringType), l.name)(), r)

case (l, r) if l.dataType != r.dataType =>
logger.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}")
logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}")
findTightestCommonType(l.dataType, r.dataType).map { widestType =>
val newLeft =
if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)()
Expand All @@ -170,15 +170,15 @@ trait HiveTypeCoercion {

val newLeft =
if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
logger.debug(s"Widening numeric types in union $castedLeft ${left.output}")
logDebug(s"Widening numeric types in union $castedLeft ${left.output}")
Project(castedLeft, left)
} else {
left
}

val newRight =
if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
logger.debug(s"Widening numeric types in union $castedRight ${right.output}")
logDebug(s"Widening numeric types in union $castedRight ${right.output}")
Project(castedRight, right)
} else {
right
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import com.typesafe.scalalogging.slf4j.Logging

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.plans.QueryPlan
Expand Down Expand Up @@ -80,7 +79,7 @@ object BindReferences extends Logging {
// produce new attributes that can't be bound. Likely the right thing to do is remove
// this rule and require all operators to explicitly bind to the input schema that
// they specify.
logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
logDebug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
a
} else {
BoundReference(ordinal, a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

package org.apache.spark.sql.catalyst.planning

import com.typesafe.scalalogging.slf4j.Logging

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.TreeNode

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.planning

import scala.annotation.tailrec

import com.typesafe.scalalogging.slf4j.Logging

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -114,7 +113,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {

def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
case join @ Join(left, right, joinType, condition) =>
logger.debug(s"Considering join on: $condition")
logDebug(s"Considering join on: $condition")
// Find equi-join predicates that can be evaluated before the join, and thus can be used
// as join keys.
val (joinPredicates, otherPredicates) =
Expand All @@ -132,7 +131,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
val rightKeys = joinKeys.map(_._2)

if (joinKeys.nonEmpty) {
logger.debug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}")
logDebug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}")
Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right))
} else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

package org.apache.spark.sql.catalyst.rules

import com.typesafe.scalalogging.slf4j.Logging

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.trees.TreeNode

abstract class Rule[TreeType <: TreeNode[_]] extends Logging {
Expand Down
Loading

0 comments on commit bd40a33

Please sign in to comment.