/
CodeGenerator.scala
1864 lines (1689 loc) · 73.4 KB
/
CodeGenerator.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.codegen
import java.io.ByteArrayInputStream
import java.util.{Map => JavaMap}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
import com.google.common.cache.{CacheBuilder, CacheLoader}
import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}
import org.codehaus.commons.compiler.CompileException
import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, InternalCompilerException, SimpleCompiler}
import org.codehaus.janino.util.ClassFile
import org.apache.spark.{TaskContext, TaskKilledException}
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types._
import org.apache.spark.util.{ParentClassLoader, Utils}
/**
* Java source for evaluating an [[Expression]] given a [[InternalRow]] of input.
*
* @param code The sequence of statements required to evaluate the expression.
* It should be empty string, if `isNull` and `value` are already existed, or no code
* needed to evaluate them (literals).
* @param isNull A term that holds a boolean value representing whether the expression evaluated
* to null.
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
* valid if `isNull` is set to `true`.
*/
case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue)
object ExprCode {
def apply(isNull: ExprValue, value: ExprValue): ExprCode = {
ExprCode(code = EmptyBlock, isNull, value)
}
def forNullValue(dataType: DataType): ExprCode = {
ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
}
def forNonNullValue(value: ExprValue): ExprCode = {
ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value)
}
}
/**
* State used for subexpression elimination.
*
* @param isNull A term that holds a boolean value representing whether the expression evaluated
* to null.
* @param value A term for a value of a common sub-expression. Not valid if `isNull`
* is set to `true`.
*/
case class SubExprEliminationState(isNull: ExprValue, value: ExprValue)
/**
* Codes and common subexpressions mapping used for subexpression elimination.
*
* @param codes Strings representing the codes that evaluate common subexpressions.
* @param states Foreach expression that is participating in subexpression elimination,
* the state to use.
*/
case class SubExprCodes(codes: Seq[String], states: Map[Expression, SubExprEliminationState])
/**
* The main information about a new added function.
*
* @param functionName String representing the name of the function
* @param innerClassName Optional value which is empty if the function is added to
* the outer class, otherwise it contains the name of the
* inner class in which the function has been added.
* @param innerClassInstance Optional value which is empty if the function is added to
* the outer class, otherwise it contains the name of the
* instance of the inner class in the outer class.
*/
private[codegen] case class NewFunctionSpec(
functionName: String,
innerClassName: Option[String],
innerClassInstance: Option[String])
/**
* A context for codegen, tracking a list of objects that could be passed into generated Java
* function.
*/
class CodegenContext extends Logging {
import CodeGenerator._
/**
* Holding a list of objects that could be used passed into generated class.
*/
val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]()
/**
* Add an object to `references`.
*
* Returns the code to access it.
*
* This does not to store the object into field but refer it from the references field at the
* time of use because number of fields in class is limited so we should reduce it.
*/
def addReferenceObj(objName: String, obj: Any, className: String = null): String = {
val idx = references.length
references += obj
val clsName = Option(className).getOrElse(obj.getClass.getName)
s"(($clsName) references[$idx] /* $objName */)"
}
/**
* Holding the variable name of the input row of the current operator, will be used by
* `BoundReference` to generate code.
*
* Note that if `currentVars` is not null, `BoundReference` prefers `currentVars` over `INPUT_ROW`
* to generate code. If you want to make sure the generated code use `INPUT_ROW`, you need to set
* `currentVars` to null, or set `currentVars(i)` to null for certain columns, before calling
* `Expression.genCode`.
*/
var INPUT_ROW = "i"
/**
* Holding a list of generated columns as input of current operator, will be used by
* BoundReference to generate code.
*/
var currentVars: Seq[ExprCode] = null
/**
* Holding expressions' inlined mutable states like `MonotonicallyIncreasingID.count` as a
* 2-tuple: java type, variable name.
* As an example, ("int", "count") will produce code:
* {{{
* private int count;
* }}}
* as a member variable
*
* They will be kept as member variables in generated classes like `SpecificProjection`.
*
* Exposed for tests only.
*/
private[catalyst] val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] =
mutable.ArrayBuffer.empty[(String, String)]
/**
* The mapping between mutable state types and corrseponding compacted arrays.
* The keys are java type string. The values are [[MutableStateArrays]] which encapsulates
* the compacted arrays for the mutable states with the same java type.
*
* Exposed for tests only.
*/
private[catalyst] val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] =
mutable.Map.empty[String, MutableStateArrays]
// An array holds the code that will initialize each state
// Exposed for tests only.
private[catalyst] val mutableStateInitCode: mutable.ArrayBuffer[String] =
mutable.ArrayBuffer.empty[String]
// Tracks the names of all the mutable states.
private val mutableStateNames: mutable.HashSet[String] = mutable.HashSet.empty
/**
* This class holds a set of names of mutableStateArrays that is used for compacting mutable
* states for a certain type, and holds the next available slot of the current compacted array.
*/
class MutableStateArrays {
val arrayNames = mutable.ListBuffer.empty[String]
createNewArray()
private[this] var currentIndex = 0
private def createNewArray() = {
val newArrayName = freshName("mutableStateArray")
mutableStateNames += newArrayName
arrayNames.append(newArrayName)
}
def getCurrentIndex: Int = currentIndex
/**
* Returns the reference of next available slot in current compacted array. The size of each
* compacted array is controlled by the constant `MUTABLESTATEARRAY_SIZE_LIMIT`.
* Once reaching the threshold, new compacted array is created.
*/
def getNextSlot(): String = {
if (currentIndex < MUTABLESTATEARRAY_SIZE_LIMIT) {
val res = s"${arrayNames.last}[$currentIndex]"
currentIndex += 1
res
} else {
createNewArray()
currentIndex = 1
s"${arrayNames.last}[0]"
}
}
}
/**
* A map containing the mutable states which have been defined so far using
* `addImmutableStateIfNotExists`. Each entry contains the name of the mutable state as key and
* its Java type and init code as value.
*/
private val immutableStates: mutable.Map[String, (String, String)] =
mutable.Map.empty[String, (String, String)]
/**
* Add a mutable state as a field to the generated class. c.f. the comments above.
*
* @param javaType Java type of the field. Note that short names can be used for some types,
* e.g. InternalRow, UnsafeRow, UnsafeArrayData, etc. Other types will have to
* specify the fully-qualified Java type name. See the code in doCompile() for
* the list of default imports available.
* Also, generic type arguments are accepted but ignored.
* @param variableName Name of the field.
* @param initFunc Function includes statement(s) to put into the init() method to initialize
* this field. The argument is the name of the mutable state variable.
* If left blank, the field will be default-initialized.
* @param forceInline whether the declaration and initialization code may be inlined rather than
* compacted. Please set `true` into forceInline for one of the followings:
* 1. use the original name of the status
* 2. expect to non-frequently generate the status
* (e.g. not much sort operators in one stage)
* @param useFreshName If this is false and the mutable state ends up inlining in the outer
* class, the name is not changed
* @return the name of the mutable state variable, which is the original name or fresh name if
* the variable is inlined to the outer class, or an array access if the variable is to
* be stored in an array of variables of the same type.
* A variable will be inlined into the outer class when one of the following conditions
* are satisfied:
* 1. forceInline is true
* 2. its type is primitive type and the total number of the inlined mutable variables
* is less than `OUTER_CLASS_VARIABLES_THRESHOLD`
* 3. its type is multi-dimensional array
* When a variable is compacted into an array, the max size of the array for compaction
* is given by `MUTABLESTATEARRAY_SIZE_LIMIT`.
*/
def addMutableState(
javaType: String,
variableName: String,
initFunc: String => String = _ => "",
forceInline: Boolean = false,
useFreshName: Boolean = true): String = {
// want to put a primitive type variable at outerClass for performance
val canInlinePrimitive = isPrimitiveType(javaType) &&
(inlinedMutableStates.length < OUTER_CLASS_VARIABLES_THRESHOLD)
if (forceInline || canInlinePrimitive || javaType.contains("[][]")) {
val varName = if (useFreshName) freshName(variableName) else variableName
val initCode = initFunc(varName)
inlinedMutableStates += ((javaType, varName))
mutableStateInitCode += initCode
mutableStateNames += varName
varName
} else {
val arrays = arrayCompactedMutableStates.getOrElseUpdate(javaType, new MutableStateArrays)
val element = arrays.getNextSlot()
val initCode = initFunc(element)
mutableStateInitCode += initCode
element
}
}
/**
* Add an immutable state as a field to the generated class only if it does not exist yet a field
* with that name. This helps reducing the number of the generated class' fields, since the same
* variable can be reused by many functions.
*
* Even though the added variables are not declared as final, they should never be reassigned in
* the generated code to prevent errors and unexpected behaviors.
*
* Internally, this method calls `addMutableState`.
*
* @param javaType Java type of the field.
* @param variableName Name of the field.
* @param initFunc Function includes statement(s) to put into the init() method to initialize
* this field. The argument is the name of the mutable state variable.
*/
def addImmutableStateIfNotExists(
javaType: String,
variableName: String,
initFunc: String => String = _ => ""): Unit = {
val existingImmutableState = immutableStates.get(variableName)
if (existingImmutableState.isEmpty) {
addMutableState(javaType, variableName, initFunc, useFreshName = false, forceInline = true)
immutableStates(variableName) = (javaType, initFunc(variableName))
} else {
val (prevJavaType, prevInitCode) = existingImmutableState.get
assert(prevJavaType == javaType, s"$variableName has already been defined with type " +
s"$prevJavaType and now it is tried to define again with type $javaType.")
assert(prevInitCode == initFunc(variableName), s"$variableName has already been defined " +
s"with different initialization statements.")
}
}
/**
* Add buffer variable which stores data coming from an [[InternalRow]]. This methods guarantees
* that the variable is safely stored, which is important for (potentially) byte array backed
* data types like: UTF8String, ArrayData, MapData & InternalRow.
*/
def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = {
val value = addMutableState(javaType(dataType), variableName)
val code = dataType match {
case StringType => code"$value = $initCode.clone();"
case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();"
case _ => code"$value = $initCode;"
}
ExprCode(code, FalseLiteral, JavaCode.global(value, dataType))
}
def declareMutableStates(): String = {
// It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
val inlinedStates = inlinedMutableStates.distinct.map { case (javaType, variableName) =>
s"private $javaType $variableName;"
}
val arrayStates = arrayCompactedMutableStates.flatMap { case (javaType, mutableStateArrays) =>
val numArrays = mutableStateArrays.arrayNames.size
mutableStateArrays.arrayNames.zipWithIndex.map { case (arrayName, index) =>
val length = if (index + 1 == numArrays) {
mutableStateArrays.getCurrentIndex
} else {
MUTABLESTATEARRAY_SIZE_LIMIT
}
if (javaType.contains("[]")) {
// initializer had an one-dimensional array variable
val baseType = javaType.substring(0, javaType.length - 2)
s"private $javaType[] $arrayName = new $baseType[$length][];"
} else {
// initializer had a scalar variable
s"private $javaType[] $arrayName = new $javaType[$length];"
}
}
}
(inlinedStates ++ arrayStates).mkString("\n")
}
def initMutableStates(): String = {
// It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
val initCodes = mutableStateInitCode.distinct.map(_ + "\n")
// The generated initialization code may exceed 64kb function size limit in JVM if there are too
// many mutable states, so split it into multiple functions.
splitExpressions(expressions = initCodes, funcName = "init", arguments = Nil)
}
/**
* Code statements to initialize states that depend on the partition index.
* An integer `partitionIndex` will be made available within the scope.
*/
val partitionInitializationStatements: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty
def addPartitionInitializationStatement(statement: String): Unit = {
partitionInitializationStatements += statement
}
def initPartition(): String = {
partitionInitializationStatements.mkString("\n")
}
/**
* Holds expressions that are equivalent. Used to perform subexpression elimination
* during codegen.
*
* For expressions that appear more than once, generate additional code to prevent
* recomputing the value.
*
* For example, consider two expression generated from this SQL statement:
* SELECT (col1 + col2), (col1 + col2) / col3.
*
* equivalentExpressions will match the tree containing `col1 + col2` and it will only
* be evaluated once.
*/
private val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
// Foreach expression that is participating in subexpression elimination, the state to use.
// Visible for testing.
private[expressions] var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState]
// The collection of sub-expression result resetting methods that need to be called on each row.
private val subexprFunctions = mutable.ArrayBuffer.empty[String]
val outerClassName = "OuterClass"
/**
* Holds the class and instance names to be generated, where `OuterClass` is a placeholder
* standing for whichever class is generated as the outermost class and which will contain any
* inner sub-classes. All other classes and instance names in this list will represent private,
* inner sub-classes.
*/
private val classes: mutable.ListBuffer[(String, String)] =
mutable.ListBuffer[(String, String)](outerClassName -> null)
// A map holding the current size in bytes of each class to be generated.
private val classSize: mutable.Map[String, Int] =
mutable.Map[String, Int](outerClassName -> 0)
// Nested maps holding function names and their code belonging to each class.
private val classFunctions: mutable.Map[String, mutable.Map[String, String]] =
mutable.Map(outerClassName -> mutable.Map.empty[String, String])
// Verbatim extra code to be added to the OuterClass.
private val extraClasses: mutable.ListBuffer[String] = mutable.ListBuffer[String]()
// Returns the size of the most recently added class.
private def currClassSize(): Int = classSize(classes.head._1)
// Returns the class name and instance name for the most recently added class.
private def currClass(): (String, String) = classes.head
// Adds a new class. Requires the class' name, and its instance name.
private def addClass(className: String, classInstance: String): Unit = {
classes.prepend(className -> classInstance)
classSize += className -> 0
classFunctions += className -> mutable.Map.empty[String, String]
}
/**
* Adds a function to the generated class. If the code for the `OuterClass` grows too large, the
* function will be inlined into a new private, inner class, and a class-qualified name for the
* function will be returned. Otherwise, the function will be inlined to the `OuterClass` the
* simple `funcName` will be returned.
*
* @param funcName the class-unqualified name of the function
* @param funcCode the body of the function
* @param inlineToOuterClass whether the given code must be inlined to the `OuterClass`. This
* can be necessary when a function is declared outside of the context
* it is eventually referenced and a returned qualified function name
* cannot otherwise be accessed.
* @return the name of the function, qualified by class if it will be inlined to a private,
* inner class
*/
def addNewFunction(
funcName: String,
funcCode: String,
inlineToOuterClass: Boolean = false): String = {
val newFunction = addNewFunctionInternal(funcName, funcCode, inlineToOuterClass)
newFunction match {
case NewFunctionSpec(functionName, None, None) => functionName
case NewFunctionSpec(functionName, Some(_), Some(innerClassInstance)) =>
innerClassInstance + "." + functionName
case _ =>
throw new IllegalArgumentException(s"$funcName is not matched at addNewFunction")
}
}
private[this] def addNewFunctionInternal(
funcName: String,
funcCode: String,
inlineToOuterClass: Boolean): NewFunctionSpec = {
val (className, classInstance) = if (inlineToOuterClass) {
outerClassName -> ""
} else if (currClassSize > GENERATED_CLASS_SIZE_THRESHOLD) {
val className = freshName("NestedClass")
val classInstance = freshName("nestedClassInstance")
addClass(className, classInstance)
className -> classInstance
} else {
currClass()
}
addNewFunctionToClass(funcName, funcCode, className)
if (className == outerClassName) {
NewFunctionSpec(funcName, None, None)
} else {
NewFunctionSpec(funcName, Some(className), Some(classInstance))
}
}
private[this] def addNewFunctionToClass(
funcName: String,
funcCode: String,
className: String) = {
classSize(className) += funcCode.length
classFunctions(className) += funcName -> funcCode
}
/**
* Declares all function code. If the added functions are too many, split them into nested
* sub-classes to avoid hitting Java compiler constant pool limitation.
*/
def declareAddedFunctions(): String = {
val inlinedFunctions = classFunctions(outerClassName).values
// Nested, private sub-classes have no mutable state (though they do reference the outer class'
// mutable state), so we declare and initialize them inline to the OuterClass.
val initNestedClasses = classes.filter(_._1 != outerClassName).map {
case (className, classInstance) =>
s"private $className $classInstance = new $className();"
}
val declareNestedClasses = classFunctions.filterKeys(_ != outerClassName).map {
case (className, functions) =>
s"""
|private class $className {
| ${functions.values.mkString("\n")}
|}
""".stripMargin
}
(inlinedFunctions ++ initNestedClasses ++ declareNestedClasses).mkString("\n")
}
/**
* Emits extra inner classes added with addExtraCode
*/
def emitExtraCode(): String = {
extraClasses.mkString("\n")
}
/**
* Add extra source code to the outermost generated class.
* @param code verbatim source code of the inner class to be added.
*/
def addInnerClass(code: String): Unit = {
extraClasses.append(code)
}
/**
* The map from a variable name to it's next ID.
*/
private val freshNameIds = new mutable.HashMap[String, Int]
freshNameIds += INPUT_ROW -> 1
/**
* A prefix used to generate fresh name.
*/
var freshNamePrefix = ""
/**
* The map from a place holder to a corresponding comment
*/
private val placeHolderToComments = new mutable.HashMap[String, String]
/**
* Returns a term name that is unique within this instance of a `CodegenContext`.
*/
def freshName(name: String): String = synchronized {
val fullName = if (freshNamePrefix == "") {
name
} else {
s"${freshNamePrefix}_$name"
}
val id = freshNameIds.getOrElse(fullName, 0)
freshNameIds(fullName) = id + 1
s"${fullName}_$id"
}
/**
* Creates an `ExprValue` representing a local java variable of required data type.
*/
def freshVariable(name: String, dt: DataType): VariableValue =
JavaCode.variable(freshName(name), dt)
/**
* Creates an `ExprValue` representing a local java variable of required Java class.
*/
def freshVariable(name: String, javaClass: Class[_]): VariableValue =
JavaCode.variable(freshName(name), javaClass)
/**
* Generates code for equal expression in Java.
*/
def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
case FloatType =>
s"((java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2)"
case DoubleType =>
s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)"
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
case NullType => "false"
case _ =>
throw new IllegalArgumentException(
"cannot generate equality code for un-comparable type: " + dataType.catalogString)
}
/**
* Generates code for comparing two expressions.
*
* @param dataType data type of the expressions
* @param c1 name of the variable of expression 1's output
* @param c2 name of the variable of expression 2's output
*/
def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
// java boolean doesn't support > or < operator
case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))"
case DoubleType => s"org.apache.spark.util.Utils.nanSafeCompareDoubles($c1, $c2)"
case FloatType => s"org.apache.spark.util.Utils.nanSafeCompareFloats($c1, $c2)"
// use c1 - c2 may overflow
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
case NullType => "0"
case array: ArrayType =>
val elementType = array.elementType
val elementA = freshName("elementA")
val isNullA = freshName("isNullA")
val elementB = freshName("elementB")
val isNullB = freshName("isNullB")
val compareFunc = freshName("compareArray")
val minLength = freshName("minLength")
val jt = javaType(elementType)
val funcCode: String =
s"""
public int $compareFunc(ArrayData a, ArrayData b) {
// when comparing unsafe arrays, try equals first as it compares the binary directly
// which is very fast.
if (a instanceof UnsafeArrayData && b instanceof UnsafeArrayData && a.equals(b)) {
return 0;
}
int lengthA = a.numElements();
int lengthB = b.numElements();
int $minLength = (lengthA > lengthB) ? lengthB : lengthA;
for (int i = 0; i < $minLength; i++) {
boolean $isNullA = a.isNullAt(i);
boolean $isNullB = b.isNullAt(i);
if ($isNullA && $isNullB) {
// Nothing
} else if ($isNullA) {
return -1;
} else if ($isNullB) {
return 1;
} else {
$jt $elementA = ${getValue("a", elementType, "i")};
$jt $elementB = ${getValue("b", elementType, "i")};
int comp = ${genComp(elementType, elementA, elementB)};
if (comp != 0) {
return comp;
}
}
}
if (lengthA < lengthB) {
return -1;
} else if (lengthA > lengthB) {
return 1;
}
return 0;
}
"""
s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"
case schema: StructType =>
val comparisons = GenerateOrdering.genComparisons(this, schema)
val compareFunc = freshName("compareStruct")
val funcCode: String =
s"""
public int $compareFunc(InternalRow a, InternalRow b) {
// when comparing unsafe rows, try equals first as it compares the binary directly
// which is very fast.
if (a instanceof UnsafeRow && b instanceof UnsafeRow && a.equals(b)) {
return 0;
}
$comparisons
return 0;
}
"""
s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2)
case _ =>
throw new IllegalArgumentException(
"cannot generate compare code for un-comparable type: " + dataType.catalogString)
}
/**
* Generates code for greater of two expressions.
*
* @param dataType data type of the expressions
* @param c1 name of the variable of expression 1's output
* @param c2 name of the variable of expression 2's output
*/
def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match {
case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2"
case _ => s"(${genComp(dataType, c1, c2)}) > 0"
}
/**
* Generates code for updating `partialResult` if `item` is smaller than it.
*
* @param dataType data type of the expressions
* @param partialResult `ExprCode` representing the partial result which has to be updated
* @param item `ExprCode` representing the new expression to evaluate for the result
*/
def reassignIfSmaller(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = {
s"""
|if (!${item.isNull} && (${partialResult.isNull} ||
| ${genGreater(dataType, partialResult.value, item.value)})) {
| ${partialResult.isNull} = false;
| ${partialResult.value} = ${item.value};
|}
""".stripMargin
}
/**
* Generates code for updating `partialResult` if `item` is greater than it.
*
* @param dataType data type of the expressions
* @param partialResult `ExprCode` representing the partial result which has to be updated
* @param item `ExprCode` representing the new expression to evaluate for the result
*/
def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = {
s"""
|if (!${item.isNull} && (${partialResult.isNull} ||
| ${genGreater(dataType, item.value, partialResult.value)})) {
| ${partialResult.isNull} = false;
| ${partialResult.value} = ${item.value};
|}
""".stripMargin
}
/**
* Generates code to do null safe execution, i.e. only execute the code when the input is not
* null by adding null check if necessary.
*
* @param nullable used to decide whether we should add null check or not.
* @param isNull the code to check if the input is null.
* @param execute the code that should only be executed when the input is not null.
*/
def nullSafeExec(nullable: Boolean, isNull: String)(execute: String): String = {
if (nullable) {
s"""
if (!$isNull) {
$execute
}
"""
} else {
"\n" + execute
}
}
/**
* Generates code to do null safe execution when accessing properties of complex
* ArrayData elements.
*
* @param nullElements used to decide whether the ArrayData might contain null or not.
* @param isNull a variable indicating whether the result will be evaluated to null or not.
* @param arrayData a variable name representing the ArrayData.
* @param execute the code that should be executed only if the ArrayData doesn't contain
* any null.
*/
def nullArrayElementsSaveExec(
nullElements: Boolean,
isNull: String,
arrayData: String)(
execute: String): String = {
val i = freshName("idx")
if (nullElements) {
s"""
|for (int $i = 0; !$isNull && $i < $arrayData.numElements(); $i++) {
| $isNull |= $arrayData.isNullAt($i);
|}
|if (!$isNull) {
| $execute
|}
""".stripMargin
} else {
execute
}
}
/**
* Splits the generated code of expressions into multiple functions, because function has
* 64kb code size limit in JVM. If the class to which the function would be inlined would grow
* beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it
* instead, because classes have a constant pool limit of 65,536 named values.
*
* Note that different from `splitExpressions`, we will extract the current inputs of this
* context and pass them to the generated functions. The input is `INPUT_ROW` for normal codegen
* path, and `currentVars` for whole stage codegen path. Whole stage codegen path is not
* supported yet.
*
* @param expressions the codes to evaluate expressions.
* @param funcName the split function name base.
* @param extraArguments the list of (type, name) of the arguments of the split function,
* except for the current inputs like `ctx.INPUT_ROW`.
* @param returnType the return type of the split function.
* @param makeSplitFunction makes split function body, e.g. add preparation or cleanup.
* @param foldFunctions folds the split function calls.
*/
def splitExpressionsWithCurrentInputs(
expressions: Seq[String],
funcName: String = "apply",
extraArguments: Seq[(String, String)] = Nil,
returnType: String = "void",
makeSplitFunction: String => String = identity,
foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = {
// TODO: support whole stage codegen
if (INPUT_ROW == null || currentVars != null) {
expressions.mkString("\n")
} else {
splitExpressions(
expressions,
funcName,
("InternalRow", INPUT_ROW) +: extraArguments,
returnType,
makeSplitFunction,
foldFunctions)
}
}
/**
* Splits the generated code of expressions into multiple functions, because function has
* 64kb code size limit in JVM. If the class to which the function would be inlined would grow
* beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it
* instead, because classes have a constant pool limit of 65,536 named values.
*
* @param expressions the codes to evaluate expressions.
* @param funcName the split function name base.
* @param arguments the list of (type, name) of the arguments of the split function.
* @param returnType the return type of the split function.
* @param makeSplitFunction makes split function body, e.g. add preparation or cleanup.
* @param foldFunctions folds the split function calls.
*/
def splitExpressions(
expressions: Seq[String],
funcName: String,
arguments: Seq[(String, String)],
returnType: String = "void",
makeSplitFunction: String => String = identity,
foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = {
val blocks = buildCodeBlocks(expressions)
if (blocks.length == 1) {
// inline execution if only one block
blocks.head
} else {
if (Utils.isTesting) {
// Passing global variables to the split method is dangerous, as any mutating to it is
// ignored and may lead to unexpected behavior.
arguments.foreach { case (_, name) =>
assert(!mutableStateNames.contains(name),
s"split function argument $name cannot be a global variable.")
}
}
val func = freshName(funcName)
val argString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ")
val functions = blocks.zipWithIndex.map { case (body, i) =>
val name = s"${func}_$i"
val code = s"""
|private $returnType $name($argString) {
| ${makeSplitFunction(body)}
|}
""".stripMargin
addNewFunctionInternal(name, code, inlineToOuterClass = false)
}
val (outerClassFunctions, innerClassFunctions) = functions.partition(_.innerClassName.isEmpty)
val argsString = arguments.map(_._2).mkString(", ")
val outerClassFunctionCalls = outerClassFunctions.map(f => s"${f.functionName}($argsString)")
val innerClassFunctionCalls = generateInnerClassesFunctionCalls(
innerClassFunctions,
func,
arguments,
returnType,
makeSplitFunction,
foldFunctions)
foldFunctions(outerClassFunctionCalls ++ innerClassFunctionCalls)
}
}
/**
* Splits the generated code of expressions into multiple sequences of String
* based on a threshold of length of a String
*
* @param expressions the codes to evaluate expressions.
*/
private def buildCodeBlocks(expressions: Seq[String]): Seq[String] = {
val blocks = new ArrayBuffer[String]()
val blockBuilder = new StringBuilder()
var length = 0
val splitThreshold = SQLConf.get.methodSplitThreshold
for (code <- expressions) {
// We can't know how many bytecode will be generated, so use the length of source code
// as metric. A method should not go beyond 8K, otherwise it will not be JITted, should
// also not be too small, or it will have many function calls (for wide table), see the
// results in BenchmarkWideTable.
if (length > splitThreshold) {
blocks += blockBuilder.toString()
blockBuilder.clear()
length = 0
}
blockBuilder.append(code)
length += CodeFormatter.stripExtraNewLinesAndComments(code).length
}
blocks += blockBuilder.toString()
}
/**
* Here we handle all the methods which have been added to the inner classes and
* not to the outer class.
* Since they can be many, their direct invocation in the outer class adds many entries
* to the outer class' constant pool. This can cause the constant pool to past JVM limit.
* Moreover, this can cause also the outer class method where all the invocations are
* performed to grow beyond the 64k limit.
* To avoid these problems, we group them and we call only the grouping methods in the
* outer class.
*
* @param functions a [[Seq]] of [[NewFunctionSpec]] defined in the inner classes
* @param funcName the split function name base.
* @param arguments the list of (type, name) of the arguments of the split function.
* @param returnType the return type of the split function.
* @param makeSplitFunction makes split function body, e.g. add preparation or cleanup.
* @param foldFunctions folds the split function calls.
* @return an [[Iterable]] containing the methods' invocations
*/
private def generateInnerClassesFunctionCalls(
functions: Seq[NewFunctionSpec],
funcName: String,
arguments: Seq[(String, String)],
returnType: String,
makeSplitFunction: String => String,
foldFunctions: Seq[String] => String): Iterable[String] = {
val innerClassToFunctions = mutable.LinkedHashMap.empty[(String, String), Seq[String]]
functions.foreach(f => {
val key = (f.innerClassName.get, f.innerClassInstance.get)
val value = f.functionName +: innerClassToFunctions.getOrElse(key, Seq.empty[String])
innerClassToFunctions.put(key, value)
})
val argDefinitionString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ")
val argInvocationString = arguments.map(_._2).mkString(", ")
innerClassToFunctions.flatMap {
case ((innerClassName, innerClassInstance), innerClassFunctions) =>
// for performance reasons, the functions are prepended, instead of appended,
// thus here they are in reversed order
val orderedFunctions = innerClassFunctions.reverse
if (orderedFunctions.size > MERGE_SPLIT_METHODS_THRESHOLD) {
// Adding a new function to each inner class which contains the invocation of all the
// ones which have been added to that inner class. For example,
// private class NestedClass {
// private void apply_862(InternalRow i) { ... }
// private void apply_863(InternalRow i) { ... }
// ...
// private void apply(InternalRow i) {
// apply_862(i);
// apply_863(i);
// ...
// }
// }
val body = foldFunctions(orderedFunctions.map(name => s"$name($argInvocationString)"))
val code = s"""
|private $returnType $funcName($argDefinitionString) {
| ${makeSplitFunction(body)}
|}
""".stripMargin
addNewFunctionToClass(funcName, code, innerClassName)
Seq(s"$innerClassInstance.$funcName($argInvocationString)")
} else {
orderedFunctions.map(f => s"$innerClassInstance.$f($argInvocationString)")
}
}
}
/**
* Returns the code for subexpression elimination after splitting it if necessary.
*/
def subexprFunctionsCode: String = {