-
Notifications
You must be signed in to change notification settings - Fork 28k
/
DataTypeWriteCompatibilitySuite.scala
501 lines (407 loc) · 20.8 KB
/
DataTypeWriteCompatibilitySuite.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.types
import scala.collection.mutable
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
class StrictDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBaseSuite {
override def storeAssignmentPolicy: SQLConf.StoreAssignmentPolicy.Value =
StoreAssignmentPolicy.STRICT
override def canCast: (DataType, DataType) => Boolean = Cast.canUpCast
test("Check struct types: unsafe casts are not allowed") {
assertNumErrors(widerPoint2, point2, "t",
"Should fail because types require unsafe casts", 2) { errs =>
assert(errs(0).contains("'t.x'"), "Should include the nested field name context")
assert(errs(0).contains("Cannot safely cast"))
assert(errs(1).contains("'t.y'"), "Should include the nested field name context")
assert(errs(1).contains("Cannot safely cast"))
}
}
test("Check array types: unsafe casts are not allowed") {
val arrayOfLong = ArrayType(LongType)
val arrayOfInt = ArrayType(IntegerType)
assertSingleError(arrayOfLong, arrayOfInt, "arr",
"Should not allow array of longs to array of ints") { err =>
assert(err.contains("'arr.element'"),
"Should identify problem with named array's element type")
assert(err.contains("Cannot safely cast"))
}
}
test("Check map value types: casting Long to Integer is not allowed") {
val mapOfLong = MapType(StringType, LongType)
val mapOfInt = MapType(StringType, IntegerType)
assertSingleError(mapOfLong, mapOfInt, "m",
"Should not allow map of longs to map of ints") { err =>
assert(err.contains("'m.value'"), "Should identify problem with named map's value type")
assert(err.contains("Cannot safely cast"))
}
}
test("Check map key types: unsafe casts are not allowed") {
val mapKeyLong = MapType(LongType, StringType)
val mapKeyInt = MapType(IntegerType, StringType)
assertSingleError(mapKeyLong, mapKeyInt, "m",
"Should not allow map of long keys to map of int keys") { err =>
assert(err.contains("'m.key'"), "Should identify problem with named map's key type")
assert(err.contains("Cannot safely cast"))
}
}
}
class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBaseSuite {
override protected def storeAssignmentPolicy: SQLConf.StoreAssignmentPolicy.Value =
StoreAssignmentPolicy.ANSI
override def canCast: (DataType, DataType) => Boolean = Cast.canANSIStoreAssign
test("Check map value types: unsafe casts are not allowed") {
val mapOfString = MapType(StringType, StringType)
val mapOfInt = MapType(StringType, IntegerType)
assertSingleError(mapOfString, mapOfInt, "m",
"Should not allow map of strings to map of ints") { err =>
assert(err.contains("'m.value'"), "Should identify problem with named map's value type")
assert(err.contains("Cannot safely cast"))
}
}
private val stringPoint2 = StructType(Seq(
StructField("x", StringType, nullable = false),
StructField("y", StringType, nullable = false)))
test("Check struct types: unsafe casts are not allowed") {
assertNumErrors(stringPoint2, point2, "t",
"Should fail because types require unsafe casts", 2) { errs =>
assert(errs(0).contains("'t.x'"), "Should include the nested field name context")
assert(errs(0).contains("Cannot safely cast"))
assert(errs(1).contains("'t.y'"), "Should include the nested field name context")
assert(errs(1).contains("Cannot safely cast"))
}
}
test("Check array types: unsafe casts are not allowed") {
val arrayOfString = ArrayType(StringType)
val arrayOfInt = ArrayType(IntegerType)
assertSingleError(arrayOfString, arrayOfInt, "arr",
"Should not allow array of strings to array of ints") { err =>
assert(err.contains("'arr.element'"),
"Should identify problem with named array's element type")
assert(err.contains("Cannot safely cast"))
}
}
test("Check map key types: unsafe casts are not allowed") {
val mapKeyString = MapType(StringType, StringType)
val mapKeyInt = MapType(IntegerType, StringType)
assertSingleError(mapKeyString, mapKeyInt, "m",
"Should not allow map of string keys to map of int keys") { err =>
assert(err.contains("'m.key'"), "Should identify problem with named map's key type")
assert(err.contains("Cannot safely cast"))
}
}
test("Conversions between timestamp and long are not allowed") {
assertSingleError(LongType, TimestampType, "longToTimestamp",
"Should not allow long to timestamp") { err =>
assert(err.contains("Cannot safely cast 'longToTimestamp': LongType to TimestampType"))
}
assertSingleError(TimestampType, LongType, "timestampToLong",
"Should not allow timestamp to long") { err =>
assert(err.contains("Cannot safely cast 'timestampToLong': TimestampType to LongType"))
}
}
}
abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite {
protected def storeAssignmentPolicy: StoreAssignmentPolicy.Value
protected def canCast: (DataType, DataType) => Boolean
protected val atomicTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DateType, TimestampType, StringType, BinaryType)
protected val point2 = StructType(Seq(
StructField("x", FloatType, nullable = false),
StructField("y", FloatType, nullable = false)))
protected val widerPoint2 = StructType(Seq(
StructField("x", DoubleType, nullable = false),
StructField("y", DoubleType, nullable = false)))
protected val point3 = StructType(Seq(
StructField("x", FloatType, nullable = false),
StructField("y", FloatType, nullable = false),
StructField("z", FloatType)))
private val simpleContainerTypes = Seq(
ArrayType(LongType), ArrayType(LongType, containsNull = false), MapType(StringType, DoubleType),
MapType(StringType, DoubleType, valueContainsNull = false), point2, point3)
private val nestedContainerTypes = Seq(ArrayType(point2, containsNull = false),
MapType(StringType, point3, valueContainsNull = false))
private val allNonNullTypes = Seq(
atomicTypes, simpleContainerTypes, nestedContainerTypes, Seq(CalendarIntervalType)).flatten
test("Check NullType is incompatible with all other types") {
allNonNullTypes.foreach { t =>
assertSingleError(NullType, t, "nulls", s"Should not allow writing None to type $t") { err =>
assert(err.contains(s"incompatible with $t"))
}
}
}
test("Check each type with itself") {
allNonNullTypes.foreach { t =>
assertAllowed(t, t, "t", s"Should allow writing type to itself $t")
}
}
test("Check atomic types: write allowed only when casting is safe") {
atomicTypes.foreach { w =>
atomicTypes.foreach { r =>
if (canCast(w, r)) {
assertAllowed(w, r, "t", s"Should allow writing $w to $r because cast is safe")
} else {
assertSingleError(w, r, "t",
s"Should not allow writing $w to $r because cast is not safe") { err =>
assert(err.contains("'t'"), "Should include the field name context")
assert(err.contains("Cannot safely cast"), "Should identify unsafe cast")
assert(err.contains(s"$w"), "Should include write type")
assert(err.contains(s"$r"), "Should include read type")
}
}
}
}
}
test("Check struct types: missing required field") {
val missingRequiredField = StructType(Seq(StructField("x", FloatType, nullable = false)))
assertSingleError(missingRequiredField, point2, "t",
"Should fail because required field 'y' is missing") { err =>
assert(err.contains("'t'"), "Should include the struct name for context")
assert(err.contains("'y'"), "Should include the nested field name")
assert(err.contains("missing field"), "Should call out field missing")
}
}
test("Check struct types: missing starting field, matched by position") {
val missingRequiredField = StructType(Seq(StructField("y", FloatType, nullable = false)))
// should have 2 errors: names x and y don't match, and field y is missing
assertNumErrors(missingRequiredField, point2, "t",
"Should fail because field 'x' is matched to field 'y' and required field 'y' is missing", 2)
{ errs =>
assert(errs(0).contains("'t'"), "Should include the struct name for context")
assert(errs(0).contains("expected 'x', found 'y'"), "Should detect name mismatch")
assert(errs(0).contains("field name does not match"), "Should identify name problem")
assert(errs(1).contains("'t'"), "Should include the struct name for context")
assert(errs(1).contains("'y'"), "Should include the _last_ nested fields of the read schema")
assert(errs(1).contains("missing field"), "Should call out field missing")
}
}
test("Check struct types: missing middle field, matched by position") {
val missingMiddleField = StructType(Seq(
StructField("x", FloatType, nullable = false),
StructField("z", FloatType, nullable = false)))
val expectedStruct = StructType(Seq(
StructField("x", FloatType, nullable = false),
StructField("y", FloatType, nullable = false),
StructField("z", FloatType, nullable = true)))
// types are compatible: (req int, req int) => (req int, req int, opt int)
// but this should still fail because the names do not match.
assertNumErrors(missingMiddleField, expectedStruct, "t",
"Should fail because field 'y' is matched to field 'z'", 2) { errs =>
assert(errs(0).contains("'t'"), "Should include the struct name for context")
assert(errs(0).contains("expected 'y', found 'z'"), "Should detect name mismatch")
assert(errs(0).contains("field name does not match"), "Should identify name problem")
assert(errs(1).contains("'t'"), "Should include the struct name for context")
assert(errs(1).contains("'z'"), "Should include the nested field name")
assert(errs(1).contains("missing field"), "Should call out field missing")
}
}
test("Check struct types: generic colN names are ignored") {
val missingMiddleField = StructType(Seq(
StructField("col1", FloatType, nullable = false),
StructField("col2", FloatType, nullable = false)))
val expectedStruct = StructType(Seq(
StructField("x", FloatType, nullable = false),
StructField("y", FloatType, nullable = false)))
// types are compatible: (req int, req int) => (req int, req int)
// names don't match, but match the naming convention used by Spark to fill in names
assertAllowed(missingMiddleField, expectedStruct, "t",
"Should succeed because column names are ignored")
}
test("Check struct types: required field is optional") {
val requiredFieldIsOptional = StructType(Seq(
StructField("x", FloatType),
StructField("y", FloatType, nullable = false)))
assertSingleError(requiredFieldIsOptional, point2, "t",
"Should fail because required field 'x' is optional") { err =>
assert(err.contains("'t.x'"), "Should include the nested field name context")
assert(err.contains("Cannot write nullable values to non-null field"))
}
}
test("Check struct types: data field would be dropped") {
assertSingleError(point3, point2, "t",
"Should fail because field 'z' would be dropped") { err =>
assert(err.contains("'t'"), "Should include the struct name for context")
assert(err.contains("'z'"), "Should include the extra field name")
assert(err.contains("Cannot write extra fields"))
}
}
test("Check struct types: type promotion is allowed") {
assertAllowed(point2, widerPoint2, "t",
"Should allow widening float fields x and y to double")
}
test("Check struct type: ignore field name mismatch with byPosition mode") {
val nameMismatchFields = StructType(Seq(
StructField("a", FloatType, nullable = false),
StructField("b", FloatType, nullable = false)))
assertAllowed(nameMismatchFields, point2, "t",
"Should allow field name mismatch with byPosition mode", false)
}
ignore("Check struct types: missing optional field is allowed") {
// built-in data sources do not yet support missing fields when optional
assertAllowed(point2, point3, "t",
"Should allow writing point (x,y) to point(x,y,z=null)")
}
test("Check array types: type promotion is allowed") {
val arrayOfLong = ArrayType(LongType)
val arrayOfInt = ArrayType(IntegerType)
assertAllowed(arrayOfInt, arrayOfLong, "arr",
"Should allow array of int written to array of long column")
}
test("Check array types: cannot write optional to required elements") {
val arrayOfRequired = ArrayType(LongType, containsNull = false)
val arrayOfOptional = ArrayType(LongType)
assertSingleError(arrayOfOptional, arrayOfRequired, "arr",
"Should not allow array of optional elements to array of required elements") { err =>
assert(err.contains("'arr'"), "Should include type name context")
assert(err.contains("Cannot write nullable elements to array of non-nulls"))
}
}
test("Check array types: writing required to optional elements is allowed") {
val arrayOfRequired = ArrayType(LongType, containsNull = false)
val arrayOfOptional = ArrayType(LongType)
assertAllowed(arrayOfRequired, arrayOfOptional, "arr",
"Should allow array of required elements to array of optional elements")
}
test("Check map value types: type promotion is allowed") {
val mapOfLong = MapType(StringType, LongType)
val mapOfInt = MapType(StringType, IntegerType)
assertAllowed(mapOfInt, mapOfLong, "m", "Should allow map of int written to map of long column")
}
test("Check map value types: cannot write optional to required values") {
val mapOfRequired = MapType(StringType, LongType, valueContainsNull = false)
val mapOfOptional = MapType(StringType, LongType)
assertSingleError(mapOfOptional, mapOfRequired, "m",
"Should not allow map of optional values to map of required values") { err =>
assert(err.contains("'m'"), "Should include type name context")
assert(err.contains("Cannot write nullable values to map of non-nulls"))
}
}
test("Check map value types: writing required to optional values is allowed") {
val mapOfRequired = MapType(StringType, LongType, valueContainsNull = false)
val mapOfOptional = MapType(StringType, LongType)
assertAllowed(mapOfRequired, mapOfOptional, "m",
"Should allow map of required elements to map of optional elements")
}
test("Check map key types: type promotion is allowed") {
val mapKeyLong = MapType(LongType, StringType)
val mapKeyInt = MapType(IntegerType, StringType)
assertAllowed(mapKeyInt, mapKeyLong, "m",
"Should allow map of int written to map of long column")
}
test("Check types with multiple errors") {
val readType = StructType(Seq(
StructField("a", ArrayType(DoubleType, containsNull = false)),
StructField("arr_of_structs", ArrayType(point2, containsNull = false)),
StructField("bad_nested_type", ArrayType(StringType)),
StructField("m", MapType(LongType, FloatType, valueContainsNull = false)),
StructField("map_of_structs", MapType(StringType, point3, valueContainsNull = false)),
StructField("x", IntegerType, nullable = false),
StructField("missing1", StringType, nullable = false),
StructField("missing2", StringType)
))
val missingMiddleField = StructType(Seq(
StructField("x", FloatType, nullable = false),
StructField("z", FloatType, nullable = false)))
val writeType = StructType(Seq(
StructField("a", ArrayType(StringType)),
StructField("arr_of_structs", ArrayType(point3)),
StructField("bad_nested_type", point3),
StructField("m", MapType(StringType, BooleanType)),
StructField("map_of_structs", MapType(StringType, missingMiddleField)),
StructField("y", StringType)
))
assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) { errs =>
assert(errs(0).contains("'top.a.element'"), "Should identify bad type")
assert(errs(0).contains("Cannot safely cast"))
assert(errs(0).contains("StringType to DoubleType"))
assert(errs(1).contains("'top.a'"), "Should identify bad type")
assert(errs(1).contains("Cannot write nullable elements to array of non-nulls"))
assert(errs(2).contains("'top.arr_of_structs.element'"), "Should identify bad type")
assert(errs(2).contains("'z'"), "Should identify bad field")
assert(errs(2).contains("Cannot write extra fields to struct"))
assert(errs(3).contains("'top.arr_of_structs'"), "Should identify bad type")
assert(errs(3).contains("Cannot write nullable elements to array of non-nulls"))
assert(errs(4).contains("'top.bad_nested_type'"), "Should identify bad type")
assert(errs(4).contains("is incompatible with"))
assert(errs(5).contains("'top.m.key'"), "Should identify bad type")
assert(errs(5).contains("Cannot safely cast"))
assert(errs(5).contains("StringType to LongType"))
assert(errs(6).contains("'top.m.value'"), "Should identify bad type")
assert(errs(6).contains("Cannot safely cast"))
assert(errs(6).contains("BooleanType to FloatType"))
assert(errs(7).contains("'top.m'"), "Should identify bad type")
assert(errs(7).contains("Cannot write nullable values to map of non-nulls"))
assert(errs(8).contains("'top.map_of_structs.value'"), "Should identify bad type")
assert(errs(8).contains("expected 'y', found 'z'"), "Should detect name mismatch")
assert(errs(8).contains("field name does not match"), "Should identify name problem")
assert(errs(9).contains("'top.map_of_structs.value'"), "Should identify bad type")
assert(errs(9).contains("'z'"), "Should identify missing field")
assert(errs(9).contains("missing fields"), "Should detect missing field")
assert(errs(10).contains("'top.map_of_structs'"), "Should identify bad type")
assert(errs(10).contains("Cannot write nullable values to map of non-nulls"))
assert(errs(11).contains("'top.x'"), "Should identify bad type")
assert(errs(11).contains("Cannot safely cast"))
assert(errs(11).contains("StringType to IntegerType"))
assert(errs(12).contains("'top'"), "Should identify bad type")
assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name mismatch")
assert(errs(12).contains("field name does not match"), "Should identify name problem")
assert(errs(13).contains("'top'"), "Should identify bad type")
assert(errs(13).contains("'missing1'"), "Should identify missing field")
assert(errs(13).contains("missing fields"), "Should detect missing field")
}
}
// Helper functions
def assertAllowed(
writeType: DataType,
readType: DataType,
name: String,
desc: String,
byName: Boolean = true): Unit = {
assert(
DataType.canWrite(writeType, readType, byName, analysis.caseSensitiveResolution, name,
storeAssignmentPolicy,
errMsg => fail(s"Should not produce errors but was called with: $errMsg")), desc)
}
def assertSingleError(
writeType: DataType,
readType: DataType,
name: String,
desc: String)
(errFunc: String => Unit): Unit = {
assertNumErrors(writeType, readType, name, desc, 1) { errs =>
errFunc(errs.head)
}
}
def assertNumErrors(
writeType: DataType,
readType: DataType,
name: String,
desc: String,
numErrs: Int,
byName: Boolean = true)
(checkErrors: Seq[String] => Unit): Unit = {
val errs = new mutable.ArrayBuffer[String]()
assert(
DataType.canWrite(writeType, readType, byName, analysis.caseSensitiveResolution, name,
storeAssignmentPolicy, errMsg => errs += errMsg) === false, desc)
assert(errs.size === numErrs, s"Should produce $numErrs error messages")
checkErrors(errs)
}
}