Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-20253][SQL] Remove unnecessary nullchecks of a return value from Spark runtime routines in generated Java code #17569

Closed
wants to merge 9 commits into from

Conversation

kiszk
Copy link
Member

@kiszk kiszk commented Apr 7, 2017

What changes were proposed in this pull request?

This PR elminates unnecessary nullchecks of a return value from known Spark runtime routines. We know whether a given Spark runtime routine returns null or not (e.g. ArrayData.toDoubleArray() never returns null). Thus, we can eliminate a null check for the return value from the Spark runtime routine.

When we run the following example program, now we get the Java code "Without this PR". In this code, since we know ArrayData.toDoubleArray() never returns ``null```, we can eliminate null checks at lines 90-92, and 97.

val ds = sparkContext.parallelize(Seq(Array(1.1, 2.2)), 1).toDS.cache
ds.count
ds.map(e => e).show

Without this PR

/* 050 */   protected void processNext() throws java.io.IOException {
/* 051 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 052 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 053 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 054 */       ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0));
/* 055 */
/* 056 */       ArrayData deserializetoobject_value1 = null;
/* 057 */
/* 058 */       if (!inputadapter_isNull) {
/* 059 */         int deserializetoobject_dataLength = inputadapter_value.numElements();
/* 060 */
/* 061 */         Double[] deserializetoobject_convertedArray = null;
/* 062 */         deserializetoobject_convertedArray = new Double[deserializetoobject_dataLength];
/* 063 */
/* 064 */         int deserializetoobject_loopIndex = 0;
/* 065 */         while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 066 */           MapObjects_loopValue2 = (double) (inputadapter_value.getDouble(deserializetoobject_loopIndex));
/* 067 */           MapObjects_loopIsNull2 = inputadapter_value.isNullAt(deserializetoobject_loopIndex);
/* 068 */
/* 069 */           if (MapObjects_loopIsNull2) {
/* 070 */             throw new RuntimeException(((java.lang.String) references[0]));
/* 071 */           }
/* 072 */           if (false) {
/* 073 */             deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null;
/* 074 */           } else {
/* 075 */             deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue2;
/* 076 */           }
/* 077 */
/* 078 */           deserializetoobject_loopIndex += 1;
/* 079 */         }
/* 080 */
/* 081 */         deserializetoobject_value1 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /*###*/
/* 082 */       }
/* 083 */       boolean deserializetoobject_isNull = true;
/* 084 */       double[] deserializetoobject_value = null;
/* 085 */       if (!inputadapter_isNull) {
/* 086 */         deserializetoobject_isNull = false;
/* 087 */         if (!deserializetoobject_isNull) {
/* 088 */           Object deserializetoobject_funcResult = null;
/* 089 */           deserializetoobject_funcResult = deserializetoobject_value1.toDoubleArray();
/* 090 */           if (deserializetoobject_funcResult == null) {
/* 091 */             deserializetoobject_isNull = true;
/* 092 */           } else {
/* 093 */             deserializetoobject_value = (double[]) deserializetoobject_funcResult;
/* 094 */           }
/* 095 */
/* 096 */         }
/* 097 */         deserializetoobject_isNull = deserializetoobject_value == null;
/* 098 */       }
/* 099 */
/* 100 */       boolean mapelements_isNull = true;
/* 101 */       double[] mapelements_value = null;
/* 102 */       if (!false) {
/* 103 */         mapelements_resultIsNull = false;
/* 104 */
/* 105 */         if (!mapelements_resultIsNull) {
/* 106 */           mapelements_resultIsNull = deserializetoobject_isNull;
/* 107 */           mapelements_argValue = deserializetoobject_value;
/* 108 */         }
/* 109 */
/* 110 */         mapelements_isNull = mapelements_resultIsNull;
/* 111 */         if (!mapelements_isNull) {
/* 112 */           Object mapelements_funcResult = null;
/* 113 */           mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue);
/* 114 */           if (mapelements_funcResult == null) {
/* 115 */             mapelements_isNull = true;
/* 116 */           } else {
/* 117 */             mapelements_value = (double[]) mapelements_funcResult;
/* 118 */           }
/* 119 */
/* 120 */         }
/* 121 */         mapelements_isNull = mapelements_value == null;
/* 122 */       }
/* 123 */
/* 124 */       serializefromobject_resultIsNull = false;
/* 125 */
/* 126 */       if (!serializefromobject_resultIsNull) {
/* 127 */         serializefromobject_resultIsNull = mapelements_isNull;
/* 128 */         serializefromobject_argValue = mapelements_value;
/* 129 */       }
/* 130 */
/* 131 */       boolean serializefromobject_isNull = serializefromobject_resultIsNull;
/* 132 */       final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue);
/* 133 */       serializefromobject_isNull = serializefromobject_value == null;
/* 134 */       serializefromobject_holder.reset();
/* 135 */
/* 136 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 137 */
/* 138 */       if (serializefromobject_isNull) {
/* 139 */         serializefromobject_rowWriter.setNullAt(0);
/* 140 */       } else {
/* 141 */         // Remember the current cursor so that we can calculate how many bytes are
/* 142 */         // written later.
/* 143 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 144 */
/* 145 */         if (serializefromobject_value instanceof UnsafeArrayData) {
/* 146 */           final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 147 */           // grow the global buffer before writing data.
/* 148 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 149 */           ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 150 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 151 */
/* 152 */         } else {
/* 153 */           final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 154 */           serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8);
/* 155 */
/* 156 */           for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) {
/* 157 */             if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 158 */               serializefromobject_arrayWriter.setNullDouble(serializefromobject_index);
/* 159 */             } else {
/* 160 */               final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index);
/* 161 */               serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 162 */             }
/* 163 */           }
/* 164 */         }
/* 165 */
/* 166 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 167 */       }
/* 168 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 169 */       append(serializefromobject_result);
/* 170 */       if (shouldStop()) return;
/* 171 */     }
/* 172 */   }

With this PR (removed most of lines 90-97 in the above code)

/* 050 */   protected void processNext() throws java.io.IOException {
/* 051 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 052 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 053 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 054 */       ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0));
/* 055 */
/* 056 */       ArrayData deserializetoobject_value1 = null;
/* 057 */
/* 058 */       if (!inputadapter_isNull) {
/* 059 */         int deserializetoobject_dataLength = inputadapter_value.numElements();
/* 060 */
/* 061 */         Double[] deserializetoobject_convertedArray = null;
/* 062 */         deserializetoobject_convertedArray = new Double[deserializetoobject_dataLength];
/* 063 */
/* 064 */         int deserializetoobject_loopIndex = 0;
/* 065 */         while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 066 */           MapObjects_loopValue2 = (double) (inputadapter_value.getDouble(deserializetoobject_loopIndex));
/* 067 */           MapObjects_loopIsNull2 = inputadapter_value.isNullAt(deserializetoobject_loopIndex);
/* 068 */
/* 069 */           if (MapObjects_loopIsNull2) {
/* 070 */             throw new RuntimeException(((java.lang.String) references[0]));
/* 071 */           }
/* 072 */           if (false) {
/* 073 */             deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null;
/* 074 */           } else {
/* 075 */             deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue2;
/* 076 */           }
/* 077 */
/* 078 */           deserializetoobject_loopIndex += 1;
/* 079 */         }
/* 080 */
/* 081 */         deserializetoobject_value1 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /*###*/
/* 082 */       }
/* 083 */       boolean deserializetoobject_isNull = true;
/* 084 */       double[] deserializetoobject_value = null;
/* 085 */       if (!inputadapter_isNull) {
/* 086 */         deserializetoobject_isNull = false;
/* 087 */         if (!deserializetoobject_isNull) {
/* 088 */           Object deserializetoobject_funcResult = null;
/* 089 */           deserializetoobject_funcResult = deserializetoobject_value1.toDoubleArray();
/* 090 */           deserializetoobject_value = (double[]) deserializetoobject_funcResult;
/* 091 */
/* 092 */         }
/* 093 */
/* 094 */       }
/* 095 */
/* 096 */       boolean mapelements_isNull = true;
/* 097 */       double[] mapelements_value = null;
/* 098 */       if (!false) {
/* 099 */         mapelements_resultIsNull = false;
/* 100 */
/* 101 */         if (!mapelements_resultIsNull) {
/* 102 */           mapelements_resultIsNull = deserializetoobject_isNull;
/* 103 */           mapelements_argValue = deserializetoobject_value;
/* 104 */         }
/* 105 */
/* 106 */         mapelements_isNull = mapelements_resultIsNull;
/* 107 */         if (!mapelements_isNull) {
/* 108 */           Object mapelements_funcResult = null;
/* 109 */           mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue);
/* 110 */           if (mapelements_funcResult == null) {
/* 111 */             mapelements_isNull = true;
/* 112 */           } else {
/* 113 */             mapelements_value = (double[]) mapelements_funcResult;
/* 114 */           }
/* 115 */
/* 116 */         }
/* 117 */         mapelements_isNull = mapelements_value == null;
/* 118 */       }
/* 119 */
/* 120 */       serializefromobject_resultIsNull = false;
/* 121 */
/* 122 */       if (!serializefromobject_resultIsNull) {
/* 123 */         serializefromobject_resultIsNull = mapelements_isNull;
/* 124 */         serializefromobject_argValue = mapelements_value;
/* 125 */       }
/* 126 */
/* 127 */       boolean serializefromobject_isNull = serializefromobject_resultIsNull;
/* 128 */       final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue);
/* 129 */       serializefromobject_isNull = serializefromobject_value == null;
/* 130 */       serializefromobject_holder.reset();
/* 131 */
/* 132 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 133 */
/* 134 */       if (serializefromobject_isNull) {
/* 135 */         serializefromobject_rowWriter.setNullAt(0);
/* 136 */       } else {
/* 137 */         // Remember the current cursor so that we can calculate how many bytes are
/* 138 */         // written later.
/* 139 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 140 */
/* 141 */         if (serializefromobject_value instanceof UnsafeArrayData) {
/* 142 */           final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 143 */           // grow the global buffer before writing data.
/* 144 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 145 */           ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 146 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 147 */
/* 148 */         } else {
/* 149 */           final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 150 */           serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8);
/* 151 */
/* 152 */           for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) {
/* 153 */             if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 154 */               serializefromobject_arrayWriter.setNullDouble(serializefromobject_index);
/* 155 */             } else {
/* 156 */               final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index);
/* 157 */               serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 158 */             }
/* 159 */           }
/* 160 */         }
/* 161 */
/* 162 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 163 */       }
/* 164 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 165 */       append(serializefromobject_result);
/* 166 */       if (shouldStop()) return;
/* 167 */     }
/* 168 */   }

How was this patch tested?

Add test suites to DatasetPrimitiveSuite

@kiszk kiszk changed the title [SPARK-20254][SQL] Remove unnecessary nullchecks of a return value from Spark runtime routines in generated Java code [SPARK-20253][SQL] Remove unnecessary nullchecks of a return value from Spark runtime routines in generated Java code Apr 7, 2017
@SparkQA
Copy link

SparkQA commented Apr 7, 2017

Test build #75611 has finished for PR 17569 at commit 4482e1c.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Apr 7, 2017

Test build #75614 has finished for PR 17569 at commit ae5e232.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@kiszk
Copy link
Member Author

kiszk commented Apr 8, 2017

@cloud-fan could you please review this?

@@ -356,7 +361,8 @@ object ScalaReflection extends ScalaReflection {
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The deserialize is totally implemented by users, can we guarantee not return null?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It is UDT. I have checked deserialized only in Spark runtime.

@@ -365,7 +371,8 @@ object ScalaReflection extends ScalaReflection {
udt.getClass,
Nil,
dataType = ObjectType(udt.getClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@@ -577,7 +584,7 @@ object ScalaReflection extends ScalaReflection {
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "serialize", udt, inputObject :: Nil)
Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@@ -586,7 +593,7 @@ object ScalaReflection extends ScalaReflection {
udt.getClass,
Nil,
dataType = ObjectType(udt.getClass))
Invoke(obj, "serialize", udt, inputObject :: Nil)
Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sam here

@@ -608,7 +616,7 @@ case class MapObjects private(
$convertedArray = $arrayConstructor;
""",
genValue => s"$convertedArray[$loopIndex] = $genValue;",
s"new ${classOf[GenericArrayData].getName}($convertedArray);"
s"new ${classOf[GenericArrayData].getName}($convertedArray); /*###*/"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry. It was my debug code...

if ($funcResult == null) {
${ev.isNull} = true;
} else {
if (!returnNullable) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we have postNullCheck, can we always go to this branch?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, yes we can.

"""
}

// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
val postNullCheck = if (ctx.defaultValue(dataType) == "null" && returnNullable) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, can we embed postNullCheck to evaluate?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, done

@@ -96,6 +96,16 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(dsBoolean.map(e => !e), false, true)
}

test("mapPrimitiveArray") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do these tests fail before this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I have just added to confirm this check works well.


case StringType =>
Invoke(input, "toString", ObjectType(classOf[String]))
Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false)
Copy link
Contributor

@cloud-fan cloud-fan Apr 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we check how many places we set returnNullable to true? If it's only a few, we can change the default value of returnNullable to false.

Copy link
Member Author

@kiszk kiszk Apr 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is statistics for 59 call sites of Invoke().

18: dataType is primitive type
21: returnNullable is true (no specification at call site, as default)
19: returnNullable is false
1: set a variable to `returnNullable

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok let's keep the default value unchanged

@SparkQA
Copy link

SparkQA commented Apr 8, 2017

Test build #75619 has finished for PR 17569 at commit fc6caac.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Apr 8, 2017

Test build #75621 has finished for PR 17569 at commit a39803a.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@viirya
Copy link
Member

viirya commented Apr 8, 2017

Seems there are places (i.e., RowEncoder) calling isNullAt which gives returnNullable as true (default value).

@@ -225,25 +225,26 @@ case class Invoke(
getFuncResult(ev.value, s"${obj.value}.$functionName($argString)")
} else {
val funcResult = ctx.freshName("funcResult")
// If the function can return null, we do an extra check to make sure our null bit is still
// set correctly.
val postNullCheck = if (!returnNullable) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename postNullCheck. It is actually not only null check but also assigning the function result.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, done

@SparkQA
Copy link

SparkQA commented Apr 8, 2017

Test build #75625 has finished for PR 17569 at commit 3080ac2.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Apr 8, 2017

Test build #75627 has finished for PR 17569 at commit 510fb53.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -225,25 +225,26 @@ case class Invoke(
getFuncResult(ev.value, s"${obj.value}.$functionName($argString)")
} else {
val funcResult = ctx.freshName("funcResult")
// If the function can return null, we do an extra check to make sure our null bit is still
// set correctly.
val postNullCheckAndAssign = if (!returnNullable) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about just assignResult?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@cloud-fan
Copy link
Contributor

LGTM

1 similar comment
@viirya
Copy link
Member

viirya commented Apr 9, 2017

LGTM

@SparkQA
Copy link

SparkQA commented Apr 9, 2017

Test build #75628 has finished for PR 17569 at commit 10cf4be.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@asfgit asfgit closed this in 7a63f5e Apr 10, 2017
@cloud-fan
Copy link
Contributor

thanks, merging to master!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants