/
PostgresIntegrationSuite.scala
265 lines (244 loc) · 12.1 KB
/
PostgresIntegrationSuite.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.jdbc
import java.sql.Connection
import java.util.Properties
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.types.{ArrayType, DecimalType, FloatType, ShortType}
import org.apache.spark.tags.DockerTest
/**
* To run this test suite for a specific version (e.g., postgres:13.0):
* {{{
* POSTGRES_DOCKER_IMAGE_NAME=postgres:13.0
* ./build/sbt -Pdocker-integration-tests
* "testOnly org.apache.spark.sql.jdbc.PostgresIntegrationSuite"
* }}}
*/
@DockerTest
class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
override val db = new DatabaseOnDocker {
override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:13.0-alpine")
override val env = Map(
"POSTGRES_PASSWORD" -> "rootpass"
)
override val usesIpc = false
override val jdbcPort = 5432
override def getJdbcUrl(ip: String, port: Int): String =
s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass"
}
override def dataPreparation(conn: Connection): Unit = {
conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
conn.setCatalog("foo")
conn.prepareStatement("CREATE TYPE enum_type AS ENUM ('d1', 'd2')").executeUpdate()
conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, "
+ "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, "
+ "c10 integer[], c11 text[], c12 real[], c13 numeric(2,2)[], c14 enum_type, "
+ "c15 float4, c16 smallint, c17 numeric[])").executeUpdate()
conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
+ "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', "
+ """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', '{0.11, 0.22}', 'd1', 1.01, 1, """
+ "'{111.2222, 333.4444}')"
).executeUpdate()
conn.prepareStatement("INSERT INTO bar VALUES (null, null, null, null, null, "
+ "null, null, null, null, null, "
+ "null, null, null, null, null, null, null, null)"
).executeUpdate()
conn.prepareStatement("CREATE TABLE ts_with_timezone " +
"(id integer, tstz TIMESTAMP WITH TIME ZONE, ttz TIME WITH TIME ZONE)")
.executeUpdate()
conn.prepareStatement("INSERT INTO ts_with_timezone VALUES " +
"(1, TIMESTAMP WITH TIME ZONE '2016-08-12 10:22:31.949271-07', " +
"TIME WITH TIME ZONE '17:22:31.949271+00')")
.executeUpdate()
conn.prepareStatement("CREATE TABLE st_with_array (c0 uuid, c1 inet, c2 cidr," +
"c3 json, c4 jsonb, c5 uuid[], c6 inet[], c7 cidr[], c8 json[], c9 jsonb[])")
.executeUpdate()
conn.prepareStatement("INSERT INTO st_with_array VALUES ( " +
"'0a532531-cdf1-45e3-963d-5de90b6a30f1', '172.168.22.1', '192.168.100.128/25', " +
"""'{"a": "foo", "b": "bar"}', '{"a": 1, "b": 2}', """ +
"ARRAY['7be8aaf8-650e-4dbb-8186-0a749840ecf2'," +
"'205f9bfc-018c-4452-a605-609c0cfad228']::uuid[], ARRAY['172.16.0.41', " +
"'172.16.0.42']::inet[], ARRAY['192.168.0.0/24', '10.1.0.0/16']::cidr[], " +
"""ARRAY['{"a": "foo", "b": "bar"}', '{"a": 1, "b": 2}']::json[], """ +
"""ARRAY['{"a": 1, "b": 2, "c": 3}']::jsonb[])"""
).executeUpdate()
conn.prepareStatement("CREATE TABLE char_types (" +
"c0 char(4), c1 character(4), c2 character varying(4), c3 varchar(4), c4 bpchar)"
).executeUpdate()
conn.prepareStatement("INSERT INTO char_types VALUES " +
"('abcd', 'efgh', 'ijkl', 'mnop', 'q')").executeUpdate()
conn.prepareStatement("CREATE TABLE char_array_types (" +
"c0 char(4)[], c1 character(4)[], c2 character varying(4)[], c3 varchar(4)[], c4 bpchar[])"
).executeUpdate()
conn.prepareStatement("INSERT INTO char_array_types VALUES " +
"""('{"a", "bcd"}', '{"ef", "gh"}', '{"i", "j", "kl"}', '{"mnop"}', '{"q", "r"}')"""
).executeUpdate()
}
test("Type mapping for various types") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
val rows = df.collect().sortBy(_.toString())
assert(rows.length == 2)
// Test the types, and values using the first row.
val types = rows(0).toSeq.map(x => x.getClass)
assert(types.length == 18)
assert(classOf[String].isAssignableFrom(types(0)))
assert(classOf[java.lang.Integer].isAssignableFrom(types(1)))
assert(classOf[java.lang.Double].isAssignableFrom(types(2)))
assert(classOf[java.lang.Long].isAssignableFrom(types(3)))
assert(classOf[java.lang.Boolean].isAssignableFrom(types(4)))
assert(classOf[Array[Byte]].isAssignableFrom(types(5)))
assert(classOf[Array[Byte]].isAssignableFrom(types(6)))
assert(classOf[java.lang.Boolean].isAssignableFrom(types(7)))
assert(classOf[String].isAssignableFrom(types(8)))
assert(classOf[String].isAssignableFrom(types(9)))
assert(classOf[scala.collection.Seq[Int]].isAssignableFrom(types(10)))
assert(classOf[scala.collection.Seq[String]].isAssignableFrom(types(11)))
assert(classOf[scala.collection.Seq[Double]].isAssignableFrom(types(12)))
assert(classOf[scala.collection.Seq[BigDecimal]].isAssignableFrom(types(13)))
assert(classOf[String].isAssignableFrom(types(14)))
assert(classOf[java.lang.Float].isAssignableFrom(types(15)))
assert(classOf[java.lang.Short].isAssignableFrom(types(16)))
assert(classOf[scala.collection.Seq[BigDecimal]].isAssignableFrom(types(17)))
assert(rows(0).getString(0).equals("hello"))
assert(rows(0).getInt(1) == 42)
assert(rows(0).getDouble(2) == 1.25)
assert(rows(0).getLong(3) == 123456789012345L)
assert(!rows(0).getBoolean(4))
// BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's...
assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5),
Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49)))
assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6),
Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte)))
assert(rows(0).getBoolean(7))
assert(rows(0).getString(8) == "172.16.0.42")
assert(rows(0).getString(9) == "192.168.0.0/16")
assert(rows(0).getSeq(10) == Seq(1, 2))
assert(rows(0).getSeq(11) == Seq("a", null, "b"))
assert(rows(0).getSeq(12).toSeq == Seq(0.11f, 0.22f))
assert(rows(0).getSeq(13) == Seq("0.11", "0.22").map(BigDecimal(_).bigDecimal))
assert(rows(0).getString(14) == "d1")
assert(rows(0).getFloat(15) == 1.01f)
assert(rows(0).getShort(16) == 1)
assert(rows(0).getSeq(17) ==
Seq("111.222200000000000000", "333.444400000000000000").map(BigDecimal(_).bigDecimal))
// Test reading null values using the second row.
assert(0.until(16).forall(rows(1).isNullAt(_)))
}
test("Basic write test") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
// Test only that it doesn't crash.
df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
// Test that written numeric type has same DataType as input
assert(sqlContext.read.jdbc(jdbcUrl, "public.barcopy", new Properties).schema(13).dataType ==
ArrayType(DecimalType(2, 2), true))
// Test write null values.
df.select(df.queryExecution.analyzed.output.map { a =>
Column(Literal.create(null, a.dataType)).as(a.name)
}: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties)
}
test("Creating a table with shorts and floats") {
sqlContext.createDataFrame(Seq((1.0f, 1.toShort)))
.write.jdbc(jdbcUrl, "shortfloat", new Properties)
val schema = sqlContext.read.jdbc(jdbcUrl, "shortfloat", new Properties).schema
assert(schema(0).dataType == FloatType)
assert(schema(1).dataType == ShortType)
}
test("SPARK-20557: column type TIMESTAMP with TIME ZONE and TIME with TIME ZONE " +
"should be recognized") {
// When using JDBC to read the columns of TIMESTAMP with TIME ZONE and TIME with TIME ZONE
// the actual types are java.sql.Types.TIMESTAMP and java.sql.Types.TIME
val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties)
val rows = dfRead.collect()
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types(1).equals("class java.sql.Timestamp"))
assert(types(2).equals("class java.sql.Timestamp"))
}
test("SPARK-22291: Conversion error when transforming array types of " +
"uuid, inet and cidr to StingType in PostgreSQL") {
val df = sqlContext.read.jdbc(jdbcUrl, "st_with_array", new Properties)
val rows = df.collect()
assert(rows(0).getString(0) == "0a532531-cdf1-45e3-963d-5de90b6a30f1")
assert(rows(0).getString(1) == "172.168.22.1")
assert(rows(0).getString(2) == "192.168.100.128/25")
assert(rows(0).getString(3) == "{\"a\": \"foo\", \"b\": \"bar\"}")
assert(rows(0).getString(4) == "{\"a\": 1, \"b\": 2}")
assert(rows(0).getSeq(5) == Seq("7be8aaf8-650e-4dbb-8186-0a749840ecf2",
"205f9bfc-018c-4452-a605-609c0cfad228"))
assert(rows(0).getSeq(6) == Seq("172.16.0.41", "172.16.0.42"))
assert(rows(0).getSeq(7) == Seq("192.168.0.0/24", "10.1.0.0/16"))
assert(rows(0).getSeq(8) == Seq("""{"a": "foo", "b": "bar"}""", """{"a": 1, "b": 2}"""))
assert(rows(0).getSeq(9) == Seq("""{"a": 1, "b": 2, "c": 3}"""))
}
test("query JDBC option") {
val expectedResult = Set(
(42, 123456789012345L)
).map { case (c1, c3) =>
Row(Integer.valueOf(c1), java.lang.Long.valueOf(c3))
}
val query = "SELECT c1, c3 FROM bar WHERE c1 IS NOT NULL"
// query option to pass on the query string.
val df = spark.read.format("jdbc")
.option("url", jdbcUrl)
.option("query", query)
.load()
assert(df.collect.toSet === expectedResult)
// query option in the create table path.
sql(
s"""
|CREATE OR REPLACE TEMPORARY VIEW queryOption
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$jdbcUrl', query '$query')
""".stripMargin.replaceAll("\n", " "))
assert(sql("select c1, c3 from queryOption").collect.toSet == expectedResult)
}
test("write byte as smallint") {
sqlContext.createDataFrame(Seq((1.toByte, 2.toShort)))
.write.jdbc(jdbcUrl, "byte_to_smallint_test", new Properties)
val df = sqlContext.read.jdbc(jdbcUrl, "byte_to_smallint_test", new Properties)
val schema = df.schema
assert(schema.head.dataType == ShortType)
assert(schema(1).dataType == ShortType)
val rows = df.collect()
assert(rows.length === 1)
assert(rows(0).getShort(0) === 1)
assert(rows(0).getShort(1) === 2)
}
test("character type tests") {
val df = sqlContext.read.jdbc(jdbcUrl, "char_types", new Properties)
val row = df.collect()
assert(row.length == 1)
assert(row(0).length === 5)
assert(row(0).getString(0) === "abcd")
assert(row(0).getString(1) === "efgh")
assert(row(0).getString(2) === "ijkl")
assert(row(0).getString(3) === "mnop")
assert(row(0).getString(4) === "q")
}
test("SPARK-32576: character array type tests") {
val df = sqlContext.read.jdbc(jdbcUrl, "char_array_types", new Properties)
val row = df.collect()
assert(row.length == 1)
assert(row(0).length === 5)
assert(row(0).getSeq[String](0) === Seq("a ", "bcd "))
assert(row(0).getSeq[String](1) === Seq("ef ", "gh "))
assert(row(0).getSeq[String](2) === Seq("i", "j", "kl"))
assert(row(0).getSeq[String](3) === Seq("mnop"))
assert(row(0).getSeq[String](4) === Seq("q", "r"))
}
}