Skip to content

Commit

Permalink
[SPARK-29644][SQL] Corrected ShortType and ByteType mapping to SmallI…
Browse files Browse the repository at this point in the history
…nt and TinyInt in JDBCUtils

### What changes were proposed in this pull request?
Corrected ShortType and ByteType mapping to SmallInt and TinyInt, corrected setter methods to set ShortType and ByteType  as setShort() and setByte(). Changes in JDBCUtils.scala
Fixed Unit test cases to where applicable and added new E2E test cases in to test table read/write using ShortType and ByteType.

#### Problems

- In master in JDBCUtils.scala line number 547 and 551 have a problem where ShortType and ByteType are set as Integers rather than set as Short and Byte respectively.
```
case ShortType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getShort(pos))
The issue was pointed out by maropu

case ByteType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
 stmt.setInt(pos + 1, row.getByte(pos))
```

- Also at line JDBCUtils.scala 247 TinyInt is interpreted wrongly as IntergetType in getCatalystType()

``` case java.sql.Types.TINYINT       => IntegerType ```

- At line 172 ShortType was wrongly interpreted as IntegerType
``` case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) ```

- All thru out tests, ShortType and ByteType were being interpreted as IntegerTypes.

### Why are the changes needed?
A given type should be set using the right type.

### Does this PR introduce any user-facing change?
No

### How was this patch tested?
Corrected Unit test cases where applicable. Validated in CI/CD
Added a test case in MsSqlServerIntegrationSuite.scala, PostgresIntegrationSuite.scala , MySQLIntegrationSuite.scala to write/read tables from dataframe with cols as shorttype and bytetype. Validated by manual as follows.
```
./build/mvn install -DskipTests
./build/mvn test -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.12
```

Closes #26301 from shivsood/shorttype_fix_maropu.

Authored-by: shivsood <shivsood@microsoft.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
shivsood authored and dongjoon-hyun committed Nov 14, 2019
1 parent 15a72f3 commit 32d44b1
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 13 deletions.
Expand Up @@ -59,7 +59,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
"""
|INSERT INTO numbers VALUES (
|0,
|255, 32767, 2147483647, 9223372036854775807,
|127, 32767, 2147483647, 9223372036854775807,
|123456789012345.123456789012345, 123456789012345.123456789012345,
|123456789012345.123456789012345,
|123, 12345.12,
Expand Down Expand Up @@ -119,7 +119,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
val types = row.toSeq.map(x => x.getClass.toString)
assert(types.length == 12)
assert(types(0).equals("class java.lang.Boolean"))
assert(types(1).equals("class java.lang.Integer"))
assert(types(1).equals("class java.lang.Byte"))
assert(types(2).equals("class java.lang.Short"))
assert(types(3).equals("class java.lang.Integer"))
assert(types(4).equals("class java.lang.Long"))
Expand All @@ -131,7 +131,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(types(10).equals("class java.math.BigDecimal"))
assert(types(11).equals("class java.math.BigDecimal"))
assert(row.getBoolean(0) == false)
assert(row.getInt(1) == 255)
assert(row.getByte(1) == 127)
assert(row.getShort(2) == 32767)
assert(row.getInt(3) == 2147483647)
assert(row.getLong(4) == 9223372036854775807L)
Expand Down Expand Up @@ -202,4 +202,46 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
df2.write.jdbc(jdbcUrl, "datescopy", new Properties)
df3.write.jdbc(jdbcUrl, "stringscopy", new Properties)
}

test("SPARK-29644: Write tables with ShortType") {
import testImplicits._
val df = Seq(-32768.toShort, 0.toShort, 1.toShort, 38.toShort, 32768.toShort).toDF("a")
val tablename = "shorttable"
df.write
.format("jdbc")
.mode("overwrite")
.option("url", jdbcUrl)
.option("dbtable", tablename)
.save()
val df2 = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", tablename)
.load()
assert(df.count == df2.count)
val rows = df2.collect()
val colType = rows(0).toSeq.map(x => x.getClass.toString)
assert(colType(0) == "class java.lang.Short")
}

test("SPARK-29644: Write tables with ByteType") {
import testImplicits._
val df = Seq(-127.toByte, 0.toByte, 1.toByte, 38.toByte, 128.toByte).toDF("a")
val tablename = "bytetable"
df.write
.format("jdbc")
.mode("overwrite")
.option("url", jdbcUrl)
.option("dbtable", tablename)
.save()
val df2 = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", tablename)
.load()
assert(df.count == df2.count)
val rows = df2.collect()
val colType = rows(0).toSeq.map(x => x.getClass.toString)
assert(colType(0) == "class java.lang.Byte")
}
}
Expand Up @@ -84,7 +84,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(types.length == 9)
assert(types(0).equals("class java.lang.Boolean"))
assert(types(1).equals("class java.lang.Long"))
assert(types(2).equals("class java.lang.Integer"))
assert(types(2).equals("class java.lang.Short"))
assert(types(3).equals("class java.lang.Integer"))
assert(types(4).equals("class java.lang.Integer"))
assert(types(5).equals("class java.lang.Long"))
Expand All @@ -93,7 +93,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(types(8).equals("class java.lang.Double"))
assert(rows(0).getBoolean(0) == false)
assert(rows(0).getLong(1) == 0x225)
assert(rows(0).getInt(2) == 17)
assert(rows(0).getShort(2) == 17)
assert(rows(0).getInt(3) == 77777)
assert(rows(0).getInt(4) == 123456789)
assert(rows(0).getLong(5) == 123456789012345L)
Expand Down
Expand Up @@ -170,8 +170,8 @@ object JdbcUtils extends Logging {
case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
case ShortType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT))
case ByteType => Option(JdbcType("TINYINT", java.sql.Types.TINYINT))
case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
Expand Down Expand Up @@ -235,7 +235,7 @@ object JdbcUtils extends Logging {
case java.sql.Types.REF => StringType
case java.sql.Types.REF_CURSOR => null
case java.sql.Types.ROWID => LongType
case java.sql.Types.SMALLINT => IntegerType
case java.sql.Types.SMALLINT => ShortType
case java.sql.Types.SQLXML => StringType
case java.sql.Types.STRUCT => StringType
case java.sql.Types.TIME => TimestampType
Expand All @@ -244,7 +244,7 @@ object JdbcUtils extends Logging {
case java.sql.Types.TIMESTAMP => TimestampType
case java.sql.Types.TIMESTAMP_WITH_TIMEZONE
=> null
case java.sql.Types.TINYINT => IntegerType
case java.sql.Types.TINYINT => ByteType
case java.sql.Types.VARBINARY => BinaryType
case java.sql.Types.VARCHAR => StringType
case _ =>
Expand Down Expand Up @@ -546,11 +546,11 @@ object JdbcUtils extends Logging {

case ShortType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getShort(pos))
stmt.setShort(pos + 1, row.getShort(pos))

case ByteType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getByte(pos))
stmt.setByte(pos + 1, row.getByte(pos))

case BooleanType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
Expand Down
Expand Up @@ -578,8 +578,8 @@ class JDBCSuite extends QueryTest
assert(rows.length === 1)
assert(rows(0).getInt(0) === 1)
assert(rows(0).getBoolean(1) === false)
assert(rows(0).getInt(2) === 3)
assert(rows(0).getInt(3) === 4)
assert(rows(0).getByte(2) === 3.toByte)
assert(rows(0).getShort(3) === 4.toShort)
assert(rows(0).getLong(4) === 1234567890123L)
}

Expand Down
Expand Up @@ -574,6 +574,48 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter {
}
}

test("SPARK-29644: Write tables with ShortType") {
import testImplicits._
val df = Seq(-32768.toShort, 0.toShort, 1.toShort, 38.toShort, 32768.toShort).toDF("a")
val tablename = "shorttable"
df.write
.format("jdbc")
.mode("overwrite")
.option("url", url)
.option("dbtable", tablename)
.save()
val df2 = spark.read
.format("jdbc")
.option("url", url)
.option("dbtable", tablename)
.load()
assert(df.count == df2.count)
val rows = df2.collect()
val colType = rows(0).toSeq.map(x => x.getClass.toString)
assert(colType(0) == "class java.lang.Short")
}

test("SPARK-29644: Write tables with ByteType") {
import testImplicits._
val df = Seq(-127.toByte, 0.toByte, 1.toByte, 38.toByte, 128.toByte).toDF("a")
val tablename = "bytetable"
df.write
.format("jdbc")
.mode("overwrite")
.option("url", url)
.option("dbtable", tablename)
.save()
val df2 = spark.read
.format("jdbc")
.option("url", url)
.option("dbtable", tablename)
.load()
assert(df.count == df2.count)
val rows = df2.collect()
val colType = rows(0).toSeq.map(x => x.getClass.toString)
assert(colType(0) == "class java.lang.Byte")
}

private def runAndVerifyRecordsWritten(expected: Long)(job: => Unit): Unit = {
assert(expected === runAndReturnMetrics(job, _.taskMetrics.outputMetrics.recordsWritten))
}
Expand Down

0 comments on commit 32d44b1

Please sign in to comment.