Skip to content

Commit

Permalink
Address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 29, 2014
1 parent 03bfd95 commit 122d1e7
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 73 deletions.
54 changes: 38 additions & 16 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@
"ShortType", "ArrayType", "MapType", "StructField", "StructType",
"SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"]


class PrimitiveTypeSingleton(type):
_instances = {}

def __call__(cls):
if cls not in cls._instances:
cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__()
return cls._instances[cls]


class StringType(object):
"""Spark SQL StringType
Expand All @@ -44,6 +47,7 @@ class StringType(object):
def __repr__(self):
return "StringType"


class BinaryType(object):
"""Spark SQL BinaryType
Expand All @@ -55,6 +59,7 @@ class BinaryType(object):
def __repr__(self):
return "BinaryType"


class BooleanType(object):
"""Spark SQL BooleanType
Expand All @@ -66,6 +71,7 @@ class BooleanType(object):
def __repr__(self):
return "BooleanType"


class TimestampType(object):
"""Spark SQL TimestampType
Expand All @@ -77,6 +83,7 @@ class TimestampType(object):
def __repr__(self):
return "TimestampType"


class DecimalType(object):
"""Spark SQL DecimalType
Expand All @@ -88,6 +95,7 @@ class DecimalType(object):
def __repr__(self):
return "DecimalType"


class DoubleType(object):
"""Spark SQL DoubleType
Expand All @@ -99,13 +107,15 @@ class DoubleType(object):
def __repr__(self):
return "DoubleType"


class FloatType(object):
"""Spark SQL FloatType
For now, please use L{DoubleType} instead of using L{FloatType}.
Because query evaluation is done in Scala, java.lang.Double will be be used
for Python float numbers. Because the underlying JVM type of FloatType is
java.lang.Float (in Java) and Float (in scala), there will be a java.lang.ClassCastException
java.lang.Float (in Java) and Float (in scala), and we are trying to cast the type,
there will be a java.lang.ClassCastException
if FloatType (Python) is used.
"""
Expand All @@ -114,13 +124,15 @@ class FloatType(object):
def __repr__(self):
return "FloatType"


class ByteType(object):
"""Spark SQL ByteType
For now, please use L{IntegerType} instead of using L{ByteType}.
Because query evaluation is done in Scala, java.lang.Integer will be be used
for Python int numbers. Because the underlying JVM type of ByteType is
java.lang.Byte (in Java) and Byte (in scala), there will be a java.lang.ClassCastException
java.lang.Byte (in Java) and Byte (in scala), and we are trying to cast the type,
there will be a java.lang.ClassCastException
if ByteType (Python) is used.
"""
Expand All @@ -129,6 +141,7 @@ class ByteType(object):
def __repr__(self):
return "ByteType"


class IntegerType(object):
"""Spark SQL IntegerType
Expand All @@ -140,6 +153,7 @@ class IntegerType(object):
def __repr__(self):
return "IntegerType"


class LongType(object):
"""Spark SQL LongType
Expand All @@ -152,13 +166,15 @@ class LongType(object):
def __repr__(self):
return "LongType"


class ShortType(object):
"""Spark SQL ShortType
For now, please use L{IntegerType} instead of using L{ShortType}.
Because query evaluation is done in Scala, java.lang.Integer will be be used
for Python int numbers. Because the underlying JVM type of ShortType is
java.lang.Short (in Java) and Short (in scala), there will be a java.lang.ClassCastException
java.lang.Short (in Java) and Short (in scala), and we are trying to cast the type,
there will be a java.lang.ClassCastException
if ShortType (Python) is used.
"""
Expand All @@ -167,6 +183,7 @@ class ShortType(object):
def __repr__(self):
return "ShortType"


class ArrayType(object):
"""Spark SQL ArrayType
Expand Down Expand Up @@ -196,9 +213,9 @@ def __repr__(self):
str(self.containsNull).lower() + ")"

def __eq__(self, other):
return (isinstance(other, self.__class__) and \
self.elementType == other.elementType and \
self.containsNull == other.containsNull)
return (isinstance(other, self.__class__) and
self.elementType == other.elementType and
self.containsNull == other.containsNull)

def __ne__(self, other):
return not self.__eq__(other)
Expand Down Expand Up @@ -238,14 +255,15 @@ def __repr__(self):
str(self.valueContainsNull).lower() + ")"

def __eq__(self, other):
return (isinstance(other, self.__class__) and \
self.keyType == other.keyType and \
self.valueType == other.valueType and \
self.valueContainsNull == other.valueContainsNull)
return (isinstance(other, self.__class__) and
self.keyType == other.keyType and
self.valueType == other.valueType and
self.valueContainsNull == other.valueContainsNull)

def __ne__(self, other):
return not self.__eq__(other)


class StructField(object):
"""Spark SQL StructField
Expand Down Expand Up @@ -278,14 +296,15 @@ def __repr__(self):
str(self.nullable).lower() + ")"

def __eq__(self, other):
return (isinstance(other, self.__class__) and \
self.name == other.name and \
self.dataType == other.dataType and \
self.nullable == other.nullable)
return (isinstance(other, self.__class__) and
self.name == other.name and
self.dataType == other.dataType and
self.nullable == other.nullable)

def __ne__(self, other):
return not self.__eq__(other)


class StructType(object):
"""Spark SQL StructType
Expand Down Expand Up @@ -315,12 +334,13 @@ def __repr__(self):
",".join([field.__repr__() for field in self.fields]) + "))"

def __eq__(self, other):
return (isinstance(other, self.__class__) and \
self.fields == other.fields)
return (isinstance(other, self.__class__) and
self.fields == other.fields)

def __ne__(self, other):
return not self.__eq__(other)


def _parse_datatype_list(datatype_list_string):
"""Parses a list of comma separated data types.
Expand Down Expand Up @@ -348,6 +368,7 @@ def _parse_datatype_list(datatype_list_string):
datatype_list.append(_parse_datatype_string(datatype_string))
return datatype_list


def _parse_datatype_string(datatype_string):
"""Parses the given data type string.
Expand Down Expand Up @@ -472,6 +493,7 @@ def _parse_datatype_string(datatype_string):
fields = _parse_datatype_list(field_list_string)
return StructType(fields)


class SQLContext:
"""Main entry point for SparkSQL functionality.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ public abstract class DataType {
/**
* Creates an ArrayType by specifying the data type of elements ({@code elementType}).
* The field of {@code containsNull} is set to {@code false}.
*
* @param elementType
* @return
*/
public static ArrayType createArrayType(DataType elementType) {
if (elementType == null) {
Expand All @@ -102,9 +99,6 @@ public static ArrayType createArrayType(DataType elementType) {
/**
* Creates an ArrayType by specifying the data type of elements ({@code elementType}) and
* whether the array contains null values ({@code containsNull}).
* @param elementType
* @param containsNull
* @return
*/
public static ArrayType createArrayType(DataType elementType, boolean containsNull) {
if (elementType == null) {
Expand All @@ -117,10 +111,6 @@ public static ArrayType createArrayType(DataType elementType, boolean containsNu
/**
* Creates a MapType by specifying the data type of keys ({@code keyType}) and values
* ({@code keyType}). The field of {@code valueContainsNull} is set to {@code true}.
*
* @param keyType
* @param valueType
* @return
*/
public static MapType createMapType(DataType keyType, DataType valueType) {
if (keyType == null) {
Expand All @@ -137,10 +127,6 @@ public static MapType createMapType(DataType keyType, DataType valueType) {
* Creates a MapType by specifying the data type of keys ({@code keyType}), the data type of
* values ({@code keyType}), and whether values contain any null value
* ({@code valueContainsNull}).
* @param keyType
* @param valueType
* @param valueContainsNull
* @return
*/
public static MapType createMapType(
DataType keyType,
Expand All @@ -159,10 +145,6 @@ public static MapType createMapType(
/**
* Creates a StructField by specifying the name ({@code name}), data type ({@code dataType}) and
* whether values of this field can be null values ({@code nullable}).
* @param name
* @param dataType
* @param nullable
* @return
*/
public static StructField createStructField(String name, DataType dataType, boolean nullable) {
if (name == null) {
Expand All @@ -177,17 +159,13 @@ public static StructField createStructField(String name, DataType dataType, bool

/**
* Creates a StructType with the given list of StructFields ({@code fields}).
* @param fields
* @return
*/
public static StructType createStructType(List<StructField> fields) {
return createStructType(fields.toArray(new StructField[0]));
}

/**
* Creates a StructType with the given StructField array ({@code fields}).
* @param fields
* @return
*/
public static StructType createStructType(StructField[] fields) {
if (fields == null) {
Expand Down
20 changes: 20 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,26 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD.
* It is important to make sure that the structure of every [[Row]] of the provided RDD matches
* the provided schema. Otherwise, there will be runtime exception.
* Example:
* {{{
* import org.apache.spark.sql._
* val sqlContext = new org.apache.spark.sql.SQLContext(sc)
*
* val schema =
* StructType(
* StructField("name", StringType, false) ::
* StructField("age", IntegerType, true) :: Nil)
*
* val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Row(p(0), p(1).trim.toInt))
* val peopleSchemaRDD = sqlContext. applySchema(people, schema)
* peopleSchemaRDD.printSchema
* // root
* // |-- name: string (nullable = false)
* // |-- age: integer (nullable = true)
*
* peopleSchemaRDD.registerAsTable("people")
* sqlContext.sql("select name from people").collect.foreach(println)
* }}}
*
* @group userf
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,9 @@ private[sql] object JsonRDD extends Logging {
// the ObjectMapper will take the last value associated with this duplicate key.
// For example: for {"key": 1, "key":2}, we will get "key"->2.
val mapper = new ObjectMapper()
iter.map {
record =>
val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]]))
parsed.asInstanceOf[Map[String, Any]]
iter.map { record =>
val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]]))
parsed.asInstanceOf[Map[String, Any]]
}
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
case ParquetOriginalType.LIST => { // TODO: check enums!
assert(groupType.getFieldCount == 1)
val field = groupType.getFields.apply(0)
ArrayType(toDataType(field), false)
ArrayType(toDataType(field), containsNull = false)
}
case ParquetOriginalType.MAP => {
assert(
Expand Down Expand Up @@ -147,7 +147,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
MapType(keyType, valueType)
} else if (correspondsToArray(groupType)) { // ArrayType
val elementType = toDataType(groupType.getFields.apply(0))
ArrayType(elementType, false)
ArrayType(elementType, containsNull = false)
} else { // everything else: StructType
val fields = groupType
.getFields
Expand Down
Loading

0 comments on commit 122d1e7

Please sign in to comment.