From 1a9b34b92e9d71732bed698fb17adaf9bdd362fa Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Sat, 10 Nov 2018 14:31:52 +0800 Subject: [PATCH] Reduce memory copy when writing decimal --- .../sql/catalyst/expressions/UnsafeRow.java | 10 +++++--- .../expressions/codegen/UnsafeRowWriter.java | 11 +++++--- .../codegen/UnsafeRowWriterSuite.scala | 25 +++++++++++++++++++ 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index a76e6ef8c91c1..ee7f81605f479 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -281,9 +281,6 @@ public void setDecimal(int ordinal, Decimal value, int precision) { // fixed length long cursor = getLong(ordinal) >>> 32; assert cursor > 0 : "invalid cursor " + cursor; - // zero-out the bytes - Platform.putLong(baseObject, baseOffset + cursor, 0L); - Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); if (value == null) { setNullAt(ordinal); @@ -295,6 +292,13 @@ public void setDecimal(int ordinal, Decimal value, int precision) { byte[] bytes = integer.toByteArray(); assert(bytes.length <= 16); + // always zero-out the 8-byte to 16-byte buffer + Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); + if (bytes.length < 8) { + // need zero-out the 8-byte buffer when bytes.length less than 8-byte + Platform.putLong(baseObject, baseOffset + cursor, 0L); + } + // Write the bytes to the variable length portion. Platform.copyMemory( bytes, Platform.BYTE_ARRAY_OFFSET, baseObject, baseOffset + cursor, bytes.length); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 3960d6d520476..b82883bf24fb6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -185,10 +185,6 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // grow the global buffer before writing data. holder.grow(16); - // always zero-out the 16-byte buffer - Platform.putLong(getBuffer(), cursor(), 0L); - Platform.putLong(getBuffer(), cursor() + 8, 0L); - // Make sure Decimal object has the same scale as DecimalType. // Note that we may pass in null Decimal object to set null for it. if (input == null || !input.changePrecision(precision, scale)) { @@ -200,6 +196,13 @@ public void write(int ordinal, Decimal input, int precision, int scale) { final int numBytes = bytes.length; assert numBytes <= 16; + // always zero-out the 8-byte to 16-byte buffer + Platform.putLong(getBuffer(), cursor() + 8, 0L); + if (numBytes < 8) { + // need zero-out the 8-byte buffer when numBytes less than 8-byte + Platform.putLong(getBuffer(), cursor(), 0L); + } + // Write the bytes to the variable length portion. Platform.copyMemory( bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala index fb651b76fc16d..afa57c16b6043 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala @@ -50,4 +50,29 @@ class UnsafeRowWriterSuite extends SparkFunSuite { assert(res1 == res2) } + test("SPARK-26001: write a decimal with 16 bytes and then one with less than 8") { + val decimal1 = Decimal(3.431) + decimal1.changePrecision(38, 10) + checkDecimalSizeInBytes(decimal1, 5) + + val decimal2 = Decimal(123456789.1232456789) + decimal2.changePrecision(38, 18) + checkDecimalSizeInBytes(decimal2, 11) + // On an UnsafeRowWriter we write decimal2 first and then decimal1 + val unsafeRowWriter1 = new UnsafeRowWriter(1) + unsafeRowWriter1.resetRowWriter() + unsafeRowWriter1.write(0, decimal2, decimal2.precision, decimal2.scale) + unsafeRowWriter1.reset() + unsafeRowWriter1.write(0, decimal1, decimal1.precision, decimal1.scale) + val res1 = unsafeRowWriter1.getRow + + // On a second UnsafeRowWriter we write directly decimal1 + val unsafeRowWriter2 = new UnsafeRowWriter(1) + unsafeRowWriter2.resetRowWriter() + unsafeRowWriter2.write(0, decimal1, decimal1.precision, decimal1.scale) + val res2 = unsafeRowWriter2.getRow + // The two rows should be the equal + assert(res1 == res2) + } + }