-
Notifications
You must be signed in to change notification settings - Fork 28k
/
JdbcDialects.scala
250 lines (226 loc) · 9.29 KB
/
JdbcDialects.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.jdbc
import java.sql.{Connection, Date, Timestamp}
import org.apache.commons.lang3.StringUtils
import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since}
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.types._
/**
* :: DeveloperApi ::
* A database type definition coupled with the jdbc type needed to send null
* values to the database.
* @param databaseTypeDefinition The database type definition
* @param jdbcNullType The jdbc type (as defined in java.sql.Types) used to
* send a null value to the database.
*/
@DeveloperApi
@InterfaceStability.Evolving
case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int)
/**
* :: DeveloperApi ::
* Encapsulates everything (extensions, workarounds, quirks) to handle the
* SQL dialect of a certain database or jdbc driver.
* Lots of databases define types that aren't explicitly supported
* by the JDBC spec. Some JDBC drivers also report inaccurate
* information---for instance, BIT(n{@literal >}1) being reported as a BIT type is quite
* common, even though BIT in JDBC is meant for single-bit values. Also, there
* does not appear to be a standard name for an unbounded string or binary
* type; we use BLOB and CLOB by default but override with database-specific
* alternatives when these are absent or do not behave correctly.
*
* Currently, the only thing done by the dialect is type mapping.
* `getCatalystType` is used when reading from a JDBC table and `getJDBCType`
* is used when writing to a JDBC table. If `getCatalystType` returns `null`,
* the default type handling is used for the given JDBC type. Similarly,
* if `getJDBCType` returns `(null, None)`, the default type handling is used
* for the given Catalyst type.
*/
@DeveloperApi
@InterfaceStability.Evolving
abstract class JdbcDialect extends Serializable {
/**
* Check if this dialect instance can handle a certain jdbc url.
* @param url the jdbc url.
* @return True if the dialect can be applied on the given jdbc url.
* @throws NullPointerException if the url is null.
*/
def canHandle(url : String): Boolean
/**
* Get the custom datatype mapping for the given jdbc meta information.
* @param sqlType The sql type (see java.sql.Types)
* @param typeName The sql type name (e.g. "BIGINT UNSIGNED")
* @param size The size of the type.
* @param md Result metadata associated with this type.
* @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]])
* or null if the default type mapping should be used.
*/
def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None
/**
* Retrieve the jdbc / sql type for a given datatype.
* @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
* @return The new JdbcType if there is an override for this DataType
*/
def getJDBCType(dt: DataType): Option[JdbcType] = None
/**
* Quotes the identifier. This is used to put quotes around the identifier in case the column
* name is a reserved keyword, or in case it contains characters that require quotes (e.g. space).
*/
def quoteIdentifier(colName: String): String = {
s""""$colName""""
}
/**
* Get the SQL query that should be used to find if the given table exists. Dialects can
* override this method to return a query that works best in a particular database.
* @param table The name of the table.
* @return The SQL query to use for checking the table.
*/
def getTableExistsQuery(table: String): String = {
s"SELECT * FROM $table WHERE 1=0"
}
/**
* The SQL query that should be used to discover the schema of a table. It only needs to
* ensure that the result set has the same schema as the table, such as by calling
* "SELECT * ...". Dialects can override this method to return a query that works best in a
* particular database.
* @param table The name of the table.
* @return The SQL query to use for discovering the schema.
*/
@Since("2.1.0")
def getSchemaQuery(table: String): String = {
s"SELECT * FROM $table WHERE 1=0"
}
/**
* The SQL query that should be used to truncate a table. Dialects can override this method to
* return a query that is suitable for a particular database. For PostgreSQL, for instance,
* a different query is used to prevent "TRUNCATE" affecting other tables.
* @param table The table to truncate
* @return The SQL query to use for truncating a table
*/
@Since("2.3.0")
def getTruncateQuery(table: String): String = {
getTruncateQuery(table, isCascadingTruncateTable)
}
/**
* The SQL query that should be used to truncate a table. Dialects can override this method to
* return a query that is suitable for a particular database. For PostgreSQL, for instance,
* a different query is used to prevent "TRUNCATE" affecting other tables.
* @param table The table to truncate
* @param cascade Whether or not to cascade the truncation
* @return The SQL query to use for truncating a table
*/
@Since("2.4.0")
def getTruncateQuery(
table: String,
cascade: Option[Boolean] = isCascadingTruncateTable): String = {
s"TRUNCATE TABLE $table"
}
/**
* Override connection specific properties to run before a select is made. This is in place to
* allow dialects that need special treatment to optimize behavior.
* @param connection The connection object
* @param properties The connection properties. This is passed through from the relation.
*/
def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
}
/**
* Escape special characters in SQL string literals.
* @param value The string to be escaped.
* @return Escaped string.
*/
@Since("2.3.0")
protected[jdbc] def escapeSql(value: String): String =
if (value == null) null else StringUtils.replace(value, "'", "''")
/**
* Converts value to SQL expression.
* @param value The value to be converted.
* @return Converted value.
*/
@Since("2.3.0")
def compileValue(value: Any): Any = value match {
case stringValue: String => s"'${escapeSql(stringValue)}'"
case timestampValue: Timestamp => "'" + timestampValue + "'"
case dateValue: Date => "'" + dateValue + "'"
case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ")
case _ => value
}
/**
* Return Some[true] iff `TRUNCATE TABLE` causes cascading default.
* Some[true] : TRUNCATE TABLE causes cascading.
* Some[false] : TRUNCATE TABLE does not cause cascading.
* None: The behavior of TRUNCATE TABLE is unknown (default).
*/
def isCascadingTruncateTable(): Option[Boolean] = None
}
/**
* :: DeveloperApi ::
* Registry of dialects that apply to every new jdbc `org.apache.spark.sql.DataFrame`.
*
* If multiple matching dialects are registered then all matching ones will be
* tried in reverse order. A user-added dialect will thus be applied first,
* overwriting the defaults.
*
* @note All new dialects are applied to new jdbc DataFrames only. Make
* sure to register your dialects first.
*/
@DeveloperApi
@InterfaceStability.Evolving
object JdbcDialects {
/**
* Register a dialect for use on all new matching jdbc `org.apache.spark.sql.DataFrame`.
* Reading an existing dialect will cause a move-to-front.
*
* @param dialect The new dialect.
*/
def registerDialect(dialect: JdbcDialect) : Unit = {
dialects = dialect :: dialects.filterNot(_ == dialect)
}
/**
* Unregister a dialect. Does nothing if the dialect is not registered.
*
* @param dialect The jdbc dialect.
*/
def unregisterDialect(dialect : JdbcDialect) : Unit = {
dialects = dialects.filterNot(_ == dialect)
}
private[this] var dialects = List[JdbcDialect]()
registerDialect(MySQLDialect)
registerDialect(PostgresDialect)
registerDialect(DB2Dialect)
registerDialect(MsSqlServerDialect)
registerDialect(DerbyDialect)
registerDialect(OracleDialect)
registerDialect(TeradataDialect)
/**
* Fetch the JdbcDialect class corresponding to a given database url.
*/
def get(url: String): JdbcDialect = {
val matchingDialects = dialects.filter(_.canHandle(url))
matchingDialects.length match {
case 0 => NoopDialect
case 1 => matchingDialects.head
case _ => new AggregatedDialect(matchingDialects)
}
}
}
/**
* NOOP dialect object, always returning the neutral element.
*/
private object NoopDialect extends JdbcDialect {
override def canHandle(url : String): Boolean = true
}