Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public interface SpecializedGetters {

double getDouble(int ordinal);

Decimal getDecimal(int ordinal);
Decimal getDecimal(int ordinal, int precision, int scale);

UTF8String getUTF8String(int ordinal);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import java.util.Iterator;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
Expand Down Expand Up @@ -61,26 +63,18 @@ public final class UnsafeFixedWidthAggregationMap {

private final boolean enablePerfMetrics;

/**
* @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema,
* false otherwise.
*/
public static boolean supportsGroupKeySchema(StructType schema) {
for (StructField field: schema.fields()) {
if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
return false;
}
}
return true;
}

/**
* @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
* schema, false otherwise.
*/
public static boolean supportsAggregationBufferSchema(StructType schema) {
for (StructField field: schema.fields()) {
if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
if (field.dataType() instanceof DecimalType) {
DecimalType dt = (DecimalType) field.dataType();
if (dt.precision() > Decimal.MAX_LONG_DIGITS()) {
return false;
}
} else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import java.io.IOException;
import java.io.OutputStream;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
Expand Down Expand Up @@ -65,12 +67,7 @@ public static int calculateBitSetWidthInBytes(int numFields) {
*/
public static final Set<DataType> settableFieldTypes;

/**
* Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException).
*/
public static final Set<DataType> readableFieldTypes;

// TODO: support DecimalType
// DecimalType(precision <= 18) is settable
static {
settableFieldTypes = Collections.unmodifiableSet(
new HashSet<>(
Expand All @@ -86,16 +83,6 @@ public static int calculateBitSetWidthInBytes(int numFields) {
DateType,
TimestampType
})));

// We support get() on a superset of the types for which we support set():
final Set<DataType> _readableFieldTypes = new HashSet<>(
Arrays.asList(new DataType[]{
StringType,
BinaryType,
CalendarIntervalType
}));
_readableFieldTypes.addAll(settableFieldTypes);
readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
}

//////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -232,6 +219,21 @@ public void setFloat(int ordinal, float value) {
PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
}

@Override
public void setDecimal(int ordinal, Decimal value, int precision) {
assertIndexIsValid(ordinal);
if (value == null) {
setNullAt(ordinal);
} else {
if (precision <= Decimal.MAX_LONG_DIGITS()) {
setLong(ordinal, value.toUnscaledLong());
} else {
// TODO(davies): support update decimal (hold a bounded space even it's null)
throw new UnsupportedOperationException();
}
}
}

@Override
public Object get(int ordinal) {
throw new UnsupportedOperationException();
Expand All @@ -256,7 +258,8 @@ public Object get(int ordinal, DataType dataType) {
} else if (dataType instanceof DoubleType) {
return getDouble(ordinal);
} else if (dataType instanceof DecimalType) {
return getDecimal(ordinal);
DecimalType dt = (DecimalType) dataType;
return getDecimal(ordinal, dt.precision(), dt.scale());
} else if (dataType instanceof DateType) {
return getInt(ordinal);
} else if (dataType instanceof TimestampType) {
Expand Down Expand Up @@ -322,6 +325,22 @@ public double getDouble(int ordinal) {
return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal));
}

@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
assertIndexIsValid(ordinal);
if (isNullAt(ordinal)) {
return null;
}
if (precision <= Decimal.MAX_LONG_DIGITS()) {
return Decimal.apply(getLong(ordinal), precision, scale);
} else {
byte[] bytes = getBinary(ordinal);
BigInteger bigInteger = new BigInteger(bytes);
BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale);
}
}

@Override
public UTF8String getUTF8String(int ordinal) {
assertIndexIsValid(ordinal);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.ByteArray;
Expand All @@ -30,6 +31,47 @@
*/
public class UnsafeRowWriters {

/** Writer for Decimal with precision under 18. */
public static class CompactDecimalWriter {

public static int getSize(Decimal input) {
return 0;
}

public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) {
target.setLong(ordinal, input.toUnscaledLong());
return 0;
}
}

/** Writer for Decimal with precision larger than 18. */
public static class DecimalWriter {

public static int getSize(Decimal input) {
// bounded size
return 16;
}

public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) {
final long offset = target.getBaseOffset() + cursor;
final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
final int numBytes = bytes.length;
assert(numBytes <= 16);

// zero-out the bytes
PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L);
PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L);

// Write the bytes to the variable length portion.
PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET,
target.getBaseObject(), offset, numBytes);

// Set the fixed length portion.
target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
return 16;
}
}

/** Writer for UTF8String. */
public static class UTF8StringWriter {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ object CatalystTypeConverters {
case StringType => StringConverter
case DateType => DateConverter
case TimestampType => TimestampConverter
case dt: DecimalType => BigDecimalConverter
case dt: DecimalType => new DecimalConverter(dt)
case BooleanType => BooleanConverter
case ByteType => ByteConverter
case ShortType => ShortConverter
Expand Down Expand Up @@ -306,17 +306,20 @@ object CatalystTypeConverters {
DateTimeUtils.toJavaTimestamp(row.getLong(column))
}

private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
private class DecimalConverter(dataType: DecimalType)
extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match {
case d: BigDecimal => Decimal(d)
case d: JavaBigDecimal => Decimal(d)
case d: Decimal => d
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal
override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal =
row.getDecimal(column).toJavaBigDecimal
row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal
}

private object BigDecimalConverter extends DecimalConverter(DecimalType.SYSTEM_DEFAULT)

private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] {
final override def toScala(catalystValue: Any): Any = catalystValue
final override def toCatalystImpl(scalaValue: T): Any = scalaValue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters {

override def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType)

override def getDecimal(ordinal: Int): Decimal =
getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT)
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal =
getAs[Decimal](ordinal, DecimalType(precision, scale))

override def getInterval(ordinal: Int): CalendarInterval =
getAs[CalendarInterval](ordinal, CalendarIntervalType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection}
import org.apache.spark.sql.types.{StructType, DataType}
import org.apache.spark.sql.types.{Decimal, StructType, DataType}
import org.apache.spark.unsafe.types.UTF8String

/**
Expand Down Expand Up @@ -225,6 +225,11 @@ class JoinedRow extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)

override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = {
if (i < row1.numFields) row1.getDecimal(i, precision, scale)
else row2.getDecimal(i - row1.numFields, precision, scale)
}

override def getStruct(i: Int, numFields: Int): InternalRow = {
if (i < row1.numFields) {
row1.getStruct(i, numFields)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class CodeGenContext {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)"
case t: DecimalType => s"$getter.getDecimal($ordinal, ${t.precision}, ${t.scale})"
case StringType => s"$getter.getUTF8String($ordinal)"
case BinaryType => s"$getter.getBinary($ordinal)"
case CalendarIntervalType => s"$getter.getInterval($ordinal)"
Expand All @@ -120,10 +121,10 @@ class CodeGenContext {
*/
def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = {
val jt = javaType(dataType)
if (isPrimitiveType(jt)) {
s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
} else {
s"$row.update($ordinal, $value)"
dataType match {
case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
case _ => s"$row.update($ordinal, $value)"
}
}

Expand Down
Loading