From ad64483e2263485f843a469310fb8d252824d09e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 22 Jul 2016 12:32:44 +0900 Subject: [PATCH 1/9] Avoid per-record type dispatch in JDBC when reading --- .../execution/datasources/jdbc/JDBCRDD.scala | 231 +++++++++--------- 1 file changed, 122 insertions(+), 109 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 24e2c1a5fd2f6..dc0eeb801997f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -322,46 +322,135 @@ private[sql] class JDBCRDD( } } - // Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that - // we don't have to potentially poke around in the Metadata once for every - // row. - // Is there a better way to do this? I'd rather be using a type that - // contains only the tags I define. - abstract class JDBCConversion - case object BooleanConversion extends JDBCConversion - case object DateConversion extends JDBCConversion - case class DecimalConversion(precision: Int, scale: Int) extends JDBCConversion - case object DoubleConversion extends JDBCConversion - case object FloatConversion extends JDBCConversion - case object IntegerConversion extends JDBCConversion - case object LongConversion extends JDBCConversion - case object BinaryLongConversion extends JDBCConversion - case object StringConversion extends JDBCConversion - case object TimestampConversion extends JDBCConversion - case object BinaryConversion extends JDBCConversion - case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion + // A `JDBCConversion` is responsible for converting a value from `ResultSet` + // to a value in a field for `InternalRow`. + private type JDBCConversion = (ResultSet, Int) => Any + + // This `ArrayElementConversion` is responsible for converting elements in + // an array from `ResultSet`. + private type ArrayElementConversion = (Object) => Any /** - * Maps a StructType to a type tag list. + * Maps a StructType to conversions for each type. */ def getConversions(schema: StructType): Array[JDBCConversion] = schema.fields.map(sf => getConversions(sf.dataType, sf.metadata)) private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match { - case BooleanType => BooleanConversion - case DateType => DateConversion - case DecimalType.Fixed(p, s) => DecimalConversion(p, s) - case DoubleType => DoubleConversion - case FloatType => FloatConversion - case IntegerType => IntegerConversion - case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion - case StringType => StringConversion - case TimestampType => TimestampConversion - case BinaryType => BinaryConversion - case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata)) + case BooleanType => + (rs: ResultSet, pos: Int) => rs.getBoolean(pos) + + case DateType => + (rs: ResultSet, pos: Int) => + // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. + val dateVal = rs.getDate(pos) + if (dateVal != null) { + DateTimeUtils.fromJavaDate(dateVal) + } else { + null + } + + case DecimalType.Fixed(p, s) => + (rs: ResultSet, pos: Int) => + val decimalVal = rs.getBigDecimal(pos) + if (decimalVal == null) { + null + } else { + Decimal(decimalVal, p, s) + } + + case DoubleType => + (rs: ResultSet, pos: Int) => rs.getDouble(pos) + + case FloatType => + (rs: ResultSet, pos: Int) => rs.getFloat(pos) + + case IntegerType => + (rs: ResultSet, pos: Int) => rs.getInt(pos) + + case LongType if metadata.contains("binarylong") => + (rs: ResultSet, pos: Int) => + val bytes = rs.getBytes(pos) + var ans = 0L + var j = 0 + while (j < bytes.size) { + ans = 256 * ans + (255 & bytes(j)) + j = j + 1 + } + ans + + case LongType => + (rs: ResultSet, pos: Int) => rs.getLong(pos) + + case StringType => + (rs: ResultSet, pos: Int) => + // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 + UTF8String.fromString(rs.getString(pos)) + + case TimestampType => + (rs: ResultSet, pos: Int) => + val t = rs.getTimestamp(pos) + if (t != null) { + DateTimeUtils.fromJavaTimestamp(t) + } else { + null + } + + case BinaryType => + (rs: ResultSet, pos: Int) => rs.getBytes(pos) + + case ArrayType(et, _) => + val elementConversion: ArrayElementConversion = + getArrayElementConversion(et, metadata) + (rs: ResultSet, pos: Int) => + val array = rs.getArray(pos).getArray + if (array != null) { + val data = elementConversion.apply(array) + new GenericArrayData(data) + } else { + null + } + case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") } + private def getArrayElementConversion( + dt: DataType, + metadata: Metadata): ArrayElementConversion = { + dt match { + case TimestampType => + (array: Object) => + array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => + nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) + } + + case StringType => + (array: Object) => + array.asInstanceOf[Array[java.lang.String]] + .map(UTF8String.fromString) + + case DateType => + (array: Object) => + array.asInstanceOf[Array[java.sql.Date]].map { date => + nullSafeConvert(date, DateTimeUtils.fromJavaDate) + } + + case dt: DecimalType => + (array: Object) => + array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => + nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, dt.precision, dt.scale)) + } + + case LongType if metadata.contains("binarylong") => + throw new IllegalArgumentException(s"Unsupported array element conversion.") + + case ArrayType(_, _) => + throw new IllegalArgumentException("Nested arrays unsupported") + + case _ => (array: Object) => array.asInstanceOf[Array[Any]] + } + } + /** * Runs the SQL query against the JDBC driver. * @@ -398,7 +487,7 @@ private[sql] class JDBCRDD( stmt.setFetchSize(fetchSize) val rs = stmt.executeQuery() - val conversions = getConversions(schema) + val conversions: Array[JDBCConversion] = getConversions(schema) val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) def getNext(): InternalRow = { @@ -407,84 +496,8 @@ private[sql] class JDBCRDD( var i = 0 while (i < conversions.length) { val pos = i + 1 - conversions(i) match { - case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) - case DateConversion => - // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. - val dateVal = rs.getDate(pos) - if (dateVal != null) { - mutableRow.setInt(i, DateTimeUtils.fromJavaDate(dateVal)) - } else { - mutableRow.update(i, null) - } - // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal - // object returned by ResultSet.getBigDecimal is not correctly matched to the table - // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale. - // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through - // a BigDecimal object with scale as 0. But the dataframe schema has correct type as - // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then - // retrieve it, you will get wrong result 199.99. - // So it is needed to set precision and scale for Decimal based on JDBC metadata. - case DecimalConversion(p, s) => - val decimalVal = rs.getBigDecimal(pos) - if (decimalVal == null) { - mutableRow.update(i, null) - } else { - mutableRow.update(i, Decimal(decimalVal, p, s)) - } - case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) - case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) - case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) - case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) - // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 - case StringConversion => mutableRow.update(i, UTF8String.fromString(rs.getString(pos))) - case TimestampConversion => - val t = rs.getTimestamp(pos) - if (t != null) { - mutableRow.setLong(i, DateTimeUtils.fromJavaTimestamp(t)) - } else { - mutableRow.update(i, null) - } - case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) - case BinaryLongConversion => - val bytes = rs.getBytes(pos) - var ans = 0L - var j = 0 - while (j < bytes.size) { - ans = 256 * ans + (255 & bytes(j)) - j = j + 1 - } - mutableRow.setLong(i, ans) - case ArrayConversion(elementConversion) => - val array = rs.getArray(pos).getArray - if (array != null) { - val data = elementConversion match { - case TimestampConversion => - array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => - nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) - } - case StringConversion => - array.asInstanceOf[Array[java.lang.String]] - .map(UTF8String.fromString) - case DateConversion => - array.asInstanceOf[Array[java.sql.Date]].map { date => - nullSafeConvert(date, DateTimeUtils.fromJavaDate) - } - case DecimalConversion(p, s) => - array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => - nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s)) - } - case BinaryLongConversion => - throw new IllegalArgumentException(s"Unsupported array element conversion $i") - case _: ArrayConversion => - throw new IllegalArgumentException("Nested arrays unsupported") - case _ => array.asInstanceOf[Array[Any]] - } - mutableRow.update(i, new GenericArrayData(data)) - } else { - mutableRow.update(i, null) - } - } + val value = conversions(i).apply(rs, pos) + mutableRow.update(i, value) if (rs.wasNull) mutableRow.setNullAt(i) i = i + 1 } From 5eae0e6b7d68a7781b1849441a000cf3ff7fe804 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 22 Jul 2016 12:44:44 +0900 Subject: [PATCH 2/9] Correct indentation --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index dc0eeb801997f..ee31775c8a2a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -400,8 +400,7 @@ private[sql] class JDBCRDD( (rs: ResultSet, pos: Int) => rs.getBytes(pos) case ArrayType(et, _) => - val elementConversion: ArrayElementConversion = - getArrayElementConversion(et, metadata) + val elementConversion: ArrayElementConversion = getArrayElementConversion(et, metadata) (rs: ResultSet, pos: Int) => val array = rs.getArray(pos).getArray if (array != null) { From 53350935b476e2a30dfd03f7fbfe857e6c4316d0 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 24 Jul 2016 09:57:43 +0900 Subject: [PATCH 3/9] Fix some nits --- .../spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index ee31775c8a2a1..6f13ba3727d7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -353,10 +353,10 @@ private[sql] class JDBCRDD( case DecimalType.Fixed(p, s) => (rs: ResultSet, pos: Int) => val decimalVal = rs.getBigDecimal(pos) - if (decimalVal == null) { - null - } else { + if (decimalVal != null) { Decimal(decimalVal, p, s) + } else { + null } case DoubleType => @@ -441,7 +441,8 @@ private[sql] class JDBCRDD( } case LongType if metadata.contains("binarylong") => - throw new IllegalArgumentException(s"Unsupported array element conversion.") + throw new IllegalArgumentException(s"Unsupported array element " + + s"type ${dt.simpleString} based on binary") case ArrayType(_, _) => throw new IllegalArgumentException("Nested arrays unsupported") From ec029af8d76421f879ef2154e9de9bb274238586 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 25 Jul 2016 12:55:00 +0900 Subject: [PATCH 4/9] Address comments --- .../execution/datasources/jdbc/JDBCRDD.scala | 104 +++++++----------- 1 file changed, 42 insertions(+), 62 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 6f13ba3727d7f..b7b0296d7f1f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -323,13 +323,10 @@ private[sql] class JDBCRDD( } // A `JDBCConversion` is responsible for converting a value from `ResultSet` - // to a value in a field for `InternalRow`. + // to a value in a field for `InternalRow`. The second argument `Int` means the index + // for the data to retrieve and convert from `ResultSet`. private type JDBCConversion = (ResultSet, Int) => Any - // This `ArrayElementConversion` is responsible for converting elements in - // an array from `ResultSet`. - private type ArrayElementConversion = (Object) => Any - /** * Maps a StructType to conversions for each type. */ @@ -343,21 +340,19 @@ private[sql] class JDBCRDD( case DateType => (rs: ResultSet, pos: Int) => // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. - val dateVal = rs.getDate(pos) - if (dateVal != null) { - DateTimeUtils.fromJavaDate(dateVal) - } else { - null - } - + nullSafeConvert(rs.getDate(pos), DateTimeUtils.fromJavaDate) + + // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal + // object returned by ResultSet.getBigDecimal is not correctly matched to the table + // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale. + // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through + // a BigDecimal object with scale as 0. But the dataframe schema has correct type as + // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then + // retrieve it, you will get wrong result 199.99. + // So it is needed to set precision and scale for Decimal based on JDBC metadata. case DecimalType.Fixed(p, s) => (rs: ResultSet, pos: Int) => - val decimalVal = rs.getBigDecimal(pos) - if (decimalVal != null) { - Decimal(decimalVal, p, s) - } else { - null - } + nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos), d => Decimal(d, p, s)) case DoubleType => (rs: ResultSet, pos: Int) => rs.getDouble(pos) @@ -389,66 +384,51 @@ private[sql] class JDBCRDD( case TimestampType => (rs: ResultSet, pos: Int) => - val t = rs.getTimestamp(pos) - if (t != null) { - DateTimeUtils.fromJavaTimestamp(t) - } else { - null - } + nullSafeConvert(rs.getTimestamp(pos), DateTimeUtils.fromJavaTimestamp) case BinaryType => (rs: ResultSet, pos: Int) => rs.getBytes(pos) case ArrayType(et, _) => - val elementConversion: ArrayElementConversion = getArrayElementConversion(et, metadata) + val elementConversion = getArrayElementConversion(et, metadata) (rs: ResultSet, pos: Int) => - val array = rs.getArray(pos).getArray - if (array != null) { - val data = elementConversion.apply(array) - new GenericArrayData(data) - } else { - null - } + nullSafeConvert(rs.getArray(pos).getArray, elementConversion) case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") } - private def getArrayElementConversion( - dt: DataType, - metadata: Metadata): ArrayElementConversion = { - dt match { - case TimestampType => - (array: Object) => - array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => - nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) - } + private def getArrayElementConversion(dt: DataType, metadata: Metadata) = dt match { + case TimestampType => + (array: Object) => + array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => + nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) + } - case StringType => - (array: Object) => - array.asInstanceOf[Array[java.lang.String]] - .map(UTF8String.fromString) + case StringType => + (array: Object) => + array.asInstanceOf[Array[java.lang.String]] + .map(UTF8String.fromString) - case DateType => - (array: Object) => - array.asInstanceOf[Array[java.sql.Date]].map { date => - nullSafeConvert(date, DateTimeUtils.fromJavaDate) - } + case DateType => + (array: Object) => + array.asInstanceOf[Array[java.sql.Date]].map { date => + nullSafeConvert(date, DateTimeUtils.fromJavaDate) + } - case dt: DecimalType => - (array: Object) => - array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => - nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, dt.precision, dt.scale)) - } + case dt: DecimalType => + (array: Object) => + array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => + nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, dt.precision, dt.scale)) + } - case LongType if metadata.contains("binarylong") => - throw new IllegalArgumentException(s"Unsupported array element " + - s"type ${dt.simpleString} based on binary") + case LongType if metadata.contains("binarylong") => + throw new IllegalArgumentException(s"Unsupported array element " + + s"type ${dt.simpleString} based on binary") - case ArrayType(_, _) => - throw new IllegalArgumentException("Nested arrays unsupported") + case ArrayType(_, _) => + throw new IllegalArgumentException("Nested arrays unsupported") - case _ => (array: Object) => array.asInstanceOf[Array[Any]] - } + case _ => (array: Object) => array.asInstanceOf[Array[Any]] } /** From 8ac66b17f81a2fdc0866df26889b5e2fcc634c51 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 25 Jul 2016 13:15:01 +0900 Subject: [PATCH 5/9] Avoid type-boxing --- .../execution/datasources/jdbc/JDBCRDD.scala | 72 +++++++++++-------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index b7b0296d7f1f8..b8a42869d07d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -28,8 +28,8 @@ import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -322,10 +322,10 @@ private[sql] class JDBCRDD( } } - // A `JDBCConversion` is responsible for converting a value from `ResultSet` - // to a value in a field for `InternalRow`. The second argument `Int` means the index - // for the data to retrieve and convert from `ResultSet`. - private type JDBCConversion = (ResultSet, Int) => Any + // A `JDBCConversion` is responsible for converting and setting a value from `ResultSet` + // into a field for `MutableRow`. The last argument `Int` means the index for the + // value to be set in the row and also used for the value to retrieve from `ResultSet`. + private type JDBCConversion = (ResultSet, MutableRow, Int) => Unit /** * Maps a StructType to conversions for each type. @@ -335,12 +335,18 @@ private[sql] class JDBCRDD( private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match { case BooleanType => - (rs: ResultSet, pos: Int) => rs.getBoolean(pos) + (rs: ResultSet, row: MutableRow, pos: Int) => + row.setBoolean(pos, rs.getBoolean(pos + 1)) case DateType => - (rs: ResultSet, pos: Int) => + (rs: ResultSet, row: MutableRow, pos: Int) => // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. - nullSafeConvert(rs.getDate(pos), DateTimeUtils.fromJavaDate) + val dateVal = rs.getDate(pos + 1) + if (dateVal != null) { + row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal)) + } else { + row.update(pos, null) + } // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal // object returned by ResultSet.getBigDecimal is not correctly matched to the table @@ -351,48 +357,60 @@ private[sql] class JDBCRDD( // retrieve it, you will get wrong result 199.99. // So it is needed to set precision and scale for Decimal based on JDBC metadata. case DecimalType.Fixed(p, s) => - (rs: ResultSet, pos: Int) => - nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos), d => Decimal(d, p, s)) + (rs: ResultSet, row: MutableRow, pos: Int) => + val decimal = + nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s)) + row.update(pos, decimal) case DoubleType => - (rs: ResultSet, pos: Int) => rs.getDouble(pos) + (rs: ResultSet, row: MutableRow, pos: Int) => + row.setDouble(pos, rs.getDouble(pos + 1)) case FloatType => - (rs: ResultSet, pos: Int) => rs.getFloat(pos) + (rs: ResultSet, row: MutableRow, pos: Int) => + row.setFloat(pos, rs.getFloat(pos + 1)) case IntegerType => - (rs: ResultSet, pos: Int) => rs.getInt(pos) + (rs: ResultSet, row: MutableRow, pos: Int) => + row.setInt(pos, rs.getInt(pos + 1)) case LongType if metadata.contains("binarylong") => - (rs: ResultSet, pos: Int) => - val bytes = rs.getBytes(pos) + (rs: ResultSet, row: MutableRow, pos: Int) => + val bytes = rs.getBytes(pos + 1) var ans = 0L var j = 0 while (j < bytes.size) { ans = 256 * ans + (255 & bytes(j)) j = j + 1 } - ans + row.setLong(pos, ans) case LongType => - (rs: ResultSet, pos: Int) => rs.getLong(pos) + (rs: ResultSet, row: MutableRow, pos: Int) => + row.setLong(pos, rs.getLong(pos + 1)) case StringType => - (rs: ResultSet, pos: Int) => + (rs: ResultSet, row: MutableRow, pos: Int) => // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 - UTF8String.fromString(rs.getString(pos)) + row.update(pos, UTF8String.fromString(rs.getString(pos + 1))) case TimestampType => - (rs: ResultSet, pos: Int) => - nullSafeConvert(rs.getTimestamp(pos), DateTimeUtils.fromJavaTimestamp) + (rs: ResultSet, row: MutableRow, pos: Int) => + val t = rs.getTimestamp(pos + 1) + if (t != null) { + row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t)) + } else { + row.update(pos, null) + } case BinaryType => - (rs: ResultSet, pos: Int) => rs.getBytes(pos) + (rs: ResultSet, row: MutableRow, pos: Int) => + row.update(pos, rs.getBytes(pos + 1)) case ArrayType(et, _) => val elementConversion = getArrayElementConversion(et, metadata) - (rs: ResultSet, pos: Int) => - nullSafeConvert(rs.getArray(pos).getArray, elementConversion) + (rs: ResultSet, row: MutableRow, pos: Int) => + row.update(pos, nullSafeConvert(rs.getArray(pos + 1).getArray, elementConversion)) case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") } @@ -475,9 +493,7 @@ private[sql] class JDBCRDD( inputMetrics.incRecordsRead(1) var i = 0 while (i < conversions.length) { - val pos = i + 1 - val value = conversions(i).apply(rs, pos) - mutableRow.update(i, value) + conversions(i).apply(rs, mutableRow, i) if (rs.wasNull) mutableRow.setNullAt(i) i = i + 1 } From a3853182b375539bd329fb10464c1a746d4eaa47 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 25 Jul 2016 14:00:46 +0900 Subject: [PATCH 6/9] Rename JDBCConversion to ValueSetter and make a method inline --- .../execution/datasources/jdbc/JDBCRDD.scala | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index b8a42869d07d0..0a835bfa14f76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -322,18 +322,18 @@ private[sql] class JDBCRDD( } } - // A `JDBCConversion` is responsible for converting and setting a value from `ResultSet` + // A `JDBCValueSetter` is responsible for converting and setting a value from `ResultSet` // into a field for `MutableRow`. The last argument `Int` means the index for the // value to be set in the row and also used for the value to retrieve from `ResultSet`. - private type JDBCConversion = (ResultSet, MutableRow, Int) => Unit + private type ValueSetter = (ResultSet, MutableRow, Int) => Unit /** - * Maps a StructType to conversions for each type. + * Creates a StructType to setters for each type. */ - def getConversions(schema: StructType): Array[JDBCConversion] = - schema.fields.map(sf => getConversions(sf.dataType, sf.metadata)) + def makeSetters(schema: StructType): Array[ValueSetter] = + schema.fields.map(sf => makeSetters(sf.dataType, sf.metadata)) - private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match { + private def makeSetters(dt: DataType, metadata: Metadata): ValueSetter = dt match { case BooleanType => (rs: ResultSet, row: MutableRow, pos: Int) => row.setBoolean(pos, rs.getBoolean(pos + 1)) @@ -408,45 +408,45 @@ private[sql] class JDBCRDD( row.update(pos, rs.getBytes(pos + 1)) case ArrayType(et, _) => - val elementConversion = getArrayElementConversion(et, metadata) - (rs: ResultSet, row: MutableRow, pos: Int) => - row.update(pos, nullSafeConvert(rs.getArray(pos + 1).getArray, elementConversion)) + val elementConversion = et match { + case TimestampType => + (array: Object) => + array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => + nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) + } - case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") - } + case StringType => + (array: Object) => + array.asInstanceOf[Array[java.lang.String]] + .map(UTF8String.fromString) - private def getArrayElementConversion(dt: DataType, metadata: Metadata) = dt match { - case TimestampType => - (array: Object) => - array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => - nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) - } + case DateType => + (array: Object) => + array.asInstanceOf[Array[java.sql.Date]].map { date => + nullSafeConvert(date, DateTimeUtils.fromJavaDate) + } - case StringType => - (array: Object) => - array.asInstanceOf[Array[java.lang.String]] - .map(UTF8String.fromString) + case dt: DecimalType => + (array: Object) => + array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => + nullSafeConvert[java.math.BigDecimal]( + decimal, d => Decimal(d, dt.precision, dt.scale)) + } - case DateType => - (array: Object) => - array.asInstanceOf[Array[java.sql.Date]].map { date => - nullSafeConvert(date, DateTimeUtils.fromJavaDate) - } + case LongType if metadata.contains("binarylong") => + throw new IllegalArgumentException(s"Unsupported array element " + + s"type ${dt.simpleString} based on binary") - case dt: DecimalType => - (array: Object) => - array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => - nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, dt.precision, dt.scale)) - } + case ArrayType(_, _) => + throw new IllegalArgumentException("Nested arrays unsupported") - case LongType if metadata.contains("binarylong") => - throw new IllegalArgumentException(s"Unsupported array element " + - s"type ${dt.simpleString} based on binary") + case _ => (array: Object) => array.asInstanceOf[Array[Any]] + } - case ArrayType(_, _) => - throw new IllegalArgumentException("Nested arrays unsupported") + (rs: ResultSet, row: MutableRow, pos: Int) => + row.update(pos, nullSafeConvert(rs.getArray(pos + 1).getArray, elementConversion)) - case _ => (array: Object) => array.asInstanceOf[Array[Any]] + case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") } /** @@ -485,15 +485,15 @@ private[sql] class JDBCRDD( stmt.setFetchSize(fetchSize) val rs = stmt.executeQuery() - val conversions: Array[JDBCConversion] = getConversions(schema) + val setters: Array[ValueSetter] = makeSetters(schema) val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) def getNext(): InternalRow = { if (rs.next()) { inputMetrics.incRecordsRead(1) var i = 0 - while (i < conversions.length) { - conversions(i).apply(rs, mutableRow, i) + while (i < setters.length) { + setters(i).apply(rs, mutableRow, i) if (rs.wasNull) mutableRow.setNullAt(i) i = i + 1 } From c336382e3df40d1eddf29b462a13b5933596d8ce Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 25 Jul 2016 15:27:29 +0900 Subject: [PATCH 7/9] Rename ValueSetter to JDBCValueSetter --- .../spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 0a835bfa14f76..fc27049511357 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -325,15 +325,15 @@ private[sql] class JDBCRDD( // A `JDBCValueSetter` is responsible for converting and setting a value from `ResultSet` // into a field for `MutableRow`. The last argument `Int` means the index for the // value to be set in the row and also used for the value to retrieve from `ResultSet`. - private type ValueSetter = (ResultSet, MutableRow, Int) => Unit + private type JDBCValueSetter = (ResultSet, MutableRow, Int) => Unit /** * Creates a StructType to setters for each type. */ - def makeSetters(schema: StructType): Array[ValueSetter] = + def makeSetters(schema: StructType): Array[JDBCValueSetter] = schema.fields.map(sf => makeSetters(sf.dataType, sf.metadata)) - private def makeSetters(dt: DataType, metadata: Metadata): ValueSetter = dt match { + private def makeSetters(dt: DataType, metadata: Metadata): JDBCValueSetter = dt match { case BooleanType => (rs: ResultSet, row: MutableRow, pos: Int) => row.setBoolean(pos, rs.getBoolean(pos + 1)) @@ -485,7 +485,7 @@ private[sql] class JDBCRDD( stmt.setFetchSize(fetchSize) val rs = stmt.executeQuery() - val setters: Array[ValueSetter] = makeSetters(schema) + val setters: Array[JDBCValueSetter] = makeSetters(schema) val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) def getNext(): InternalRow = { From 486dabd1fa508712e77d5e67f08939d2408d047a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 25 Jul 2016 15:43:39 +0900 Subject: [PATCH 8/9] Add missing GenericArrayData --- .../spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index fc27049511357..10791878830da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -29,7 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -444,7 +444,10 @@ private[sql] class JDBCRDD( } (rs: ResultSet, row: MutableRow, pos: Int) => - row.update(pos, nullSafeConvert(rs.getArray(pos + 1).getArray, elementConversion)) + val array = nullSafeConvert[Object]( + rs.getArray(pos + 1).getArray, + array => new GenericArrayData(elementConversion.apply(array))) + row.update(pos, array) case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") } From 25894f124658f3077f419051a52c33cc5d36306c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 25 Jul 2016 17:23:15 +0900 Subject: [PATCH 9/9] Fix documentation for makeSetters --- .../spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 10791878830da..4c98430363117 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -328,12 +328,13 @@ private[sql] class JDBCRDD( private type JDBCValueSetter = (ResultSet, MutableRow, Int) => Unit /** - * Creates a StructType to setters for each type. + * Creates `JDBCValueSetter`s according to [[StructType]], which can set + * each value from `ResultSet` to each field of [[MutableRow]] correctly. */ def makeSetters(schema: StructType): Array[JDBCValueSetter] = - schema.fields.map(sf => makeSetters(sf.dataType, sf.metadata)) + schema.fields.map(sf => makeSetter(sf.dataType, sf.metadata)) - private def makeSetters(dt: DataType, metadata: Metadata): JDBCValueSetter = dt match { + private def makeSetter(dt: DataType, metadata: Metadata): JDBCValueSetter = dt match { case BooleanType => (rs: ResultSet, row: MutableRow, pos: Int) => row.setBoolean(pos, rs.getBoolean(pos + 1))