diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 87e5a89c19658..0fb33dd5a15a0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -24,7 +24,7 @@ import java.util.HashSet; import java.util.Set; -import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; @@ -235,6 +235,41 @@ public Object get(int ordinal) { throw new UnsupportedOperationException(); } + @Override + public Object get(int ordinal, DataType dataType) { + if (dataType instanceof NullType) { + return null; + } else if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof DecimalType) { + return getDecimal(ordinal); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType) dataType).size()); + } else { + throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); + } + } + @Override public boolean isNullAt(int ordinal) { assertIndexIsValid(ordinal); @@ -436,4 +471,19 @@ public String toString() { public boolean anyNull() { return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8); } + + /** + * Writes the content of this row into a memory address, identified by an object and an offset. + * The target memory address must already been allocated, and have enough space to hold all the + * bytes in this string. + */ + public void writeToMemory(Object target, long targetOffset) { + PlatformDependent.copyMemory( + baseObject, + baseOffset, + target, + targetOffset, + sizeInBytes + ); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 385d9671386dc..ad3977281d1a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -30,11 +30,11 @@ abstract class InternalRow extends Serializable { def numFields: Int - def get(ordinal: Int): Any + def get(ordinal: Int): Any = get(ordinal, null) def genericGet(ordinal: Int): Any = get(ordinal, null) - def get(ordinal: Int, dataType: DataType): Any = get(ordinal) + def get(ordinal: Int, dataType: DataType): Any def getAs[T](ordinal: Int, dataType: DataType): T = get(ordinal, dataType).asInstanceOf[T] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index cc89d74146b34..27d6ff587ab71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -198,7 +198,7 @@ class JoinedRow extends InternalRow { if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } - override def get(i: Int): Any = + override def get(i: Int, dataType: DataType): Any = if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 5953a093dc684..b877ce47c083f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -219,7 +219,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR values(i).isNull = true } - override def get(i: Int): Any = values(i).boxed + override def get(i: Int, dataType: DataType): Any = values(i).boxed override def getStruct(ordinal: Int, numFields: Int): InternalRow = { values(ordinal).boxed.asInstanceOf[InternalRow] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index a361b216eb472..35920147105ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -183,7 +183,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - public Object get(int i) { + public Object get(int i, ${classOf[DataType].getName} dataType) { if (isNullAt(i)) return null; switch (i) { $getCases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index daeabe8e90f1d..b7c4ece4a16fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -99,7 +99,7 @@ class GenericInternalRow(protected[sql] val values: Array[Any]) extends Internal override def numFields: Int = values.length - override def get(i: Int): Any = values(i) + override def get(i: Int, dataType: DataType): Any = values(i) override def getStruct(ordinal: Int, numFields: Int): InternalRow = { values(ordinal).asInstanceOf[InternalRow] @@ -130,7 +130,7 @@ class GenericMutableRow(val values: Array[Any]) extends MutableRow { override def numFields: Int = values.length - override def get(i: Int): Any = values(i) + override def get(i: Int, dataType: DataType): Any = values(i) override def getStruct(ordinal: Int, numFields: Int): InternalRow = { values(ordinal).asInstanceOf[InternalRow]