-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-16714][SQL] array
should create a decimal array from decimals with different precisions and scales
#14353
Conversation
Test build #62848 has finished for PR 14353 at commit
|
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") | ||
override def checkInputDataTypes(): TypeCheckResult = { | ||
if (children.map(_.dataType).forall(_.isInstanceOf[DecimalType])) { | ||
TypeCheckResult.TypeCheckSuccess |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we cannot just make the check pass. We need to need to actually cast those element to the same prevision and scale.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, if we access a single element, its data type actually may not be the one shown as the array's datatype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for review, @yhuai .
I see. I'll check that more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, @yhuai . I checked the following.
scala> sql("select a[0], a[1] from (select array(0.001, 0.02) a) T")
res4: org.apache.spark.sql.DataFrame = [a[0]: decimal(3,3), a[1]: decimal(3,3)]
scala> sql("select a[0], a[1] from (select array(0.001, 0.02) a) T").show()
+-----+-----+
| a[0]| a[1]|
+-----+-----+
|0.001|0.020|
+-----+-----+
scala> sql("select a[0], a[1] from (select array(0.001, 0.02) a) T").explain(true)
== Parsed Logical Plan ==
'Project [unresolvedalias('a[0], None), unresolvedalias('a[1], None)]
+- 'SubqueryAlias T
+- 'Project ['array(0.001, 0.02) AS a#54]
+- OneRowRelation$
== Analyzed Logical Plan ==
a[0]: decimal(3,3), a[1]: decimal(3,3)
Project [a#54[0] AS a[0]#61, a#54[1] AS a[1]#62]
+- SubqueryAlias T
+- Project [array(0.001, 0.02) AS a#54]
+- OneRowRelation$
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scala> sql("create table d1(a DECIMAL(3,2))")
scala> sql("create table d2(a DECIMAL(2,1))")
scala> sql("insert into d1 values(1.0)")
scala> sql("insert into d2 values(1.0)")
scala> sql("select * from d1, d2").show()
+----+---+
| a| a|
+----+---+
|1.00|1.0|
+----+---+
scala> sql("select array(d1.a,d2.a),array(d2.a,d1.a),* from d1, d2")
res5: org.apache.spark.sql.DataFrame = [array(a, a): array<decimal(3,2)>, array(a, a): array<decimal(3,2)> ... 2 more fields]
scala> sql("select array(d1.a,d2.a),array(d2.a,d1.a),* from d1, d2").show()
+------------+------------+----+---+
| array(a, a)| array(a, a)| a| a|
+------------+------------+----+---+
|[1.00, 1.00]|[1.00, 1.00]|1.00|1.0|
+------------+------------+----+---+
scala> sql("select array(d1.a,d2.a)[0],array(d2.a,d1.a)[0],* from d1, d2").show()
+--------------+--------------+----+---+
|array(a, a)[0]|array(a, a)[0]| a| a|
+--------------+--------------+----+---+
| 1.00| 1.00|1.00|1.0|
+--------------+--------------+----+---+
scala> sql("select array(d1.a,d2.a)[1],array(d2.a,d1.a)[1],* from d1, d2").show()
+--------------+--------------+----+---+
|array(a, a)[1]|array(a, a)[1]| a| a|
+--------------+--------------+----+---+
| 1.00| 1.00|1.00|1.0|
+--------------+--------------+----+---+
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And Finally, the following is the codegen result. Please see the line 29.
scala> sql("explain codegen select array(0.001, 0.02)[1]").collect().foreach(println)
[Found 1 WholeStageCodegen subtrees.
== Subtree 1 / 1 ==
*Project [0.02 AS array(0.001, 0.02)[1]#75]
+- Scan OneRowRelation[]
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator inputadapter_input;
/* 008 */ private UnsafeRow project_result;
/* 009 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder;
/* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter project_rowWriter;
/* 011 */
/* 012 */ public GeneratedIterator(Object[] references) {
/* 013 */ this.references = references;
/* 014 */ }
/* 015 */
/* 016 */ public void init(int index, scala.collection.Iterator inputs[]) {
/* 017 */ partitionIndex = index;
/* 018 */ inputadapter_input = inputs[0];
/* 019 */ project_result = new UnsafeRow(1);
/* 020 */ this.project_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result, 0);
/* 021 */ this.project_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder, 1);
/* 022 */ }
/* 023 */
/* 024 */ protected void processNext() throws java.io.IOException {
/* 025 */ while (inputadapter_input.hasNext()) {
/* 026 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 027 */ Object project_obj = ((Expression) references[0]).eval(null);
/* 028 */ Decimal project_value = (Decimal) project_obj;
/* 029 */ project_rowWriter.write(0, project_value, 3, 3);
/* 030 */ append(project_result);
/* 031 */ if (shouldStop()) return;
/* 032 */ }
/* 033 */ }
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In short, those are recognized correctly in the Analyzed Logical Plan. As a result, the codegen correctly writes it with the unified precision and scale.
== Analyzed Logical Plan ==
a[0]: decimal(3,3), a[1]: decimal(3,3)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anything to check more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, @yhuai .
Could you give me some advice?
Hi, @rxin . |
…ng different inferred precessions and scales
Rebased. |
Test build #62954 has finished for PR 14353 at commit
|
var elementType: DataType = children.headOption.map(_.dataType).getOrElse(NullType) | ||
if (elementType.isInstanceOf[DecimalType]) { | ||
children.foreach { child => | ||
if (elementType.asInstanceOf[DecimalType].isTighterThan(child.dataType)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this suffers from the same issue as the map pr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @rxin .
Yep. I've read you comment about the lose.
I'll check that and revise.
@dongjoon-hyun I created a patch here: #14389 |
Close this for the better PR #14439 |
What changes were proposed in this pull request?
In Spark 2.0, we will parse float literals as decimals. However, it introduces a side-effect, which is described below.
Before
After
How was this patch tested?
Pass the Jenkins tests with a new test case.