Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-16542][SQL][PYSPARK] Fix bugs about types that result an array of null when creating dataframe using python #14198

Closed
wants to merge 10 commits into from
56 changes: 56 additions & 0 deletions python/pyspark/sql/tests.py
Expand Up @@ -30,6 +30,8 @@
import functools
import time
import datetime
import array
import math

import py4j
try:
Expand Down Expand Up @@ -1735,6 +1737,60 @@ def test_BinaryType_serialization(self):
df = self.spark.createDataFrame(data, schema=schema)
df.collect()

# test for SPARK-16542
def test_array_types(self):
int_types = set(['b', 'h', 'i', 'l'])
float_types = set(['f', 'd'])
unsupported_types = set(array.typecodes) - int_types - float_types

def collected(a):
row = Row(myarray=a)
rdd = self.sc.parallelize([row])
df = self.spark.createDataFrame(rdd)
return df.collect()[0]["myarray"][0]
# test whether pyspark can correctly handle int types
for t in int_types:
# test positive numbers
a = array.array(t, [1])
while True:
try:
self.assertEqual(collected(a), a[0])
a[0] *= 2
except OverflowError:
break
# test negative numbers
a = array.array(t, [-1])
while True:
try:
self.assertEqual(collected(a), a[0])
a[0] *= 2
except OverflowError:
break
# test whether pyspark can correctly handle float types
for t in float_types:
# test upper bound and precision
a = array.array(t, [1.0])
while not math.isinf(a[0]):
self.assertEqual(collected(a), a[0])
a[0] *= 2
a[0] += 1
# test lower bound
a = array.array(t, [1.0])
while a[0] != 0:
self.assertEqual(collected(a), a[0])
a[0] /= 2
# test whether pyspark can correctly handle unsupported types
for t in unsupported_types:
try:
a = array.array(t)
c = collected(a)
self.assertTrue(False) # if no exception thrown, fail the test
except TypeError:
pass # catch the expected exception and do nothing
except:
# if incorrect exception thrown, fail the test
self.assertTrue(False)


class HiveSparkSubmitTests(SparkSubmitTests):

Expand Down
17 changes: 16 additions & 1 deletion python/pyspark/sql/types.py
Expand Up @@ -929,6 +929,16 @@ def _parse_datatype_json_value(json_value):
datetime.time: TimestampType,
}

# Mapping Python array types to Spark SQL DataType
_array_type_mappings = {
'b': ByteType,
'h': ShortType,
'i': IntegerType,
'l': LongType,
'f': FloatType,
'd': DoubleType
}

if sys.version < "3":
_type_mappings.update({
unicode: StringType,
Expand Down Expand Up @@ -958,12 +968,17 @@ def _infer_type(obj):
return MapType(_infer_type(key), _infer_type(value), True)
else:
return MapType(NullType(), NullType(), True)
elif isinstance(obj, (list, array)):
elif isinstance(obj, list):
for v in obj:
if v is not None:
return ArrayType(_infer_type(obj[0]), True)
else:
return ArrayType(NullType(), True)
elif isinstance(obj, array):
if obj.typecode in _array_type_mappings:
return ArrayType(_array_type_mappings[obj.typecode](), True)
else:
raise TypeError("not supported type: array(%s)" % obj.typecode)
else:
try:
return _infer_schema(obj)
Expand Down
Expand Up @@ -91,20 +91,30 @@ object EvaluatePython {

case (c: Boolean, BooleanType) => c

case (c: Byte, ByteType) => c
case (c: Short, ByteType) => c.toByte
case (c: Int, ByteType) => c.toByte
case (c: Long, ByteType) => c.toByte

case (c: Byte, ShortType) => c.toShort
case (c: Short, ShortType) => c
case (c: Int, ShortType) => c.toShort
case (c: Long, ShortType) => c.toShort

case (c: Byte, IntegerType) => c.toInt
case (c: Short, IntegerType) => c.toInt
case (c: Int, IntegerType) => c
case (c: Long, IntegerType) => c.toInt

case (c: Byte, LongType) => c.toLong
case (c: Short, LongType) => c.toLong
case (c: Int, LongType) => c.toLong
case (c: Long, LongType) => c

case (c: Float, FloatType) => c
case (c: Double, FloatType) => c.toFloat

case (c: Float, DoubleType) => c.toDouble
case (c: Double, DoubleType) => c

case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale)
Expand Down