From ed413bcc78d8d97a1a0cd0871d7a20f7170476d0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 29 Jun 2015 11:41:26 -0700 Subject: [PATCH 01/39] [SPARK-8692] [SQL] re-order the case statements that handling catalyst data types use same order: boolean, byte, short, int, date, long, timestamp, float, double, string, binary, decimal. Then we can easily check whether some data types are missing by just one glance, and make sure we handle data/timestamp just as int/long. Author: Wenchen Fan Closes #7073 from cloud-fan/fix-date and squashes the following commits: 463044d [Wenchen Fan] fix style 51cd347 [Wenchen Fan] refactor handling of date and timestmap --- .../expressions/SpecificMutableRow.scala | 12 +-- .../expressions/UnsafeRowConverter.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 6 +- .../spark/sql/columnar/ColumnAccessor.scala | 42 +++++----- .../spark/sql/columnar/ColumnBuilder.scala | 30 +++---- .../spark/sql/columnar/ColumnStats.scala | 74 ++++++++--------- .../spark/sql/columnar/ColumnType.scala | 10 +-- .../sql/execution/SparkSqlSerializer2.scala | 82 ++++++------------- .../sql/parquet/ParquetTableSupport.scala | 34 ++++---- .../spark/sql/parquet/ParquetTypes.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 9 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 54 ++++++------ .../sql/columnar/ColumnarTestUtils.scala | 8 +- .../NullableColumnAccessorSuite.scala | 6 +- .../columnar/NullableColumnBuilderSuite.scala | 6 +- 15 files changed, 174 insertions(+), 209 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 53fedb531cfb2..3928c0f2ffdaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -196,15 +196,15 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR def this(dataTypes: Seq[DataType]) = this( dataTypes.map { - case IntegerType => new MutableInt + case BooleanType => new MutableBoolean case ByteType => new MutableByte - case FloatType => new MutableFloat case ShortType => new MutableShort + // We use INT for DATE internally + case IntegerType | DateType => new MutableInt + // We use Long for Timestamp internally + case LongType | TimestampType => new MutableLong + case FloatType => new MutableFloat case DoubleType => new MutableDouble - case BooleanType => new MutableBoolean - case LongType => new MutableLong - case DateType => new MutableInt // We use INT for DATE internally - case TimestampType => new MutableLong // We use Long for Timestamp internally case _ => new MutableAny }.toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 89adaf053b1a4..b61d490429e4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -128,14 +128,12 @@ private object UnsafeColumnWriter { case BooleanType => BooleanUnsafeColumnWriter case ByteType => ByteUnsafeColumnWriter case ShortType => ShortUnsafeColumnWriter - case IntegerType => IntUnsafeColumnWriter - case LongType => LongUnsafeColumnWriter + case IntegerType | DateType => IntUnsafeColumnWriter + case LongType | TimestampType => LongUnsafeColumnWriter case FloatType => FloatUnsafeColumnWriter case DoubleType => DoubleUnsafeColumnWriter case StringType => StringUnsafeColumnWriter case BinaryType => BinaryUnsafeColumnWriter - case DateType => IntUnsafeColumnWriter - case TimestampType => LongUnsafeColumnWriter case t => throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e20e3a9dca502..57e0bede5db20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -120,15 +120,13 @@ class CodeGenContext { case BooleanType => JAVA_BOOLEAN case ByteType => JAVA_BYTE case ShortType => JAVA_SHORT - case IntegerType => JAVA_INT - case LongType => JAVA_LONG + case IntegerType | DateType => JAVA_INT + case LongType | TimestampType => JAVA_LONG case FloatType => JAVA_FLOAT case DoubleType => JAVA_DOUBLE case dt: DecimalType => decimalType case BinaryType => "byte[]" case StringType => stringType - case DateType => JAVA_INT - case TimestampType => JAVA_LONG case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 64449b2659b4b..931469bed634a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -71,44 +71,44 @@ private[sql] abstract class NativeColumnAccessor[T <: AtomicType]( private[sql] class BooleanColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, BOOLEAN) -private[sql] class IntColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, INT) +private[sql] class ByteColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, BYTE) private[sql] class ShortColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, SHORT) +private[sql] class IntColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, INT) + private[sql] class LongColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, LONG) -private[sql] class ByteColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, BYTE) - -private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, DOUBLE) - private[sql] class FloatColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, FLOAT) -private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) - extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) +private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DOUBLE) private[sql] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) -private[sql] class DateColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, DATE) - -private[sql] class TimestampColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, TIMESTAMP) - private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY) with NullableColumnAccessor +private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) + extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) + private[sql] class GenericColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC) with NullableColumnAccessor +private[sql] class DateColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DATE) + +private[sql] class TimestampColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, TIMESTAMP) + private[sql] object ColumnAccessor { def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { val dup = buffer.duplicate().order(ByteOrder.nativeOrder) @@ -118,17 +118,17 @@ private[sql] object ColumnAccessor { dup.getInt() dataType match { + case BooleanType => new BooleanColumnAccessor(dup) + case ByteType => new ByteColumnAccessor(dup) + case ShortType => new ShortColumnAccessor(dup) case IntegerType => new IntColumnAccessor(dup) + case DateType => new DateColumnAccessor(dup) case LongType => new LongColumnAccessor(dup) + case TimestampType => new TimestampColumnAccessor(dup) case FloatType => new FloatColumnAccessor(dup) case DoubleType => new DoubleColumnAccessor(dup) - case BooleanType => new BooleanColumnAccessor(dup) - case ByteType => new ByteColumnAccessor(dup) - case ShortType => new ShortColumnAccessor(dup) case StringType => new StringColumnAccessor(dup) case BinaryType => new BinaryColumnAccessor(dup) - case DateType => new DateColumnAccessor(dup) - case TimestampType => new TimestampColumnAccessor(dup) case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnAccessor(dup, precision, scale) case _ => new GenericColumnAccessor(dup) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 1949625699ca8..087c52239713d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -94,17 +94,21 @@ private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) -private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) +private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) +private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) + private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) -private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) +private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) -private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) +private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) + +private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) private[sql] class FixedDecimalColumnBuilder( precision: Int, @@ -113,19 +117,15 @@ private[sql] class FixedDecimalColumnBuilder( new FixedDecimalColumnStats, FIXED_DECIMAL(precision, scale)) -private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) +// TODO (lian) Add support for array, struct and map +private[sql] class GenericColumnBuilder + extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE) private[sql] class TimestampColumnBuilder extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP) -private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) - -// TODO (lian) Add support for array, struct and map -private[sql] class GenericColumnBuilder - extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) - private[sql] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 @@ -151,17 +151,17 @@ private[sql] object ColumnBuilder { columnName: String = "", useCompression: Boolean = false): ColumnBuilder = { val builder: ColumnBuilder = dataType match { + case BooleanType => new BooleanColumnBuilder + case ByteType => new ByteColumnBuilder + case ShortType => new ShortColumnBuilder case IntegerType => new IntColumnBuilder + case DateType => new DateColumnBuilder case LongType => new LongColumnBuilder + case TimestampType => new TimestampColumnBuilder case FloatType => new FloatColumnBuilder case DoubleType => new DoubleColumnBuilder - case BooleanType => new BooleanColumnBuilder - case ByteType => new ByteColumnBuilder - case ShortType => new ShortColumnBuilder case StringType => new StringColumnBuilder case BinaryType => new BinaryColumnBuilder - case DateType => new DateColumnBuilder - case TimestampType => new TimestampColumnBuilder case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnBuilder(precision, scale) case _ => new GenericColumnBuilder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 1bce214d1d6c3..00374d1fa3ef1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -132,17 +132,17 @@ private[sql] class ShortColumnStats extends ColumnStats { InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class LongColumnStats extends ColumnStats { - protected var upper = Long.MinValue - protected var lower = Long.MaxValue +private[sql] class IntColumnStats extends ColumnStats { + protected var upper = Int.MinValue + protected var lower = Int.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getLong(ordinal) + val value = row.getInt(ordinal) if (value > upper) upper = value if (value < lower) lower = value - sizeInBytes += LONG.defaultSize + sizeInBytes += INT.defaultSize } } @@ -150,17 +150,17 @@ private[sql] class LongColumnStats extends ColumnStats { InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class DoubleColumnStats extends ColumnStats { - protected var upper = Double.MinValue - protected var lower = Double.MaxValue +private[sql] class LongColumnStats extends ColumnStats { + protected var upper = Long.MinValue + protected var lower = Long.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getDouble(ordinal) + val value = row.getLong(ordinal) if (value > upper) upper = value if (value < lower) lower = value - sizeInBytes += DOUBLE.defaultSize + sizeInBytes += LONG.defaultSize } } @@ -186,35 +186,17 @@ private[sql] class FloatColumnStats extends ColumnStats { InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class FixedDecimalColumnStats extends ColumnStats { - protected var upper: Decimal = null - protected var lower: Decimal = null - - override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) - if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[Decimal] - if (upper == null || value.compareTo(upper) > 0) upper = value - if (lower == null || value.compareTo(lower) < 0) lower = value - sizeInBytes += FIXED_DECIMAL.defaultSize - } - } - - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) -} - -private[sql] class IntColumnStats extends ColumnStats { - protected var upper = Int.MinValue - protected var lower = Int.MaxValue +private[sql] class DoubleColumnStats extends ColumnStats { + protected var upper = Double.MinValue + protected var lower = Double.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getInt(ordinal) + val value = row.getDouble(ordinal) if (value > upper) upper = value if (value < lower) lower = value - sizeInBytes += INT.defaultSize + sizeInBytes += DOUBLE.defaultSize } } @@ -240,10 +222,6 @@ private[sql] class StringColumnStats extends ColumnStats { InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class DateColumnStats extends IntColumnStats - -private[sql] class TimestampColumnStats extends LongColumnStats - private[sql] class BinaryColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) @@ -256,6 +234,24 @@ private[sql] class BinaryColumnStats extends ColumnStats { InternalRow(null, null, nullCount, count, sizeInBytes) } +private[sql] class FixedDecimalColumnStats extends ColumnStats { + protected var upper: Decimal = null + protected var lower: Decimal = null + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row(ordinal).asInstanceOf[Decimal] + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + sizeInBytes += FIXED_DECIMAL.defaultSize + } + } + + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) +} + private[sql] class GenericColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) @@ -267,3 +263,7 @@ private[sql] class GenericColumnStats extends ColumnStats { override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, sizeInBytes) } + +private[sql] class DateColumnStats extends IntColumnStats + +private[sql] class TimestampColumnStats extends LongColumnStats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 8bf2151e4de68..fc72360c88fe1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -447,17 +447,17 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { private[sql] object ColumnType { def apply(dataType: DataType): ColumnType[_, _] = { dataType match { + case BooleanType => BOOLEAN + case ByteType => BYTE + case ShortType => SHORT case IntegerType => INT + case DateType => DATE case LongType => LONG + case TimestampType => TIMESTAMP case FloatType => FLOAT case DoubleType => DOUBLE - case BooleanType => BOOLEAN - case ByteType => BYTE - case ShortType => SHORT case StringType => STRING case BinaryType => BINARY - case DateType => DATE - case TimestampType => TIMESTAMP case DecimalType.Fixed(precision, scale) if precision < 19 => FIXED_DECIMAL(precision, scale) case _ => GENERIC diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 74a22353b1d27..056d435eecd23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -237,7 +237,7 @@ private[sql] object SparkSqlSerializer2 { out.writeShort(row.getShort(i)) } - case IntegerType => + case IntegerType | DateType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { @@ -245,7 +245,7 @@ private[sql] object SparkSqlSerializer2 { out.writeInt(row.getInt(i)) } - case LongType => + case LongType | TimestampType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { @@ -269,55 +269,39 @@ private[sql] object SparkSqlSerializer2 { out.writeDouble(row.getDouble(i)) } - case decimal: DecimalType => + case StringType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val value = row.apply(i).asInstanceOf[Decimal] - val javaBigDecimal = value.toJavaBigDecimal - // First, write out the unscaled value. - val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray + val bytes = row.getAs[UTF8String](i).getBytes out.writeInt(bytes.length) out.write(bytes) - // Then, write out the scale. - out.writeInt(javaBigDecimal.scale()) } - case DateType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeInt(row.getAs[Int](i)) - } - - case TimestampType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeLong(row.getAs[Long](i)) - } - - case StringType => + case BinaryType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[UTF8String](i).getBytes + val bytes = row.getAs[Array[Byte]](i) out.writeInt(bytes.length) out.write(bytes) } - case BinaryType => + case decimal: DecimalType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[Array[Byte]](i) + val value = row.apply(i).asInstanceOf[Decimal] + val javaBigDecimal = value.toJavaBigDecimal + // First, write out the unscaled value. + val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray out.writeInt(bytes.length) out.write(bytes) + // Then, write out the scale. + out.writeInt(javaBigDecimal.scale()) } } i += 1 @@ -364,14 +348,14 @@ private[sql] object SparkSqlSerializer2 { mutableRow.setShort(i, in.readShort()) } - case IntegerType => + case IntegerType | DateType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { mutableRow.setInt(i, in.readInt()) } - case LongType => + case LongType | TimestampType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { @@ -392,53 +376,39 @@ private[sql] object SparkSqlSerializer2 { mutableRow.setDouble(i, in.readDouble()) } - case decimal: DecimalType => + case StringType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { - // First, read in the unscaled value. val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - val unscaledVal = new BigInteger(bytes) - // Then, read the scale. - val scale = in.readInt() - // Finally, create the Decimal object and set it in the row. - mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) - } - - case DateType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.update(i, in.readInt()) - } - - case TimestampType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.update(i, in.readLong()) + mutableRow.update(i, UTF8String.fromBytes(bytes)) } - case StringType => + case BinaryType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - mutableRow.update(i, UTF8String.fromBytes(bytes)) + mutableRow.update(i, bytes) } - case BinaryType => + case decimal: DecimalType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { + // First, read in the unscaled value. val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - mutableRow.update(i, bytes) + val unscaledVal = new BigInteger(bytes) + // Then, read the scale. + val scale = in.readInt() + // Finally, create the Decimal object and set it in the row. + mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) } } i += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 0d96a1e8070b1..df2a96dfeb619 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -198,19 +198,18 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo private[parquet] def writePrimitive(schema: DataType, value: Any): Unit = { if (value != null) { schema match { - case StringType => writer.addBinary( - Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) - case BinaryType => writer.addBinary( - Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) - case IntegerType => writer.addInteger(value.asInstanceOf[Int]) + case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case ByteType => writer.addInteger(value.asInstanceOf[Byte]) case ShortType => writer.addInteger(value.asInstanceOf[Short]) + case IntegerType | DateType => writer.addInteger(value.asInstanceOf[Int]) case LongType => writer.addLong(value.asInstanceOf[Long]) case TimestampType => writeTimestamp(value.asInstanceOf[Long]) - case ByteType => writer.addInteger(value.asInstanceOf[Byte]) - case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case FloatType => writer.addFloat(value.asInstanceOf[Float]) - case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) - case DateType => writer.addInteger(value.asInstanceOf[Int]) + case DoubleType => writer.addDouble(value.asInstanceOf[Double]) + case StringType => writer.addBinary( + Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) + case BinaryType => writer.addBinary( + Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") @@ -353,19 +352,18 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { record: InternalRow, index: Int): Unit = { ctype match { + case BooleanType => writer.addBoolean(record.getBoolean(index)) + case ByteType => writer.addInteger(record.getByte(index)) + case ShortType => writer.addInteger(record.getShort(index)) + case IntegerType | DateType => writer.addInteger(record.getInt(index)) + case LongType => writer.addLong(record.getLong(index)) + case TimestampType => writeTimestamp(record.getLong(index)) + case FloatType => writer.addFloat(record.getFloat(index)) + case DoubleType => writer.addDouble(record.getDouble(index)) case StringType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) - case IntegerType => writer.addInteger(record.getInt(index)) - case ShortType => writer.addInteger(record.getShort(index)) - case LongType => writer.addLong(record.getLong(index)) - case ByteType => writer.addInteger(record.getByte(index)) - case DoubleType => writer.addDouble(record.getDouble(index)) - case FloatType => writer.addFloat(record.getFloat(index)) - case BooleanType => writer.addBoolean(record.getBoolean(index)) - case DateType => writer.addInteger(record.getInt(index)) - case TimestampType => writeTimestamp(record.getLong(index)) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 4d5199a140344..e748bd7857bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -38,8 +38,8 @@ import org.apache.spark.sql.types._ private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = ctype match { - case _: NumericType | BooleanType | StringType | BinaryType => true - case _: DataType => false + case _: NumericType | BooleanType | DateType | TimestampType | StringType | BinaryType => true + case _ => false } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 1f37455dd0bc4..9bd7b221e93f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -22,19 +22,20 @@ import org.apache.spark.sql.catalyst.expressions.InternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0)) testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0)) testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0)) testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, + InternalRow(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, InternalRow(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) testColumnStats(classOf[FixedDecimalColumnStats], FIXED_DECIMAL(15, 10), InternalRow(null, null, 0)) - testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) - testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, - InternalRow(Long.MaxValue, Long.MinValue, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 6daddfb2c4804..4d46a657056e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -36,9 +36,9 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { test("defaultSize") { val checks = Map( - INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, - FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 8, - BINARY -> 16, GENERIC -> 16) + BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, DATE -> 4, + LONG -> 8, TIMESTAMP -> 8, FLOAT -> 4, DOUBLE -> 8, + STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -60,27 +60,24 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } } - checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(BOOLEAN, true, 1) + checkActualSize(BYTE, Byte.MaxValue, 1) checkActualSize(SHORT, Short.MaxValue, 2) + checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(DATE, Int.MaxValue, 4) checkActualSize(LONG, Long.MaxValue, 8) - checkActualSize(BYTE, Byte.MaxValue, 1) - checkActualSize(DOUBLE, Double.MaxValue, 8) + checkActualSize(TIMESTAMP, Long.MaxValue, 8) checkActualSize(FLOAT, Float.MaxValue, 4) - checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) - checkActualSize(BOOLEAN, true, 1) + checkActualSize(DOUBLE, Double.MaxValue, 8) checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length) - checkActualSize(DATE, 0, 4) - checkActualSize(TIMESTAMP, 0L, 8) - - val binary = Array.fill[Byte](4)(0: Byte) - checkActualSize(BINARY, binary, 4 + 4) + checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) + checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) val generic = Map(1 -> "a") checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) } - testNativeColumnType[BooleanType.type]( - BOOLEAN, + testNativeColumnType(BOOLEAN)( (buffer: ByteBuffer, v: Boolean) => { buffer.put((if (v) 1 else 0).toByte) }, @@ -88,18 +85,23 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { buffer.get() == 1 }) - testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt) + testNativeColumnType(BYTE)(_.put(_), _.get) + + testNativeColumnType(SHORT)(_.putShort(_), _.getShort) + + testNativeColumnType(INT)(_.putInt(_), _.getInt) + + testNativeColumnType(DATE)(_.putInt(_), _.getInt) - testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort) + testNativeColumnType(LONG)(_.putLong(_), _.getLong) - testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong) + testNativeColumnType(TIMESTAMP)(_.putLong(_), _.getLong) - testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get) + testNativeColumnType(FLOAT)(_.putFloat(_), _.getFloat) - testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble) + testNativeColumnType(DOUBLE)(_.putDouble(_), _.getDouble) - testNativeColumnType[DecimalType]( - FIXED_DECIMAL(15, 10), + testNativeColumnType(FIXED_DECIMAL(15, 10))( (buffer: ByteBuffer, decimal: Decimal) => { buffer.putLong(decimal.toUnscaledLong) }, @@ -107,10 +109,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { Decimal(buffer.getLong(), 15, 10) }) - testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat) - testNativeColumnType[StringType.type]( - STRING, + testNativeColumnType(STRING)( (buffer: ByteBuffer, string: UTF8String) => { val bytes = string.getBytes buffer.putInt(bytes.length) @@ -197,8 +197,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } def testNativeColumnType[T <: AtomicType]( - columnType: NativeColumnType[T], - putter: (ByteBuffer, T#InternalType) => Unit, + columnType: NativeColumnType[T]) + (putter: (ByteBuffer, T#InternalType) => Unit, getter: (ByteBuffer) => T#InternalType): Unit = { testColumnType[T, T#InternalType](columnType, putter, getter) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index 7c86eae3f77fd..d9861339739c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -39,18 +39,18 @@ object ColumnarTestUtils { } (columnType match { + case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort case INT => Random.nextInt() + case DATE => Random.nextInt() case LONG => Random.nextLong() + case TIMESTAMP => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() - case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) - case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) - case DATE => Random.nextInt() - case TIMESTAMP => Random.nextLong() + case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case _ => // Using a random one-element map instead of an arbitrary object Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 2a6e0c376551a..9eaa769846088 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -42,9 +42,9 @@ class NullableColumnAccessorSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC, - DATE, TIMESTAMP - ).foreach { + BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + .foreach { testNullableColumnAccessor(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index cb4e9f1eb7f46..17e9ae464bcc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -38,9 +38,9 @@ class NullableColumnBuilderSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC, - DATE, TIMESTAMP - ).foreach { + BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + .foreach { testNullableColumnBuilder(_) } From 3664ee25f0a67de5ba76e9487a55a55216ae589f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 29 Jun 2015 11:53:17 -0700 Subject: [PATCH 02/39] [SPARK-8066, SPARK-8067] [hive] Add support for Hive 1.0, 1.1 and 1.2. Allow HiveContext to connect to metastores of those versions; some new shims had to be added to account for changing internal APIs. A new test was added to exercise the "reset()" path which now also requires a shim; and the test code was changed to use a directory under the build's target to store ivy dependencies. Without that, at least I consistently run into issues with Ivy messing up (or being confused) by my existing caches. Author: Marcelo Vanzin Closes #7026 from vanzin/SPARK-8067 and squashes the following commits: 3e2e67b [Marcelo Vanzin] [SPARK-8066, SPARK-8067] [hive] Add support for Hive 1.0, 1.1 and 1.2. --- .../spark/sql/hive/client/ClientWrapper.scala | 5 +- .../spark/sql/hive/client/HiveShim.scala | 70 ++++++++++++++++++- .../hive/client/IsolatedClientLoader.scala | 13 ++-- .../spark/sql/hive/client/package.scala | 33 +++++++-- .../spark/sql/hive/client/VersionsSuite.scala | 25 +++++-- 5 files changed, 131 insertions(+), 15 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 2f771d76793e5..4c708cec572ae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -97,6 +97,9 @@ private[hive] class ClientWrapper( case hive.v12 => new Shim_v0_12() case hive.v13 => new Shim_v0_13() case hive.v14 => new Shim_v0_14() + case hive.v1_0 => new Shim_v1_0() + case hive.v1_1 => new Shim_v1_1() + case hive.v1_2 => new Shim_v1_2() } // Create an internal session state for this ClientWrapper. @@ -456,7 +459,7 @@ private[hive] class ClientWrapper( logDebug(s"Deleting table $t") val table = client.getTable("default", t) client.getIndexes("default", t, 255).foreach { index => - client.dropIndex("default", t, index.getIndexName, true) + shim.dropIndex(client, "default", t, index.getIndexName) } if (!table.isIndexTable) { client.dropTable("default", t) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index e7c1779f80ce6..1fa9d278e2a57 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.client -import java.lang.{Boolean => JBoolean, Integer => JInteger} +import java.lang.{Boolean => JBoolean, Integer => JInteger, Long => JLong} import java.lang.reflect.{Method, Modifier} import java.net.URI import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} @@ -94,6 +94,8 @@ private[client] sealed abstract class Shim { holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit + def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit + protected def findStaticMethod(klass: Class[_], name: String, args: Class[_]*): Method = { val method = findMethod(klass, name, args: _*) require(Modifier.isStatic(method.getModifiers()), @@ -166,6 +168,14 @@ private[client] class Shim_v0_12 extends Shim { JInteger.TYPE, JBoolean.TYPE, JBoolean.TYPE) + private lazy val dropIndexMethod = + findMethod( + classOf[Hive], + "dropIndex", + classOf[String], + classOf[String], + classOf[String], + JBoolean.TYPE) override def setCurrentSessionState(state: SessionState): Unit = { // Starting from Hive 0.13, setCurrentSessionState will internally override @@ -234,6 +244,10 @@ private[client] class Shim_v0_12 extends Shim { numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean) } + override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { + dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean) + } + } private[client] class Shim_v0_13 extends Shim_v0_12 { @@ -379,3 +393,57 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { TimeUnit.MILLISECONDS).asInstanceOf[Long] } } + +private[client] class Shim_v1_0 extends Shim_v0_14 { + +} + +private[client] class Shim_v1_1 extends Shim_v1_0 { + + private lazy val dropIndexMethod = + findMethod( + classOf[Hive], + "dropIndex", + classOf[String], + classOf[String], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE) + + override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { + dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean, true: JBoolean) + } + +} + +private[client] class Shim_v1_2 extends Shim_v1_1 { + + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JLong.TYPE) + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE, + 0: JLong) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 0934ad5034671..3d609a66f3664 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -41,9 +41,11 @@ private[hive] object IsolatedClientLoader { */ def forVersion( version: String, - config: Map[String, String] = Map.empty): IsolatedClientLoader = synchronized { + config: Map[String, String] = Map.empty, + ivyPath: Option[String] = None): IsolatedClientLoader = synchronized { val resolvedVersion = hiveVersion(version) - val files = resolvedVersions.getOrElseUpdate(resolvedVersion, downloadVersion(resolvedVersion)) + val files = resolvedVersions.getOrElseUpdate(resolvedVersion, + downloadVersion(resolvedVersion, ivyPath)) new IsolatedClientLoader(hiveVersion(version), files, config) } @@ -51,9 +53,12 @@ private[hive] object IsolatedClientLoader { case "12" | "0.12" | "0.12.0" => hive.v12 case "13" | "0.13" | "0.13.0" | "0.13.1" => hive.v13 case "14" | "0.14" | "0.14.0" => hive.v14 + case "1.0" | "1.0.0" => hive.v1_0 + case "1.1" | "1.1.0" => hive.v1_1 + case "1.2" | "1.2.0" => hive.v1_2 } - private def downloadVersion(version: HiveVersion): Seq[URL] = { + private def downloadVersion(version: HiveVersion, ivyPath: Option[String]): Seq[URL] = { val hiveArtifacts = version.extraDeps ++ Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") .map(a => s"org.apache.hive:$a:${version.fullVersion}") ++ @@ -64,7 +69,7 @@ private[hive] object IsolatedClientLoader { SparkSubmitUtils.resolveMavenCoordinates( hiveArtifacts.mkString(","), Some("http://www.datanucleus.org/downloads/maven2"), - None, + ivyPath, exclusions = version.exclusions) } val allFiles = classpath.split(",").map(new File(_)).toSet diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 27a3d8f5896cc..b48082fe4b363 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -32,13 +32,36 @@ package object client { // Hive 0.14 depends on calcite 0.9.2-incubating-SNAPSHOT which does not exist in // maven central anymore, so override those with a version that exists. // - // org.pentaho:pentaho-aggdesigner-algorithm is also nowhere to be found, so exclude - // it explicitly. If it's needed by the metastore client, users will have to dig it - // out of somewhere and use configuration to point Spark at the correct jars. + // The other excluded dependencies are also nowhere to be found, so exclude them explicitly. If + // they're needed by the metastore client, users will have to dig them out of somewhere and use + // configuration to point Spark at the correct jars. case object v14 extends HiveVersion("0.14.0", - Seq("org.apache.calcite:calcite-core:1.3.0-incubating", + extraDeps = Seq("org.apache.calcite:calcite-core:1.3.0-incubating", "org.apache.calcite:calcite-avatica:1.3.0-incubating"), - Seq("org.pentaho:pentaho-aggdesigner-algorithm")) + exclusions = Seq("org.pentaho:pentaho-aggdesigner-algorithm")) + + case object v1_0 extends HiveVersion("1.0.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) + + // The curator dependency was added to the exclusions here because it seems to confuse the ivy + // library. org.apache.curator:curator is a pom dependency but ivy tries to find the jar for it, + // and fails. + case object v1_1 extends HiveVersion("1.1.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) + + case object v1_2 extends HiveVersion("1.2.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) } // scalastyle:on diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 9a571650b6e25..d52e162acbd04 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.client +import java.io.File + import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.util.Utils @@ -28,6 +30,12 @@ import org.apache.spark.util.Utils * is not fully tested. */ class VersionsSuite extends SparkFunSuite with Logging { + + // Do not use a temp path here to speed up subsequent executions of the unit test during + // development. + private val ivyPath = Some( + new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath()) + private def buildConf() = { lazy val warehousePath = Utils.createTempDir() lazy val metastorePath = Utils.createTempDir() @@ -38,7 +46,7 @@ class VersionsSuite extends SparkFunSuite with Logging { } test("success sanity check") { - val badClient = IsolatedClientLoader.forVersion("13", buildConf()).client + val badClient = IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client val db = new HiveDatabase("default", "") badClient.createDatabase(db) } @@ -67,19 +75,21 @@ class VersionsSuite extends SparkFunSuite with Logging { // TODO: currently only works on mysql where we manually create the schema... ignore("failure sanity check") { val e = intercept[Throwable] { - val badClient = quietly { IsolatedClientLoader.forVersion("13", buildConf()).client } + val badClient = quietly { + IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client + } } assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("12", "13", "14") + private val versions = Seq("12", "13", "14", "1.0.0", "1.1.0", "1.2.0") private var client: ClientInterface = null versions.foreach { version => test(s"$version: create client") { client = null - client = IsolatedClientLoader.forVersion(version, buildConf()).client + client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).client } test(s"$version: createDatabase") { @@ -170,5 +180,12 @@ class VersionsSuite extends SparkFunSuite with Logging { false, false) } + + test(s"$version: create index and reset") { + client.runSqlHive("CREATE TABLE indexed_table (key INT)") + client.runSqlHive("CREATE INDEX index_1 ON TABLE indexed_table(key) " + + "as 'COMPACT' WITH DEFERRED REBUILD") + client.reset() + } } } From a5c2961caaafd751f11bdd406bb6885443d7572e Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 29 Jun 2015 11:57:19 -0700 Subject: [PATCH 03/39] [SPARK-8235] [SQL] misc function sha / sha1 Jira: https://issues.apache.org/jira/browse/SPARK-8235 I added the support for sha1. If I understood rxin correctly, sha and sha1 should execute the same algorithm, shouldn't they? Please take a close look on the Python part. This is adopted from #6934 Author: Tarek Auel Author: Tarek Auel Closes #6963 from tarekauel/SPARK-8235 and squashes the following commits: f064563 [Tarek Auel] change to shaHex 7ce3cdc [Tarek Auel] rely on automatic cast a1251d6 [Tarek Auel] Merge remote-tracking branch 'upstream/master' into SPARK-8235 68eb043 [Tarek Auel] added docstring be5aff1 [Tarek Auel] improved error message 7336c96 [Tarek Auel] added type check cf23a80 [Tarek Auel] simplified example ebf75ef [Tarek Auel] [SPARK-8301] updated the python documentation. Removed sha in python and scala 6d6ff0d [Tarek Auel] [SPARK-8233] added docstring ea191a9 [Tarek Auel] [SPARK-8233] fixed signatureof python function. Added expected type to misc e3fd7c3 [Tarek Auel] SPARK[8235] added sha to the list of __all__ e5dad4e [Tarek Auel] SPARK[8235] sha / sha1 --- python/pyspark/sql/functions.py | 14 +++++++++ .../catalyst/analysis/FunctionRegistry.scala | 2 ++ .../spark/sql/catalyst/expressions/misc.scala | 30 ++++++++++++++++++- .../expressions/MiscFunctionsSuite.scala | 8 +++++ .../org/apache/spark/sql/functions.scala | 16 ++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 12 ++++++++ 6 files changed, 81 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 7d3d0361610b7..45ecd826bd3bd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -42,6 +42,7 @@ 'monotonicallyIncreasingId', 'rand', 'randn', + 'sha1', 'sha2', 'sparkPartitionId', 'struct', @@ -382,6 +383,19 @@ def sha2(col, numBits): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def sha1(col): + """Returns the hex string result of SHA-1. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() + [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha1(_to_java_column(col)) + return Column(jc) + + @since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 457948a800a17..b24064d061533 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -136,6 +136,8 @@ object FunctionRegistry { // misc functions expression[Md5]("md5"), expression[Sha2]("sha2"), + expression[Sha1]("sha1"), + expression[Sha1]("sha"), // aggregate functions expression[Average]("avg"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index e80706fc65aff..9a39165a1ff05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,8 +21,9 @@ import java.security.MessageDigest import java.security.NoSuchAlgorithmException import org.apache.commons.codec.digest.DigestUtils +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType, DataType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -140,3 +141,30 @@ case class Sha2(left: Expression, right: Expression) """ } } + +/** + * A function that calculates a sha1 hash value and returns it as a hex string + * For input of type [[BinaryType]] or [[StringType]] + */ +case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = StringType + + override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + UTF8String.fromString(DigestUtils.shaHex(value.asInstanceOf[Array[Byte]])) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => + "org.apache.spark.unsafe.types.UTF8String.fromString" + + s"(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" + ) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 38482c54c61db..36e636b5da6b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -31,6 +31,14 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Md5(Literal.create(null, BinaryType)), null) } + test("sha1") { + checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") + checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), + "5d211bad8f4ee70e16c7d343a838fc344a1ed961") + checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) + checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709") + } + test("sha2") { checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC")) checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 355ce0e3423cf..ef92801548a13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1414,6 +1414,22 @@ object functions { */ def md5(columnName: String): Column = md5(Column(columnName)) + /** + * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha1(e: Column): Column = Sha1(e.expr) + + /** + * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha1(columnName: String): Column = sha1(Column(columnName)) + /** * Calculates the SHA-2 family of hash functions and returns the value as a hex string. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8baed57a7f129..abfd47c811ed9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -144,6 +144,18 @@ class DataFrameFunctionsSuite extends QueryTest { Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) } + test("misc sha1 function") { + val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b") + checkAnswer( + df.select(sha1($"a"), sha1("b")), + Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")) + + val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b") + checkAnswer( + dfEmpty.selectExpr("sha1(a)", "sha1(b)"), + Row("da39a3ee5e6b4b0d3255bfef95601890afd80709", "da39a3ee5e6b4b0d3255bfef95601890afd80709")) + } + test("misc sha2 function") { val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") checkAnswer( From 492dca3a73e70705b5d5639e8fe4640b80e78d31 Mon Sep 17 00:00:00 2001 From: Vladimir Vladimirov Date: Mon, 29 Jun 2015 12:03:41 -0700 Subject: [PATCH 04/39] [SPARK-8528] Expose SparkContext.applicationId in PySpark Use case - we want to log applicationId (YARN in hour case) to request help with troubleshooting from the DevOps Author: Vladimir Vladimirov Closes #6936 from smartkiwi/master and squashes the following commits: 870338b [Vladimir Vladimirov] this would make doctest to run in python3 0eae619 [Vladimir Vladimirov] Scala doesn't use u'...' for unicode literals 14d77a8 [Vladimir Vladimirov] stop using ELLIPSIS b4ebfc5 [Vladimir Vladimirov] addressed PR feedback - updated docstring 223a32f [Vladimir Vladimirov] fixed test - applicationId is property that returns the string 3221f5a [Vladimir Vladimirov] [SPARK-8528] added documentation for Scala 2cff090 [Vladimir Vladimirov] [SPARK-8528] add applicationId property for SparkContext object in pyspark --- .../scala/org/apache/spark/SparkContext.scala | 8 ++++++++ python/pyspark/context.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c7a7436462083..b3c3bf3746e18 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -315,6 +315,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _dagScheduler = ds } + /** + * A unique identifier for the Spark application. + * Its format depends on the scheduler implementation. + * (i.e. + * in case of local spark app something like 'local-1433865536131' + * in case of YARN something like 'application_1433865536131_34483' + * ) + */ def applicationId: String = _applicationId def applicationAttemptId: Option[String] = _applicationAttemptId diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 90b2fffbb9c7c..d7466729b8f36 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -291,6 +291,21 @@ def version(self): """ return self._jsc.version() + @property + @ignore_unicode_prefix + def applicationId(self): + """ + A unique identifier for the Spark application. + Its format depends on the scheduler implementation. + (i.e. + in case of local spark app something like 'local-1433865536131' + in case of YARN something like 'application_1433865536131_34483' + ) + >>> sc.applicationId # doctest: +ELLIPSIS + u'local-...' + """ + return self._jsc.sc().applicationId() + @property def startTime(self): """Return the epoch time when the Spark Context was started.""" From 94e040d05996111b2b448bcdee1cda184c6d039b Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 29 Jun 2015 12:16:12 -0700 Subject: [PATCH 05/39] [SQL][DOCS] Remove wrong example from DataFrame.scala In DataFrame.scala, there are examples like as follows. ``` * // The following are equivalent: * peopleDf.filter($"age" > 15) * peopleDf.where($"age" > 15) * peopleDf($"age" > 15) ``` But, I think the last example doesn't work. Author: Kousuke Saruta Closes #6977 from sarutak/fix-dataframe-example and squashes the following commits: 46efbd7 [Kousuke Saruta] Removed wrong example --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d75d88307562e..986e59133919f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -682,7 +682,6 @@ class DataFrame private[sql]( * // The following are equivalent: * peopleDf.filter($"age" > 15) * peopleDf.where($"age" > 15) - * peopleDf($"age" > 15) * }}} * @group dfops * @since 1.3.0 @@ -707,7 +706,6 @@ class DataFrame private[sql]( * // The following are equivalent: * peopleDf.filter($"age" > 15) * peopleDf.where($"age" > 15) - * peopleDf($"age" > 15) * }}} * @group dfops * @since 1.3.0 From 637b4eedad84dcff1769454137a64ac70c7f2397 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Mon, 29 Jun 2015 12:25:16 -0700 Subject: [PATCH 06/39] [SPARK-8214] [SQL] Add function hex cc chenghao-intel adrian-wang Author: zhichao.li Closes #6976 from zhichao-li/hex and squashes the following commits: e218d1b [zhichao.li] turn off scalastyle for non-ascii de3f5ea [zhichao.li] non-ascii char cf9c936 [zhichao.li] give separated buffer for each hex method 967ec90 [zhichao.li] Make 'value' as a feild of Hex 3b2fa13 [zhichao.li] tiny fix a647641 [zhichao.li] remove duplicate null check 7cab020 [zhichao.li] tiny refactoring 35ecfe5 [zhichao.li] add function hex --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 86 ++++++++++++++++++- .../expressions/MathFunctionsSuite.scala | 14 ++- .../org/apache/spark/sql/functions.scala | 16 ++++ .../spark/sql/MathExpressionsSuite.scala | 13 +++ 5 files changed, 125 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b24064d061533..b17457d3094c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -113,6 +113,7 @@ object FunctionRegistry { expression[Expm1]("expm1"), expression[Floor]("floor"), expression[Hypot]("hypot"), + expression[Hex]("hex"), expression[Logarithm]("log"), expression[Log]("ln"), expression[Log10]("log10"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 5694afc61be05..4b57ddd9c5768 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import java.lang.{Long => JLong} +import java.util.Arrays +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -273,9 +275,6 @@ case class Atan2(left: Expression, right: Expression) } } -case class Hypot(left: Expression, right: Expression) - extends BinaryMathExpression(math.hypot, "HYPOT") - case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -287,6 +286,85 @@ case class Pow(left: Expression, right: Expression) } } +/** + * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. + * Otherwise if the number is a STRING, + * it converts each character into its hexadecimal representation and returns the resulting STRING. + * Negative numbers would be treated as two's complement. + */ +case class Hex(child: Expression) + extends UnaryExpression with Serializable { + + override def dataType: DataType = StringType + + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType.isInstanceOf[StringType] + || child.dataType.isInstanceOf[IntegerType] + || child.dataType.isInstanceOf[LongType] + || child.dataType.isInstanceOf[BinaryType] + || child.dataType == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"hex doesn't accepts ${child.dataType} type") + } + } + + override def eval(input: InternalRow): Any = { + val num = child.eval(input) + if (num == null) { + null + } else { + child.dataType match { + case LongType => hex(num.asInstanceOf[Long]) + case IntegerType => hex(num.asInstanceOf[Integer].toLong) + case BinaryType => hex(num.asInstanceOf[Array[Byte]]) + case StringType => hex(num.asInstanceOf[UTF8String]) + } + } + } + + /** + * Converts every character in s to two hex digits. + */ + private def hex(str: UTF8String): UTF8String = { + hex(str.getBytes) + } + + private def hex(bytes: Array[Byte]): UTF8String = { + doHex(bytes, bytes.length) + } + + private def doHex(bytes: Array[Byte], length: Int): UTF8String = { + val value = new Array[Byte](length * 2) + var i = 0 + while(i < length) { + value(i * 2) = Character.toUpperCase(Character.forDigit( + (bytes(i) & 0xF0) >>> 4, 16)).toByte + value(i * 2 + 1) = Character.toUpperCase(Character.forDigit( + bytes(i) & 0x0F, 16)).toByte + i += 1 + } + UTF8String.fromBytes(value) + } + + private def hex(num: Long): UTF8String = { + // Extract the hex digits of num into value[] from right to left + val value = new Array[Byte](16) + var numBuf = num + var len = 0 + do { + len += 1 + value(value.length - len) = Character.toUpperCase(Character + .forDigit((numBuf & 0xF).toInt, 16)).toByte + numBuf >>>= 4 + } while (numBuf != 0) + UTF8String.fromBytes(Arrays.copyOfRange(value, value.length - len, value.length)) + } +} + +case class Hypot(left: Expression, right: Expression) + extends BinaryMathExpression(math.hypot, "HYPOT") + case class Logarithm(left: Expression, right: Expression) extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 0d1d5ebdff2d5..b932d4ab850c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types.{DataType, DoubleType, LongType} @@ -226,6 +225,19 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) } + test("hex") { + checkEvaluation(Hex(Literal(28)), "1C") + checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4") + checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") + checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") + checkEvaluation(Hex(Literal("helloHex")), "68656C6C6F486578") + checkEvaluation(Hex(Literal("helloHex".getBytes())), "68656C6C6F486578") + // scalastyle:off + // Turn off scala style for non-ascii chars + checkEvaluation(Hex(Literal("三重的")), "E4B889E9878DE79A84") + // scalastyle:on + } + test("hypot") { testBinary(Hypot, math.hypot) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ef92801548a13..5422e066afcb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1046,6 +1046,22 @@ object functions { */ def floor(columnName: String): Column = floor(Column(columnName)) + /** + * Computes hex value of the given column + * + * @group math_funcs + * @since 1.5.0 + */ + def hex(column: Column): Column = Hex(column.expr) + + /** + * Computes hex value of the given input + * + * @group math_funcs + * @since 1.5.0 + */ + def hex(colName: String): Column = hex(Column(colName)) + /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 2768d7dfc8030..d6331aa4ff09e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -212,6 +212,19 @@ class MathExpressionsSuite extends QueryTest { ) } + test("hex") { + val data = Seq((28, -28, 100800200404L, "hello")).toDF("a", "b", "c", "d") + checkAnswer(data.select(hex('a)), Seq(Row("1C"))) + checkAnswer(data.select(hex('b)), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.select(hex('c)), Seq(Row("177828FED4"))) + checkAnswer(data.select(hex('d)), Seq(Row("68656C6C6F"))) + checkAnswer(data.selectExpr("hex(a)"), Seq(Row("1C"))) + checkAnswer(data.selectExpr("hex(b)"), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.selectExpr("hex(c)"), Seq(Row("177828FED4"))) + checkAnswer(data.selectExpr("hex(d)"), Seq(Row("68656C6C6F"))) + checkAnswer(data.selectExpr("hex(cast(d as binary))"), Seq(Row("68656C6C6F"))) + } + test("hypot") { testTwoToOneMathFunction(hypot, hypot, math.hypot) } From c6ba2ea341ad23de265d870669b25e6a41f461e5 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 29 Jun 2015 12:46:33 -0700 Subject: [PATCH 07/39] [SPARK-7862] [SQL] Disable the error message redirect to stderr This is a follow up of #6404, the ScriptTransformation prints the error msg into stderr directly, probably be a disaster for application log. Author: Cheng Hao Closes #6882 from chenghao-intel/verbose and squashes the following commits: bfedd77 [Cheng Hao] revert the write 76ff46b [Cheng Hao] update the CircularBuffer 692b19e [Cheng Hao] check the process exitValue for ScriptTransform 47e0970 [Cheng Hao] Use the RedirectThread instead 1de771d [Cheng Hao] naming the threads in ScriptTransformation 8536e81 [Cheng Hao] disable the error message redirection for stderr --- .../scala/org/apache/spark/util/Utils.scala | 33 ++++++++++++ .../org/apache/spark/util/UtilsSuite.scala | 8 +++ .../spark/sql/hive/client/ClientWrapper.scala | 29 ++--------- .../hive/execution/ScriptTransformation.scala | 51 ++++++++++++------- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 5 files changed, 77 insertions(+), 46 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 19157af5b6f4d..a7fc749a2b0c6 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2333,3 +2333,36 @@ private[spark] class RedirectThread( } } } + +/** + * An [[OutputStream]] that will store the last 10 kilobytes (by default) written to it + * in a circular buffer. The current contents of the buffer can be accessed using + * the toString method. + */ +private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.OutputStream { + var pos: Int = 0 + var buffer = new Array[Int](sizeInBytes) + + def write(i: Int): Unit = { + buffer(pos) = i + pos = (pos + 1) % buffer.length + } + + override def toString: String = { + val (end, start) = buffer.splitAt(pos) + val input = new java.io.InputStream { + val iterator = (start ++ end).iterator + + def read(): Int = if (iterator.hasNext) iterator.next() else -1 + } + val reader = new BufferedReader(new InputStreamReader(input)) + val stringBuilder = new StringBuilder + var line = reader.readLine() + while (line != null) { + stringBuilder.append(line) + stringBuilder.append("\n") + line = reader.readLine() + } + stringBuilder.toString() + } +} diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index a61ea3918f46a..baa4c661cc21e 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -673,4 +673,12 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(!Utils.isInDirectory(nullFile, parentDir)) assert(!Utils.isInDirectory(nullFile, childFile3)) } + + test("circular buffer") { + val buffer = new CircularBuffer(25) + val stream = new java.io.PrintStream(buffer, true, "UTF-8") + + stream.println("test circular test circular test circular test circular test circular") + assert(buffer.toString === "t circular test circular\n") + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 4c708cec572ae..cbd2bf6b5eede 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -22,6 +22,8 @@ import java.net.URI import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet} import javax.annotation.concurrent.GuardedBy +import org.apache.spark.util.CircularBuffer + import scala.collection.JavaConversions._ import scala.language.reflectiveCalls @@ -66,32 +68,7 @@ private[hive] class ClientWrapper( with Logging { // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. - private val outputBuffer = new java.io.OutputStream { - var pos: Int = 0 - var buffer = new Array[Int](10240) - def write(i: Int): Unit = { - buffer(pos) = i - pos = (pos + 1) % buffer.size - } - - override def toString: String = { - val (end, start) = buffer.splitAt(pos) - val input = new java.io.InputStream { - val iterator = (start ++ end).iterator - - def read(): Int = if (iterator.hasNext) iterator.next() else -1 - } - val reader = new BufferedReader(new InputStreamReader(input)) - val stringBuilder = new StringBuilder - var line = reader.readLine() - while(line != null) { - stringBuilder.append(line) - stringBuilder.append("\n") - line = reader.readLine() - } - stringBuilder.toString() - } - } + private val outputBuffer = new CircularBuffer() private val shim = version match { case hive.v12 => new Shim_v0_12() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 611888055d6cf..b967e191c5855 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.types.DataType -import org.apache.spark.util.Utils +import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} /** * Transforms the input by forking and running the specified script. @@ -59,15 +59,13 @@ case class ScriptTransformation( child.execute().mapPartitions { iter => val cmd = List("/bin/bash", "-c", script) val builder = new ProcessBuilder(cmd) - // redirectError(Redirect.INHERIT) would consume the error output from buffer and - // then print it to stderr (inherit the target from the current Scala process). - // If without this there would be 2 issues: + // We need to start threads connected to the process pipeline: // 1) The error msg generated by the script process would be hidden. // 2) If the error msg is too big to chock up the buffer, the input logic would be hung - builder.redirectError(Redirect.INHERIT) val proc = builder.start() val inputStream = proc.getInputStream val outputStream = proc.getOutputStream + val errorStream = proc.getErrorStream val reader = new BufferedReader(new InputStreamReader(inputStream)) val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output) @@ -152,29 +150,43 @@ case class ScriptTransformation( val dataOutputStream = new DataOutputStream(outputStream) val outputProjection = new InterpretedProjection(input, child.output) + // TODO make the 2048 configurable? + val stderrBuffer = new CircularBuffer(2048) + // Consume the error stream from the pipeline, otherwise it will be blocked if + // the pipeline is full. + new RedirectThread(errorStream, // input stream from the pipeline + stderrBuffer, // output to a circular buffer + "Thread-ScriptTransformation-STDERR-Consumer").start() + // Put the write(output to the pipeline) into a single thread // and keep the collector as remain in the main thread. // otherwise it will causes deadlock if the data size greater than // the pipeline / buffer capacity. new Thread(new Runnable() { override def run(): Unit = { - iter - .map(outputProjection) - .foreach { row => - if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - - outputStream.write(data) - } else { - val writable = inputSerde.serialize( - row.asInstanceOf[GenericInternalRow].values, inputSoi) - prepareWritable(writable).write(dataOutputStream) + Utils.tryWithSafeFinally { + iter + .map(outputProjection) + .foreach { row => + if (inputSerde == null) { + val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") + + outputStream.write(data) + } else { + val writable = inputSerde.serialize( + row.asInstanceOf[GenericInternalRow].values, inputSoi) + prepareWritable(writable).write(dataOutputStream) + } + } + outputStream.close() + } { + if (proc.waitFor() != 0) { + logError(stderrBuffer.toString) // log the stderr circular buffer } } - outputStream.close() } - }).start() + }, "Thread-ScriptTransformation-Feed").start() iterator } @@ -278,3 +290,4 @@ case class HiveScriptIOSchema ( } } } + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f0aad8dbbe64d..9f7e58f890241 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -653,7 +653,7 @@ class SQLQuerySuite extends QueryTest { .queryExecution.toRdd.count()) } - ignore("test script transform for stderr") { + test("test script transform for stderr") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(0 === From be7ef067620408859144e0244b0f1b8eb56faa86 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 29 Jun 2015 13:15:04 -0700 Subject: [PATCH 08/39] [SPARK-8681] fixed wrong ordering of columns in crosstab I specifically randomized the test. What crosstab does is equivalent to a countByKey, therefore if this test fails again for any reason, we will know that we hit a corner case or something. cc rxin marmbrus Author: Burak Yavuz Closes #7060 from brkyvz/crosstab-fixes and squashes the following commits: 0a65234 [Burak Yavuz] addressed comments v1 d96da7e [Burak Yavuz] fixed wrong ordering of columns in crosstab --- .../sql/execution/stat/StatFunctions.scala | 8 ++++-- .../apache/spark/sql/DataFrameStatSuite.scala | 28 ++++++++++--------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 042e2c9cbb22e..b624ef7e8fa1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -111,7 +111,7 @@ private[sql] object StatFunctions extends Logging { "the pairs. Please try reducing the amount of distinct items in your columns.") } // get the distinct values of column 2, so that we can make them the column names - val distinctCol2 = counts.map(_.get(1)).distinct.zipWithIndex.toMap + val distinctCol2: Map[Any, Int] = counts.map(_.get(1)).distinct.zipWithIndex.toMap val columnSize = distinctCol2.size require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + s"exceed 1e4. Currently $columnSize") @@ -120,14 +120,16 @@ private[sql] object StatFunctions extends Logging { rows.foreach { (row: Row) => // row.get(0) is column 1 // row.get(1) is column 2 - // row.get(3) is the frequency + // row.get(2) is the frequency countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts countsRow.update(0, UTF8String.fromString(col1Item.toString)) countsRow }.toSeq - val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq + // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in + // SPARK-8681. We need to explicitly sort by the column index and assign the column names. + val headerNames = distinctCol2.toSeq.sortBy(_._2).map(r => StructField(r._1.toString, LongType)) val schema = StructType(StructField(tableName, StringType) +: headerNames) new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0d3ff899dad72..64ec1a70c47e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.Random + import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite @@ -65,22 +67,22 @@ class DataFrameStatSuite extends SparkFunSuite { } test("crosstab") { - val df = Seq((0, 0), (2, 1), (1, 0), (2, 0), (0, 0), (2, 0)).toDF("a", "b") + val rng = new Random() + val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10))) + val df = data.toDF("a", "b") val crosstab = df.stat.crosstab("a", "b") val columnNames = crosstab.schema.fieldNames assert(columnNames(0) === "a_b") - assert(columnNames(1) === "0") - assert(columnNames(2) === "1") - val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0)) - assert(rows(0).get(0).toString === "0") - assert(rows(0).getLong(1) === 2L) - assert(rows(0).get(2) === 0L) - assert(rows(1).get(0).toString === "1") - assert(rows(1).getLong(1) === 1L) - assert(rows(1).get(2) === 0L) - assert(rows(2).get(0).toString === "2") - assert(rows(2).getLong(1) === 2L) - assert(rows(2).getLong(2) === 1L) + // reduce by key + val expected = data.map(t => (t, 1)).groupBy(_._1).mapValues(_.length) + val rows = crosstab.collect() + rows.foreach { row => + val i = row.getString(0).toInt + for (col <- 1 to 9) { + val j = columnNames(col).toInt + assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong) + } + } } test("Frequent Items") { From afae9766f28d2e58297405c39862d20a04267b62 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 29 Jun 2015 13:20:55 -0700 Subject: [PATCH 09/39] [SPARK-8070] [SQL] [PYSPARK] avoid spark jobs in createDataFrame Avoid the unnecessary jobs when infer schema from list. cc yhuai mengxr Author: Davies Liu Closes #6606 from davies/improve_create and squashes the following commits: a5928bf [Davies Liu] Update MimaExcludes.scala 62da911 [Davies Liu] fix mima bab4d7d [Davies Liu] Merge branch 'improve_create' of github.com:davies/spark into improve_create eee44a8 [Davies Liu] Merge branch 'master' of github.com:apache/spark into improve_create 8d9292d [Davies Liu] Update context.py eb24531 [Davies Liu] Update context.py c969997 [Davies Liu] bug fix d5a8ab0 [Davies Liu] fix tests 8c3f10d [Davies Liu] Merge branch 'master' of github.com:apache/spark into improve_create 6ea5925 [Davies Liu] address comments 6ceaeff [Davies Liu] avoid spark jobs in createDataFrame --- python/pyspark/sql/context.py | 64 +++++++++++++++++++++++++---------- python/pyspark/sql/types.py | 48 +++++++++++++++----------- 2 files changed, 75 insertions(+), 37 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index dc239226e6d3c..4dda3b430cfbf 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -203,7 +203,37 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc._javaAccumulator, returnType.json()) + def _inferSchemaFromList(self, data): + """ + Infer schema from list of Row or tuple. + + :param data: list of Row or tuple + :return: StructType + """ + if not data: + raise ValueError("can not infer schema from empty dataset") + first = data[0] + if type(first) is dict: + warnings.warn("inferring schema from dict is deprecated," + "please use pyspark.sql.Row instead") + schema = _infer_schema(first) + if _has_nulltype(schema): + for r in data: + schema = _merge_type(schema, _infer_schema(r)) + if not _has_nulltype(schema): + break + else: + raise ValueError("Some of types cannot be determined after inferring") + return schema + def _inferSchema(self, rdd, samplingRatio=None): + """ + Infer schema from an RDD of Row or tuple. + + :param rdd: an RDD of Row or tuple + :param samplingRatio: sampling ratio, or no sampling (default) + :return: StructType + """ first = rdd.first() if not first: raise ValueError("The first row in RDD is empty, " @@ -322,6 +352,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): data = [r.tolist() for r in data.to_records(index=False)] if not isinstance(data, RDD): + if not isinstance(data, list): + data = list(data) try: # data could be list, tuple, generator ... rdd = self._sc.parallelize(data) @@ -330,28 +362,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): else: rdd = data - if schema is None: - schema = self._inferSchema(rdd, samplingRatio) + if schema is None or isinstance(schema, (list, tuple)): + if isinstance(data, RDD): + struct = self._inferSchema(rdd, samplingRatio) + else: + struct = self._inferSchemaFromList(data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + schema = struct converter = _create_converter(schema) rdd = rdd.map(converter) - if isinstance(schema, (list, tuple)): - first = rdd.first() - if not isinstance(first, (list, tuple)): - raise TypeError("each row in `rdd` should be list or tuple, " - "but got %r" % type(first)) - row_cls = Row(*schema) - schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio) - - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) + elif isinstance(schema, StructType): + # take the first few rows to verify schema rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) - for row in rows: - _verify_type(row, schema) + else: + raise TypeError("schema should be StructType or list or None") # convert python objects to sql data converter = _python_to_sql_converter(schema) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 23d9adb0daea1..932686e5e4b01 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -635,7 +635,7 @@ def _need_python_to_sql_conversion(dataType): >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), ... StructField("values", ArrayType(DoubleType(), False), False)]) >>> _need_python_to_sql_conversion(schema0) - False + True >>> _need_python_to_sql_conversion(ExamplePointUDT()) True >>> schema1 = ArrayType(ExamplePointUDT(), False) @@ -647,7 +647,8 @@ def _need_python_to_sql_conversion(dataType): True """ if isinstance(dataType, StructType): - return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + # convert namedtuple or Row into tuple + return True elif isinstance(dataType, ArrayType): return _need_python_to_sql_conversion(dataType.elementType) elif isinstance(dataType, MapType): @@ -688,21 +689,25 @@ def _python_to_sql_converter(dataType): if isinstance(dataType, StructType): names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - converters = [_python_to_sql_converter(t) for t in types] - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): - return tuple(c(v) for c, v in zip(converters, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs - d = dict(obj) - return tuple(c(d.get(n)) for n, c in zip(names, converters)) + if any(_need_python_to_sql_conversion(t) for t in types): + converters = [_python_to_sql_converter(t) for t in types] + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): + return tuple(c(v) for c, v in zip(converters, obj)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + elif obj is not None: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + else: + def converter(obj): + if isinstance(obj, dict): + return tuple(obj.get(n) for n in names) else: - return tuple(c(v) for c, v in zip(converters, obj)) - elif obj is not None: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return tuple(obj) return converter elif isinstance(dataType, ArrayType): element_converter = _python_to_sql_converter(dataType.elementType) @@ -1027,10 +1032,13 @@ def _verify_type(obj, dataType): _type = type(dataType) assert _type in _acceptable_types, "unknown datatype: %s" % dataType - # subclass of them can not be deserialized in JVM - if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object in type %s" - % (dataType, type(obj))) + if _type is StructType: + if not isinstance(obj, (tuple, list)): + raise TypeError("StructType can not accept object in type %s" % type(obj)) + else: + # subclass of them can not be deserialized in JVM + if type(obj) not in _acceptable_types[_type]: + raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) if isinstance(dataType, ArrayType): for i in obj: From 27ef85451cd237caa7016baa69957a35ab365aa8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 29 Jun 2015 14:07:55 -0700 Subject: [PATCH 10/39] [SPARK-8709] Exclude hadoop-client's mockito-all dependency This patch excludes `hadoop-client`'s dependency on `mockito-all`. As of #7061, Spark depends on `mockito-core` instead of `mockito-all`, so the dependency from Hadoop was leading to test compilation failures for some of the Hadoop 2 SBT builds. Author: Josh Rosen Closes #7090 from JoshRosen/SPARK-8709 and squashes the following commits: e190122 [Josh Rosen] [SPARK-8709] Exclude hadoop-client's mockito-all dependency. --- LICENSE | 2 +- core/pom.xml | 10 ---------- launcher/pom.xml | 6 ------ pom.xml | 8 ++++++++ 4 files changed, 9 insertions(+), 17 deletions(-) diff --git a/LICENSE b/LICENSE index 8672be55eca3e..f9e412cade345 100644 --- a/LICENSE +++ b/LICENSE @@ -948,6 +948,6 @@ The following components are provided under the MIT License. See project link fo (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-core:1.8.5 - http://www.mockito.org) + (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) (MIT License) jquery (https://jquery.org/license/) (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) diff --git a/core/pom.xml b/core/pom.xml index 565437c4861a4..aee0d92620606 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -69,16 +69,6 @@ org.apache.hadoop hadoop-client - - - javax.servlet - servlet-api - - - org.codehaus.jackson - jackson-mapper-asl - - org.apache.spark diff --git a/launcher/pom.xml b/launcher/pom.xml index a853e67f5cf78..2fd768d8119c4 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -68,12 +68,6 @@ org.apache.hadoop hadoop-client test - - - org.codehaus.jackson - jackson-mapper-asl - - diff --git a/pom.xml b/pom.xml index 4c18bd5e42c87..94dd512cfb618 100644 --- a/pom.xml +++ b/pom.xml @@ -747,6 +747,10 @@ asm asm + + org.codehaus.jackson + jackson-mapper-asl + org.ow2.asm asm @@ -759,6 +763,10 @@ commons-logging commons-logging + + org.mockito + mockito-all + org.mortbay.jetty servlet-api-2.5 From f6fc254ec4ce5f103d45da6d007b4066ce751236 Mon Sep 17 00:00:00 2001 From: Ilya Ganelin Date: Mon, 29 Jun 2015 14:15:15 -0700 Subject: [PATCH 11/39] [SPARK-8056][SQL] Design an easier way to construct schema for both Scala and Python I've added functionality to create new StructType similar to how we add parameters to a new SparkContext. I've also added tests for this type of creation. Author: Ilya Ganelin Closes #6686 from ilganeli/SPARK-8056B and squashes the following commits: 27c1de1 [Ilya Ganelin] Rename 467d836 [Ilya Ganelin] Removed from_string in favor of _parse_Datatype_json_value 5fef5a4 [Ilya Ganelin] Updates for type parsing 4085489 [Ilya Ganelin] Style errors 3670cf5 [Ilya Ganelin] added string to DataType conversion 8109e00 [Ilya Ganelin] Fixed error in tests 41ab686 [Ilya Ganelin] Fixed style errors e7ba7e0 [Ilya Ganelin] Moved some python tests to tests.py. Added cleaner handling of null data type and added test for correctness of input format 15868fa [Ilya Ganelin] Fixed python errors b79b992 [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-8056B a3369fc [Ilya Ganelin] Fixing space errors e240040 [Ilya Ganelin] Style bab7823 [Ilya Ganelin] Constructor error 73d4677 [Ilya Ganelin] Style 4ed00d9 [Ilya Ganelin] Fixed default arg 67df57a [Ilya Ganelin] Removed Foo 04cbf0c [Ilya Ganelin] Added comments for single object 0484d7a [Ilya Ganelin] Restored second method 6aeb740 [Ilya Ganelin] Style 689e54d [Ilya Ganelin] Style f497e9e [Ilya Ganelin] Got rid of old code e3c7a88 [Ilya Ganelin] Fixed doctest failure a62ccde [Ilya Ganelin] Style 966ac06 [Ilya Ganelin] style checks dabb7e6 [Ilya Ganelin] Added Python tests a3f4152 [Ilya Ganelin] added python bindings and better comments e6e536c [Ilya Ganelin] Added extra space 7529a2e [Ilya Ganelin] Fixed formatting d388f86 [Ilya Ganelin] Fixed small bug c4e3bf5 [Ilya Ganelin] Reverted to using parse. Updated parse to support long d7634b6 [Ilya Ganelin] Reverted to fromString to properly support types 22c39d5 [Ilya Ganelin] replaced FromString with DataTypeParser.parse. Replaced empty constructor initializing a null to have it instead create a new array to allow appends to it. faca398 [Ilya Ganelin] [SPARK-8056] Replaced default argument usage. Updated usage and code for DataType.fromString 1acf76e [Ilya Ganelin] Scala style e31c674 [Ilya Ganelin] Fixed bug in test 8dc0795 [Ilya Ganelin] Added tests for creation of StructType object with new methods fdf7e9f [Ilya Ganelin] [SPARK-8056] Created add methods to facilitate building new StructType objects. --- python/pyspark/sql/tests.py | 29 +++++ python/pyspark/sql/types.py | 52 ++++++++- .../spark/sql/types/DataTypeParser.scala | 2 +- .../apache/spark/sql/types/StructType.scala | 104 +++++++++++++++++- .../spark/sql/types/DataTypeSuite.scala | 31 ++++++ 5 files changed, 212 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ffee43a94baba..34f397d0ffef0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -516,6 +516,35 @@ def test_between_function(self): self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], df.filter(df.a.between(df.b, df.c)).collect()) + def test_struct_type(self): + from pyspark.sql.types import StructType, StringType, StructField + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1, struct2) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1, struct2) + + # Catch exception raised during improper construction + try: + struct1 = StructType().add("name") + self.assertEqual(1, 0) + except ValueError: + self.assertEqual(1, 1) + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 932686e5e4b01..ae9344e6106a4 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -355,8 +355,7 @@ class StructType(DataType): This is the data type representing a :class:`Row`. """ - - def __init__(self, fields): + def __init__(self, fields=None): """ >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct2 = StructType([StructField("f1", StringType(), True)]) @@ -368,8 +367,53 @@ def __init__(self, fields): >>> struct1 == struct2 False """ - assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType" - self.fields = fields + if not fields: + self.fields = [] + else: + self.fields = fields + assert all(isinstance(f, StructField) for f in fields),\ + "fields should be a list of StructField" + + def add(self, field, data_type=None, nullable=True, metadata=None): + """ + Construct a StructType by adding new elements to it to define the schema. The method accepts + either: + a) A single parameter which is a StructField object. + b) Between 2 and 4 parameters as (name, data_type, nullable (optional), + metadata(optional). The data_type parameter may be either a String or a DataType object + + >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + >>> struct2 = StructType([StructField("f1", StringType(), True),\ + StructField("f2", StringType(), True, None)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType().add(StructField("f1", StringType(), True)) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType().add("f1", "string", True) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + + :param field: Either the name of the field or a StructField object + :param data_type: If present, the DataType of the StructField to create + :param nullable: Whether the field to add should be nullable (default True) + :param metadata: Any additional metadata (default None) + :return: a new updated StructType + """ + if isinstance(field, StructField): + self.fields.append(field) + else: + if isinstance(field, str) and data_type is None: + raise ValueError("Must specify DataType if passing name of struct_field to create.") + + if isinstance(data_type, str): + data_type_f = _parse_datatype_json_value(data_type) + else: + data_type_f = data_type + self.fields.append(StructField(field, data_type_f, nullable, metadata)) + return self def simpleString(self): return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 04f3379afb38d..6b43224feb1f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -44,7 +44,7 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { "(?i)tinyint".r ^^^ ByteType | "(?i)smallint".r ^^^ ShortType | "(?i)double".r ^^^ DoubleType | - "(?i)bigint".r ^^^ LongType | + "(?i)(?:bigint|long)".r ^^^ LongType | "(?i)binary".r ^^^ BinaryType | "(?i)boolean".r ^^^ BooleanType | fixedDecimalType | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 193c08a4d0df7..2db0a359e9db5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -94,7 +94,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { /** No-arg constructor for kryo. */ - protected def this() = this(null) + def this() = this(Array.empty[StructField]) /** Returns all field names in an array. */ def fieldNames: Array[String] = fields.map(_.name) @@ -103,6 +103,108 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap + /** + * Creates a new [[StructType]] by adding a new field. + * {{{ + * val struct = (new StructType) + * .add(StructField("a", IntegerType, true)) + * .add(StructField("b", LongType, false)) + * .add(StructField("c", StringType, true)) + *}}} + */ + def add(field: StructField): StructType = { + StructType(fields :+ field) + } + + /** + * Creates a new [[StructType]] by adding a new nullable field with no metadata. + * + * val struct = (new StructType) + * .add("a", IntegerType) + * .add("b", LongType) + * .add("c", StringType) + */ + def add(name: String, dataType: DataType): StructType = { + StructType(fields :+ new StructField(name, dataType, nullable = true, Metadata.empty)) + } + + /** + * Creates a new [[StructType]] by adding a new field with no metadata. + * + * val struct = (new StructType) + * .add("a", IntegerType, true) + * .add("b", LongType, false) + * .add("c", StringType, true) + */ + def add(name: String, dataType: DataType, nullable: Boolean): StructType = { + StructType(fields :+ new StructField(name, dataType, nullable, Metadata.empty)) + } + + /** + * Creates a new [[StructType]] by adding a new field and specifying metadata. + * {{{ + * val struct = (new StructType) + * .add("a", IntegerType, true, Metadata.empty) + * .add("b", LongType, false, Metadata.empty) + * .add("c", StringType, true, Metadata.empty) + * }}} + */ + def add( + name: String, + dataType: DataType, + nullable: Boolean, + metadata: Metadata): StructType = { + StructType(fields :+ new StructField(name, dataType, nullable, metadata)) + } + + /** + * Creates a new [[StructType]] by adding a new nullable field with no metadata where the + * dataType is specified as a String. + * + * {{{ + * val struct = (new StructType) + * .add("a", "int") + * .add("b", "long") + * .add("c", "string") + * }}} + */ + def add(name: String, dataType: String): StructType = { + add(name, DataTypeParser.parse(dataType), nullable = true, Metadata.empty) + } + + /** + * Creates a new [[StructType]] by adding a new field with no metadata where the + * dataType is specified as a String. + * + * {{{ + * val struct = (new StructType) + * .add("a", "int", true) + * .add("b", "long", false) + * .add("c", "string", true) + * }}} + */ + def add(name: String, dataType: String, nullable: Boolean): StructType = { + add(name, DataTypeParser.parse(dataType), nullable, Metadata.empty) + } + + /** + * Creates a new [[StructType]] by adding a new field and specifying metadata where the + * dataType is specified as a String. + * {{{ + * val struct = (new StructType) + * .add("a", "int", true, Metadata.empty) + * .add("b", "long", false, Metadata.empty) + * .add("c", "string", true, Metadata.empty) + * }}} + */ + def add( + name: String, + dataType: String, + nullable: Boolean, + metadata: Metadata): StructType = { + add(name, DataTypeParser.parse(dataType), nullable, metadata) + } + /** * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not * have a name matching the given name, `null` will be returned. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 077c0ad70ac4f..14e7b4a9561b6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -33,6 +33,37 @@ class DataTypeSuite extends SparkFunSuite { assert(MapType(StringType, IntegerType, true) === map) } + test("construct with add") { + val struct = (new StructType) + .add("a", IntegerType, true) + .add("b", LongType, false) + .add("c", StringType, true) + + assert(StructField("b", LongType, false) === struct("b")) + } + + test("construct with add from StructField") { + // Test creation from StructField type + val struct = (new StructType) + .add(StructField("a", IntegerType, true)) + .add(StructField("b", LongType, false)) + .add(StructField("c", StringType, true)) + + assert(StructField("b", LongType, false) === struct("b")) + } + + test("construct with String DataType") { + // Test creation with DataType as String + val struct = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + + assert(StructField("a", IntegerType, true) === struct("a")) + assert(StructField("b", LongType, false) === struct("b")) + assert(StructField("c", StringType, true) === struct("c")) + } + test("extract fields from a StructType") { val struct = StructType( StructField("a", IntegerType, true) :: From ecd3aacf2805bb231cfb44bab079319cfe73c3f1 Mon Sep 17 00:00:00 2001 From: Ai He Date: Mon, 29 Jun 2015 14:36:26 -0700 Subject: [PATCH 12/39] [SPARK-7810] [PYSPARK] solve python rdd socket connection problem Method "_load_from_socket" in rdd.py cannot load data from jvm socket when ipv6 is used. The current method only works well with ipv4. New modification should work around both two protocols. Author: Ai He Author: AiHe Closes #6338 from AiHe/pyspark-networking-issue and squashes the following commits: d4fc9c4 [Ai He] handle code review 2 e75c5c8 [Ai He] handle code review 5644953 [AiHe] solve python rdd socket connection problem to jvm --- python/pyspark/rdd.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1b64be23a667e..cb20bc8b54027 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -121,10 +121,22 @@ def _parse_memory(s): def _load_from_socket(port, serializer): - sock = socket.socket() - sock.settimeout(3) + sock = None + # Support for both IPv4 and IPv6. + # On most of IPv6-ready systems, IPv6 will take precedence. + for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + try: + sock = socket.socket(af, socktype, proto) + sock.settimeout(3) + sock.connect(sa) + except socket.error: + sock = None + continue + break + if not sock: + raise Exception("could not open socket") try: - sock.connect(("localhost", port)) rf = sock.makefile("rb", 65536) for item in serializer.load_stream(rf): yield item From c8ae887ef02b8f7e2ad06841719fb12eacf1f7f9 Mon Sep 17 00:00:00 2001 From: Rosstin Date: Mon, 29 Jun 2015 14:45:08 -0700 Subject: [PATCH 13/39] [SPARK-8660][ML] Convert JavaDoc style comments inLogisticRegressionSuite.scala to regular multiline comments, to make copy-pasting R commands easier Converted JavaDoc style comments in mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala to regular multiline comments, to make copy-pasting R commands easier. Author: Rosstin Closes #7096 from Rosstin/SPARK-8660 and squashes the following commits: 242aedd [Rosstin] SPARK-8660, changed comment style from JavaDoc style to normal multiline comment in order to make copypaste into R easier, in file classification/LogisticRegressionSuite.scala 2cd2985 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8639 21ac1e5 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8639 6c18058 [Rosstin] fixed minor typos in docs/README.md and docs/api.md --- .../LogisticRegressionSuite.scala | 342 +++++++++--------- 1 file changed, 171 insertions(+), 171 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 5a6265ea992c6..bc6eeac1db5da 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -36,19 +36,19 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) - /** - * Here is the instruction describing how to export the test data into CSV format - * so we can validate the training accuracy compared with R's glmnet package. - * - * import org.apache.spark.mllib.classification.LogisticRegressionSuite - * val nPoints = 10000 - * val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) - * val xMean = Array(5.843, 3.057, 3.758, 1.199) - * val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) - * val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( - * weights, xMean, xVariance, true, nPoints, 42), 1) - * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " - * + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") + /* + Here is the instruction describing how to export the test data into CSV format + so we can validate the training accuracy compared with R's glmnet package. + + import org.apache.spark.mllib.classification.LogisticRegressionSuite + val nPoints = 10000 + val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 1) + data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " + + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") */ binaryDataset = { val nPoints = 10000 @@ -211,22 +211,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LogisticRegression).setFitIntercept(true) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 2.8366423 - * data.V2 -0.5895848 - * data.V3 0.8931147 - * data.V4 -0.3925051 - * data.V5 -0.7996864 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 2.8366423 + data.V2 -0.5895848 + data.V3 0.8931147 + data.V4 -0.3925051 + data.V5 -0.7996864 */ val interceptR = 2.8366423 val weightsR = Array(-0.5895848, 0.8931147, -0.3925051, -0.7996864) @@ -242,23 +242,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LogisticRegression).setFitIntercept(false) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = - * coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 -0.3534996 - * data.V3 1.2964482 - * data.V4 -0.3571741 - * data.V5 -0.7407946 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = + coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.3534996 + data.V3 1.2964482 + data.V4 -0.3571741 + data.V5 -0.7407946 */ val interceptR = 0.0 val weightsR = Array(-0.3534996, 1.2964482, -0.3571741, -0.7407946) @@ -275,22 +275,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(1.0).setRegParam(0.12) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) -0.05627428 - * data.V2 . - * data.V3 . - * data.V4 -0.04325749 - * data.V5 -0.02481551 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.05627428 + data.V2 . + data.V3 . + data.V4 -0.04325749 + data.V5 -0.02481551 */ val interceptR = -0.05627428 val weightsR = Array(0.0, 0.0, -0.04325749, -0.02481551) @@ -307,23 +307,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(1.0).setRegParam(0.12) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, - * intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 . - * data.V3 . - * data.V4 -0.05189203 - * data.V5 -0.03891782 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 . + data.V3 . + data.V4 -0.05189203 + data.V5 -0.03891782 */ val interceptR = 0.0 val weightsR = Array(0.0, 0.0, -0.05189203, -0.03891782) @@ -340,22 +340,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.0).setRegParam(1.37) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 0.15021751 - * data.V2 -0.07251837 - * data.V3 0.10724191 - * data.V4 -0.04865309 - * data.V5 -0.10062872 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.15021751 + data.V2 -0.07251837 + data.V3 0.10724191 + data.V4 -0.04865309 + data.V5 -0.10062872 */ val interceptR = 0.15021751 val weightsR = Array(-0.07251837, 0.10724191, -0.04865309, -0.10062872) @@ -372,23 +372,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.0).setRegParam(1.37) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, - * intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 -0.06099165 - * data.V3 0.12857058 - * data.V4 -0.04708770 - * data.V5 -0.09799775 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.06099165 + data.V3 0.12857058 + data.V4 -0.04708770 + data.V5 -0.09799775 */ val interceptR = 0.0 val weightsR = Array(-0.06099165, 0.12857058, -0.04708770, -0.09799775) @@ -405,22 +405,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.38).setRegParam(0.21) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 0.57734851 - * data.V2 -0.05310287 - * data.V3 . - * data.V4 -0.08849250 - * data.V5 -0.15458796 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.57734851 + data.V2 -0.05310287 + data.V3 . + data.V4 -0.08849250 + data.V5 -0.15458796 */ val interceptR = 0.57734851 val weightsR = Array(-0.05310287, 0.0, -0.08849250, -0.15458796) @@ -437,23 +437,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.38).setRegParam(0.21) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, - * intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 -0.001005743 - * data.V3 0.072577857 - * data.V4 -0.081203769 - * data.V5 -0.142534158 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, + intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.001005743 + data.V3 0.072577857 + data.V4 -0.081203769 + data.V5 -0.142534158 */ val interceptR = 0.0 val weightsR = Array(-0.001005743, 0.072577857, -0.081203769, -0.142534158) @@ -480,16 +480,16 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { classSummarizer1.merge(classSummarizer2) }).histogram - /** - * For binary logistic regression with strong L1 regularization, all the weights will be zeros. - * As a result, - * {{{ - * P(0) = 1 / (1 + \exp(b)), and - * P(1) = \exp(b) / (1 + \exp(b)) - * }}}, hence - * {{{ - * b = \log{P(1) / P(0)} = \log{count_1 / count_0} - * }}} + /* + For binary logistic regression with strong L1 regularization, all the weights will be zeros. + As a result, + {{{ + P(0) = 1 / (1 + \exp(b)), and + P(1) = \exp(b) / (1 + \exp(b)) + }}}, hence + {{{ + b = \log{P(1) / P(0)} = \log{count_1 / count_0} + }}} */ val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble) val weightsTheory = Array(0.0, 0.0, 0.0, 0.0) @@ -500,22 +500,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.weights(2) ~== weightsTheory(2) absTol 1E-6) assert(model.weights(3) ~== weightsTheory(3) absTol 1E-6) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) -0.2480643 - * data.V2 0.0000000 - * data.V3 . - * data.V4 . - * data.V5 . + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.2480643 + data.V2 0.0000000 + data.V3 . + data.V4 . + data.V5 . */ val interceptR = -0.248065 val weightsR = Array(0.0, 0.0, 0.0, 0.0) From 931da5c8ab271ff2ee04419c7e3c6b0012459694 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Mon, 29 Jun 2015 15:27:13 -0700 Subject: [PATCH 14/39] [SPARK-8478] [SQL] Harmonize UDF-related code to use uniformly UDF instead of Udf Follow-up of #6902 for being coherent between ```Udf``` and ```UDF``` Author: BenFradet Closes #6920 from BenFradet/SPARK-8478 and squashes the following commits: c500f29 [BenFradet] renamed a few variables in functions to use UDF 8ab0f2d [BenFradet] renamed idUdf to idUDF in SQLQuerySuite 98696c2 [BenFradet] renamed originalUdfs in TestHive to originalUDFs 7738f74 [BenFradet] modified HiveUDFSuite to use only UDF c52608d [BenFradet] renamed HiveUdfSuite to HiveUDFSuite e51b9ac [BenFradet] renamed ExtractPythonUdfs to ExtractPythonUDFs 8c756f1 [BenFradet] renamed Hive UDF related code 2a1ca76 [BenFradet] renamed pythonUdfs to pythonUDFs 261e6fb [BenFradet] renamed ScalaUdf to ScalaUDF --- .../{ScalaUdf.scala => ScalaUDF.scala} | 4 +- .../org/apache/spark/sql/SQLContext.scala | 4 +- .../apache/spark/sql/UDFRegistration.scala | 96 +++++++++--------- .../spark/sql/UserDefinedFunction.scala | 4 +- .../{pythonUdfs.scala => pythonUDFs.scala} | 2 +- .../org/apache/spark/sql/functions.scala | 34 +++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 4 +- .../apache/spark/sql/hive/HiveContext.scala | 4 +- .../org/apache/spark/sql/hive/HiveQl.scala | 2 +- .../hive/{hiveUdfs.scala => hiveUDFs.scala} | 26 ++--- .../apache/spark/sql/hive/test/TestHive.scala | 4 +- .../files/{testUdf => testUDF}/part-00000 | Bin ...{HiveUdfSuite.scala => HiveUDFSuite.scala} | 24 ++--- 13 files changed, 104 insertions(+), 104 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{ScalaUdf.scala => ScalaUDF.scala} (99%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{pythonUdfs.scala => pythonUDFs.scala} (99%) rename sql/hive/src/main/scala/org/apache/spark/sql/hive/{hiveUdfs.scala => hiveUDFs.scala} (96%) rename sql/hive/src/test/resources/data/files/{testUdf => testUDF}/part-00000 (100%) rename sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/{HiveUdfSuite.scala => HiveUDFSuite.scala} (93%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala similarity index 99% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 55df72f102295..dbb4381d54c4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types.DataType * User-defined function. * @param dataType Return type of function. */ -case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) +case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expression]) extends Expression { override def nullable: Boolean = true @@ -957,6 +957,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) override def eval(input: InternalRow): Any = converter(f(input)) - // TODO(davies): make ScalaUdf work with codegen + // TODO(davies): make ScalaUDF work with codegen override def isThreadSafe: Boolean = false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 8ed44ee141be5..fc14a77538ef1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -146,7 +146,7 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = - ExtractPythonUdfs :: + ExtractPythonUDFs :: sources.PreInsertCastAndRename :: Nil @@ -257,7 +257,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * The following example registers a Scala closure as UDF: * {{{ - * sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1) + * sqlContext.udf.register("myUDF", (arg1: Int, arg2: String) => arg2 + arg1) * }}} * * The following example registers a UDF in Java: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 3cc5c2441d8a5..03dc37aa73f0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -26,7 +26,7 @@ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType @@ -95,7 +95,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) }""") @@ -114,7 +114,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType) = { | functionRegistry.registerFunction( | name, - | (e: Seq[Expression]) => ScalaUdf(f$anyCast.call($anyParams), returnType, e)) + | (e: Seq[Expression]) => ScalaUDF(f$anyCast.call($anyParams), returnType, e)) |}""".stripMargin) } */ @@ -126,7 +126,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -138,7 +138,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -150,7 +150,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -162,7 +162,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -174,7 +174,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -186,7 +186,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -198,7 +198,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -210,7 +210,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -222,7 +222,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -234,7 +234,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -246,7 +246,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -258,7 +258,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -270,7 +270,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -282,7 +282,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -294,7 +294,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -306,7 +306,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -318,7 +318,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -330,7 +330,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -342,7 +342,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -354,7 +354,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -366,7 +366,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -378,7 +378,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -390,7 +390,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -405,7 +405,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF1[_, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e)) } /** @@ -415,7 +415,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF2[_, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e)) } /** @@ -425,7 +425,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF3[_, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e)) } /** @@ -435,7 +435,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -445,7 +445,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -455,7 +455,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -465,7 +465,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -475,7 +475,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -485,7 +485,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -495,7 +495,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -505,7 +505,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -515,7 +515,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -525,7 +525,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -535,7 +535,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -545,7 +545,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -555,7 +555,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -565,7 +565,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -575,7 +575,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -585,7 +585,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -595,7 +595,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -605,7 +605,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -615,7 +615,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index a02e202d2eebc..831eb7eb0fae9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -23,7 +23,7 @@ import org.apache.spark.Accumulator import org.apache.spark.annotation.Experimental import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.DataType case class UserDefinedFunction protected[sql] (f: AnyRef, dataType: DataType) { def apply(exprs: Column*): Column = { - Column(ScalaUdf(f, dataType, exprs.map(_.expr))) + Column(ScalaUDF(f, dataType, exprs.map(_.expr))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 036f5d253e385..9e1cff06c7eea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -69,7 +69,7 @@ private[spark] case class PythonUDF( * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { +private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Skip EvaluatePython nodes. case plan: EvaluatePython => plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5422e066afcb1..4d9a019058228 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1509,7 +1509,7 @@ object functions { (0 to 10).map { x => val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") val fTypes = Seq.fill(x + 1)("_").mkString(", ") - val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + val argsInUDF = (1 to x).map(i => s"arg$i.expr").mkString(", ") println(s""" /** * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires @@ -1521,7 +1521,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { - ScalaUdf(f, returnType, Seq($argsInUdf)) + ScalaUDF(f, returnType, Seq($argsInUDF)) }""") } } @@ -1659,7 +1659,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function0[_], returnType: DataType): Column = { - ScalaUdf(f, returnType, Seq()) + ScalaUDF(f, returnType, Seq()) } /** @@ -1672,7 +1672,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr)) } /** @@ -1685,7 +1685,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) } /** @@ -1698,7 +1698,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) } /** @@ -1711,7 +1711,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) } /** @@ -1724,7 +1724,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) } /** @@ -1737,7 +1737,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) } /** @@ -1750,7 +1750,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) } /** @@ -1763,7 +1763,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) } /** @@ -1776,7 +1776,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) } /** @@ -1789,7 +1789,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } // scalastyle:on @@ -1802,8 +1802,8 @@ object functions { * * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") * val sqlContext = df.sqlContext - * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) - * df.select($"id", callUDF("simpleUdf", $"value")) + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUDF("simpleUDF", $"value")) * }}} * * @group udf_funcs @@ -1821,8 +1821,8 @@ object functions { * * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") * val sqlContext = df.sqlContext - * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) - * df.select($"id", callUdf("simpleUdf", $"value")) + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUdf("simpleUDF", $"value")) * }}} * * @group udf_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 22c54e43c1d16..82dc0e9ce5132 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -140,9 +140,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index") // we except the id is materialized once - val idUdf = udf(() => UUID.randomUUID().toString) + val idUDF = udf(() => UUID.randomUUID().toString) - val dfWithId = df.withColumn("id", idUdf()) + val dfWithId = df.withColumn("id", idUDF()) // Make a new DataFrame (actually the same reference to the old one) val cached = dfWithId.cache() // Trigger the cache diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 8021f915bb821..b91242af2d155 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.SQLConf.SQLConfEntry._ import org.apache.spark.sql.catalyst.ParserDialect import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} import org.apache.spark.sql.sources.DataSourceStrategy @@ -381,7 +381,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.ParquetConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: - ExtractPythonUdfs :: + ExtractPythonUDFs :: ResolveHiveWindowFunction :: sources.PreInsertCastAndRename :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 7c4620952ba4b..2de7a99c122fd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1638,7 +1638,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C sys.error(s"Couldn't find function $functionName")) val functionClassName = functionInfo.getFunctionClass.getName - (HiveGenericUdtf( + (HiveGenericUDTF( new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)), attributes) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala similarity index 96% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 4986b1ea9d906..d7827d56ca8c5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -59,16 +59,16 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) val functionClassName = functionInfo.getFunctionClass.getName if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children) + HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUdaf(new HiveFunctionWrapper(functionClassName), children) + HiveUDAF(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) } else { sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") } @@ -79,7 +79,7 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) throw new UnsupportedOperationException } -private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { type UDFType = UDF @@ -146,7 +146,7 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector) override def get(): AnyRef = wrap(func(), oi) } -private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { type UDFType = GenericUDF @@ -413,7 +413,7 @@ private[hive] case class HiveWindowFunction( new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) } -private[hive] case class HiveGenericUdaf( +private[hive] case class HiveGenericUDAF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends AggregateExpression with HiveInspectors { @@ -441,11 +441,11 @@ private[hive] case class HiveGenericUdaf( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this) + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this) } /** It is used as a wrapper for the hive functions which uses UDAF interface */ -private[hive] case class HiveUdaf( +private[hive] case class HiveUDAF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends AggregateExpression with HiveInspectors { @@ -474,7 +474,7 @@ private[hive] case class HiveUdaf( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this, true) + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true) } /** @@ -488,7 +488,7 @@ private[hive] case class HiveUdaf( * Operators that require maintaining state in between input rows should instead be implemented as * user defined aggregations, which have clean semantics even in a partitioned execution. */ -private[hive] case class HiveGenericUdtf( +private[hive] case class HiveGenericUDTF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Generator with HiveInspectors { @@ -553,7 +553,7 @@ private[hive] case class HiveGenericUdtf( } } -private[hive] case class HiveUdafFunction( +private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], base: AggregateExpression, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index ea325cc93cb85..7978fdacaedba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -391,7 +391,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { * Records the UDFs present when the server starts, so we can delete ones that are created by * tests. */ - protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames + protected val originalUDFs: JavaSet[String] = FunctionRegistry.getFunctionNames /** * Resets the test instance by deleting any tables that have been created. @@ -410,7 +410,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { catalog.client.reset() catalog.unregisterAllTables() - FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName => + FunctionRegistry.getFunctionNames.filterNot(originalUDFs.contains(_)).foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } diff --git a/sql/hive/src/test/resources/data/files/testUdf/part-00000 b/sql/hive/src/test/resources/data/files/testUDF/part-00000 similarity index 100% rename from sql/hive/src/test/resources/data/files/testUdf/part-00000 rename to sql/hive/src/test/resources/data/files/testUDF/part-00000 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala similarity index 93% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index ce5985888f540..56b0bef1d0571 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -46,7 +46,7 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUdfSuite extends QueryTest { +class HiveUDFSuite extends QueryTest { import TestHive.{udf, sql} import TestHive.implicits._ @@ -73,7 +73,7 @@ class HiveUdfSuite extends QueryTest { test("hive struct udf") { sql( """ - |CREATE EXTERNAL TABLE hiveUdfTestTable ( + |CREATE EXTERNAL TABLE hiveUDFTestTable ( | pair STRUCT |) |PARTITIONED BY (partition STRING) @@ -82,15 +82,15 @@ class HiveUdfSuite extends QueryTest { """. stripMargin.format(classOf[PairSerDe].getName)) - val location = Utils.getSparkClassLoader.getResource("data/files/testUdf").getFile + val location = Utils.getSparkClassLoader.getResource("data/files/testUDF").getFile sql(s""" - ALTER TABLE hiveUdfTestTable - ADD IF NOT EXISTS PARTITION(partition='testUdf') + ALTER TABLE hiveUDFTestTable + ADD IF NOT EXISTS PARTITION(partition='testUDF') LOCATION '$location'""") - sql(s"CREATE TEMPORARY FUNCTION testUdf AS '${classOf[PairUdf].getName}'") - sql("SELECT testUdf(pair) FROM hiveUdfTestTable") - sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") + sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'") + sql("SELECT testUDF(pair) FROM hiveUDFTestTable") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") } test("SPARK-6409 UDAFAverage test") { @@ -169,11 +169,11 @@ class HiveUdfSuite extends QueryTest { StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() testData.registerTempTable("stringTable") - sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") + sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") checkAnswer( - sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), + sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") TestHive.reset() } @@ -244,7 +244,7 @@ class PairSerDe extends AbstractSerDe { } } -class PairUdf extends GenericUDF { +class PairUDF extends GenericUDF { override def initialize(p1: Array[ObjectInspector]): ObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( Seq("id", "value"), From ed359de595d5dd67b666660eddf092eaf89041c8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 29 Jun 2015 15:59:20 -0700 Subject: [PATCH 15/39] [SPARK-8579] [SQL] support arbitrary object in UnsafeRow This PR brings arbitrary object support in UnsafeRow (both in grouping key and aggregation buffer). Two object pools will be created to hold those non-primitive objects, and put the index of them into UnsafeRow. In order to compare the grouping key as bytes, the objects in key will be stored in a unique object pool, to make sure same objects will have same index (used as hashCode). For StringType and BinaryType, we still put them as var-length in UnsafeRow when initializing for better performance. But for update, they will be an object inside object pools (there will be some garbages left in the buffer). BTW: Will create a JIRA once issue.apache.org is available. cc JoshRosen rxin Author: Davies Liu Closes #6959 from davies/unsafe_obj and squashes the following commits: 5ce39da [Davies Liu] fix comment 5e797bf [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj 5803d64 [Davies Liu] fix conflict 461d304 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj 2f41c90 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj b04d69c [Davies Liu] address comments 4859b80 [Davies Liu] fix comments f38011c [Davies Liu] add a test for grouping by decimal d2cf7ab [Davies Liu] add more tests for null checking 71983c5 [Davies Liu] add test for timestamp e8a1649 [Davies Liu] reuse buffer for string 39f09ca [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj 035501e [Davies Liu] fix style 236d6de [Davies Liu] support arbitrary object in UnsafeRow --- .../UnsafeFixedWidthAggregationMap.java | 144 ++++++------ .../sql/catalyst/expressions/UnsafeRow.java | 218 +++++++++--------- .../spark/sql/catalyst/util/ObjectPool.java | 78 +++++++ .../sql/catalyst/util/UniqueObjectPool.java | 59 +++++ .../spark/sql/catalyst/InternalRow.scala | 5 +- .../expressions/UnsafeRowConverter.scala | 94 +++----- .../UnsafeFixedWidthAggregationMapSuite.scala | 65 ++++-- .../expressions/UnsafeRowConverterSuite.scala | 190 +++++++++++---- .../sql/catalyst/util/ObjectPoolSuite.scala | 57 +++++ .../sql/execution/GeneratedAggregate.scala | 16 +- 10 files changed, 615 insertions(+), 311 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 83f2a312972fb..1e79f4b2e88e5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -19,9 +19,11 @@ import java.util.Iterator; +import scala.Function1; + import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.catalyst.util.ObjectPool; +import org.apache.spark.sql.catalyst.util.UniqueObjectPool; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; @@ -38,26 +40,48 @@ public final class UnsafeFixedWidthAggregationMap { * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the * map, we copy this buffer and use it as the value. */ - private final byte[] emptyAggregationBuffer; + private final byte[] emptyBuffer; - private final StructType aggregationBufferSchema; + /** + * An empty row used by `initProjection` + */ + private static final InternalRow emptyRow = new GenericInternalRow(); - private final StructType groupingKeySchema; + /** + * Whether can the empty aggregation buffer be reuse without calling `initProjection` or not. + */ + private final boolean reuseEmptyBuffer; /** - * Encodes grouping keys as UnsafeRows. + * The projection used to initialize the emptyBuffer */ - private final UnsafeRowConverter groupingKeyToUnsafeRowConverter; + private final Function1 initProjection; + + /** + * Encodes grouping keys or buffers as UnsafeRows. + */ + private final UnsafeRowConverter keyConverter; + private final UnsafeRowConverter bufferConverter; /** * A hashmap which maps from opaque bytearray keys to bytearray values. */ private final BytesToBytesMap map; + /** + * An object pool for objects that are used in grouping keys. + */ + private final UniqueObjectPool keyPool; + + /** + * An object pool for objects that are used in aggregation buffers. + */ + private final ObjectPool bufferPool; + /** * Re-used pointer to the current aggregation buffer */ - private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); + private final UnsafeRow currentBuffer = new UnsafeRow(); /** * Scratch space that is used when encoding grouping keys into UnsafeRow format. @@ -69,68 +93,39 @@ public final class UnsafeFixedWidthAggregationMap { private final boolean enablePerfMetrics; - /** - * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, - * false otherwise. - */ - public static boolean supportsGroupKeySchema(StructType schema) { - for (StructField field: schema.fields()) { - if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { - return false; - } - } - return true; - } - - /** - * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given - * schema, false otherwise. - */ - public static boolean supportsAggregationBufferSchema(StructType schema) { - for (StructField field: schema.fields()) { - if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { - return false; - } - } - return true; - } - /** * Create a new UnsafeFixedWidthAggregationMap. * - * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) - * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. - * @param groupingKeySchema the schema of the grouping key, used for row conversion. + * @param initProjection the default value for new keys (a "zero" of the agg. function) + * @param keyConverter the converter of the grouping key, used for row conversion. + * @param bufferConverter the converter of the aggregation buffer, used for row conversion. * @param memoryManager the memory manager used to allocate our Unsafe memory structures. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) */ public UnsafeFixedWidthAggregationMap( - InternalRow emptyAggregationBuffer, - StructType aggregationBufferSchema, - StructType groupingKeySchema, + Function1 initProjection, + UnsafeRowConverter keyConverter, + UnsafeRowConverter bufferConverter, TaskMemoryManager memoryManager, int initialCapacity, boolean enablePerfMetrics) { - this.emptyAggregationBuffer = - convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema); - this.aggregationBufferSchema = aggregationBufferSchema; - this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema); - this.groupingKeySchema = groupingKeySchema; - this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.initProjection = initProjection; + this.keyConverter = keyConverter; + this.bufferConverter = bufferConverter; this.enablePerfMetrics = enablePerfMetrics; - } - /** - * Convert a Java object row into an UnsafeRow, allocating it into a new byte array. - */ - private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) { - final UnsafeRowConverter converter = new UnsafeRowConverter(schema); - final byte[] unsafeRow = new byte[converter.getSizeRequirement(javaRow)]; - final int writtenLength = - converter.writeRow(javaRow, unsafeRow, PlatformDependent.BYTE_ARRAY_OFFSET); - assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!"; - return unsafeRow; + this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.keyPool = new UniqueObjectPool(100); + this.bufferPool = new ObjectPool(initialCapacity); + + InternalRow initRow = initProjection.apply(emptyRow); + this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)]; + int writtenLength = bufferConverter.writeRow( + initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool); + assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!"; + // re-use the empty buffer only when there is no object saved in pool. + reuseEmptyBuffer = bufferPool.size() == 0; } /** @@ -138,15 +133,16 @@ private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) * return the same object. */ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { - final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey); + final int groupingKeySize = keyConverter.getSizeRequirement(groupingKey); // Make sure that the buffer is large enough to hold the key. If it's not, grow it: if (groupingKeySize > groupingKeyConversionScratchSpace.length) { groupingKeyConversionScratchSpace = new byte[groupingKeySize]; } - final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow( + final int actualGroupingKeySize = keyConverter.writeRow( groupingKey, groupingKeyConversionScratchSpace, - PlatformDependent.BYTE_ARRAY_OFFSET); + PlatformDependent.BYTE_ARRAY_OFFSET, + keyPool); assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; // Probe our map using the serialized key @@ -157,25 +153,31 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { if (!loc.isDefined()) { // This is the first time that we've seen this grouping key, so we'll insert a copy of the // empty aggregation buffer into the map: + if (!reuseEmptyBuffer) { + // There is some objects referenced by emptyBuffer, so generate a new one + InternalRow initRow = initProjection.apply(emptyRow); + bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, + bufferPool); + } loc.putNewKey( groupingKeyConversionScratchSpace, PlatformDependent.BYTE_ARRAY_OFFSET, groupingKeySize, - emptyAggregationBuffer, + emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, - emptyAggregationBuffer.length + emptyBuffer.length ); } // Reset the pointer to point to the value that we just stored or looked up: final MemoryLocation address = loc.getValueAddress(); - currentAggregationBuffer.pointTo( + currentBuffer.pointTo( address.getBaseObject(), address.getBaseOffset(), - aggregationBufferSchema.length(), - aggregationBufferSchema + bufferConverter.numFields(), + bufferPool ); - return currentAggregationBuffer; + return currentBuffer; } /** @@ -211,14 +213,14 @@ public MapEntry next() { entry.key.pointTo( keyAddress.getBaseObject(), keyAddress.getBaseOffset(), - groupingKeySchema.length(), - groupingKeySchema + keyConverter.numFields(), + keyPool ); entry.value.pointTo( valueAddress.getBaseObject(), valueAddress.getBaseOffset(), - aggregationBufferSchema.length(), - aggregationBufferSchema + bufferConverter.numFields(), + bufferPool ); return entry; } @@ -246,6 +248,8 @@ public void printPerfMetrics() { System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); + System.out.println("Number of unique objects in keys: " + keyPool.size()); + System.out.println("Number of objects in buffers: " + bufferPool.size()); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 11d51d90f1802..f077064a02ec0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,20 +17,12 @@ package org.apache.spark.sql.catalyst.expressions; -import javax.annotation.Nullable; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; - import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.types.UTF8String; -import static org.apache.spark.sql.types.DataTypes.*; /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. @@ -44,7 +36,20 @@ * primitive types, such as long, double, or int, we store the value directly in the word. For * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the * base address of the row) that points to the beginning of the variable-length field, and length - * (they are combined into a long). + * (they are combined into a long). For other objects, they are stored in a pool, the indexes of + * them are hold in the the word. + * + * In order to support fast hashing and equality checks for UnsafeRows that contain objects + * when used as grouping key in BytesToBytesMap, we put the objects in an UniqueObjectPool to make + * sure all the key have the same index for same object, then we can hash/compare the objects by + * hash/compare the index. + * + * For non-primitive types, the word of a field could be: + * UNION { + * [1] [offset: 31bits] [length: 31bits] // StringType + * [0] [offset: 31bits] [length: 31bits] // BinaryType + * - [index: 63bits] // StringType, Binary, index to object in pool + * } * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ @@ -53,8 +58,12 @@ public final class UnsafeRow extends MutableRow { private Object baseObject; private long baseOffset; + /** A pool to hold non-primitive objects */ + private ObjectPool pool; + Object getBaseObject() { return baseObject; } long getBaseOffset() { return baseOffset; } + ObjectPool getPool() { return pool; } /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ private int numFields; @@ -63,15 +72,6 @@ public final class UnsafeRow extends MutableRow { /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; - /** - * This optional schema is required if you want to call generic get() and set() methods on - * this UnsafeRow, but is optional if callers will only use type-specific getTYPE() and setTYPE() - * methods. This should be removed after the planned InternalRow / Row split; right now, it's only - * needed by the generic get() method, which is only called internally by code that accesses - * UTF8String-typed columns. - */ - @Nullable - private StructType schema; private long getFieldOffset(int ordinal) { return baseOffset + bitSetWidthInBytes + ordinal * 8L; @@ -81,42 +81,7 @@ public static int calculateBitSetWidthInBytes(int numFields) { return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; } - /** - * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) - */ - public static final Set settableFieldTypes; - - /** - * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException). - */ - public static final Set readableFieldTypes; - - // TODO: support DecimalType - static { - settableFieldTypes = Collections.unmodifiableSet( - new HashSet( - Arrays.asList(new DataType[] { - NullType, - BooleanType, - ByteType, - ShortType, - IntegerType, - LongType, - FloatType, - DoubleType, - DateType, - TimestampType - }))); - - // We support get() on a superset of the types for which we support set(): - final Set _readableFieldTypes = new HashSet( - Arrays.asList(new DataType[]{ - StringType, - BinaryType - })); - _readableFieldTypes.addAll(settableFieldTypes); - readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); - } + public static final long OFFSET_BITS = 31L; /** * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, @@ -130,22 +95,15 @@ public UnsafeRow() { } * @param baseObject the base object * @param baseOffset the offset within the base object * @param numFields the number of fields in this row - * @param schema an optional schema; this is necessary if you want to call generic get() or set() - * methods on this row, but is optional if the caller will only use type-specific - * getTYPE() and setTYPE() methods. + * @param pool the object pool to hold arbitrary objects */ - public void pointTo( - Object baseObject, - long baseOffset, - int numFields, - @Nullable StructType schema) { + public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) { assert numFields >= 0 : "numFields should >= 0"; - assert schema == null || schema.fields().length == numFields; this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; this.numFields = numFields; - this.schema = schema; + this.pool = pool; } private void assertIndexIsValid(int index) { @@ -168,9 +126,68 @@ private void setNotNullAt(int i) { BitSetMethods.unset(baseObject, baseOffset, i); } + /** + * Updates the column `i` as Object `value`, which cannot be primitive types. + */ @Override - public void update(int ordinal, Object value) { - throw new UnsupportedOperationException(); + public void update(int i, Object value) { + if (value == null) { + if (!isNullAt(i)) { + // remove the old value from pool + long idx = getLong(i); + if (idx <= 0) { + // this is the index of old value in pool, remove it + pool.replace((int)-idx, null); + } else { + // there will be some garbage left (UTF8String or byte[]) + } + setNullAt(i); + } + return; + } + + if (isNullAt(i)) { + // there is not an old value, put the new value into pool + int idx = pool.put(value); + setLong(i, (long)-idx); + } else { + // there is an old value, check the type, then replace it or update it + long v = getLong(i); + if (v <= 0) { + // it's the index in the pool, replace old value with new one + int idx = (int)-v; + pool.replace(idx, value); + } else { + // old value is UTF8String or byte[], try to reuse the space + boolean isString; + byte[] newBytes; + if (value instanceof UTF8String) { + newBytes = ((UTF8String) value).getBytes(); + isString = true; + } else { + newBytes = (byte[]) value; + isString = false; + } + int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE); + int oldLength = (int) (v & Integer.MAX_VALUE); + if (newBytes.length <= oldLength) { + // the new value can fit in the old buffer, re-use it + PlatformDependent.copyMemory( + newBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + baseOffset + offset, + newBytes.length); + long flag = isString ? 1L << (OFFSET_BITS * 2) : 0L; + setLong(i, flag | (((long) offset) << OFFSET_BITS) | (long) newBytes.length); + } else { + // Cannot fit in the buffer + int idx = pool.put(value); + setLong(i, (long) -idx); + } + } + } + setNotNullAt(i); } @Override @@ -227,28 +244,38 @@ public int size() { return numFields; } - @Override - public StructType schema() { - return schema; - } - + /** + * Returns the object for column `i`, which should not be primitive type. + */ @Override public Object get(int i) { assertIndexIsValid(i); - assert (schema != null) : "Schema must be defined when calling generic get() method"; - final DataType dataType = schema.fields()[i].dataType(); - // UnsafeRow is only designed to be invoked by internal code, which only invokes this generic - // get() method when trying to access UTF8String-typed columns. If we refactor the codebase to - // separate the internal and external row interfaces, then internal code can fetch strings via - // a new getUTF8String() method and we'll be able to remove this method. if (isNullAt(i)) { return null; - } else if (dataType == StringType) { - return getUTF8String(i); - } else if (dataType == BinaryType) { - return getBinary(i); + } + long v = PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i)); + if (v <= 0) { + // It's an index to object in the pool. + int idx = (int)-v; + return pool.get(idx); } else { - throw new UnsupportedOperationException(); + // The column could be StingType or BinaryType + boolean isString = (v >> (OFFSET_BITS * 2)) > 0; + int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE); + int size = (int) (v & Integer.MAX_VALUE); + final byte[] bytes = new byte[size]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + offset, + bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + size + ); + if (isString) { + return UTF8String.fromBytes(bytes); + } else { + return bytes; + } } } @@ -308,31 +335,6 @@ public double getDouble(int i) { } } - public UTF8String getUTF8String(int i) { - return UTF8String.fromBytes(getBinary(i)); - } - - public byte[] getBinary(int i) { - assertIndexIsValid(i); - final long offsetAndSize = getLong(i); - final int offset = (int)(offsetAndSize >> 32); - final int size = (int)(offsetAndSize & ((1L << 32) - 1)); - final byte[] bytes = new byte[size]; - PlatformDependent.copyMemory( - baseObject, - baseOffset + offset, - bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - size - ); - return bytes; - } - - @Override - public String getString(int i) { - return getUTF8String(i).toString(); - } - @Override public InternalRow copy() { throw new UnsupportedOperationException(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java new file mode 100644 index 0000000000000..97f89a7d0b758 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util; + +/** + * A object pool stores a collection of objects in array, then they can be referenced by the + * pool plus an index. + */ +public class ObjectPool { + + /** + * An array to hold objects, which will grow as needed. + */ + private Object[] objects; + + /** + * How many objects in the pool. + */ + private int numObj; + + public ObjectPool(int capacity) { + objects = new Object[capacity]; + numObj = 0; + } + + /** + * Returns how many objects in the pool. + */ + public int size() { + return numObj; + } + + /** + * Returns the object at position `idx` in the array. + */ + public Object get(int idx) { + assert (idx < numObj); + return objects[idx]; + } + + /** + * Puts an object `obj` at the end of array, returns the index of it. + *

+ * The array will grow as needed. + */ + public int put(Object obj) { + if (numObj >= objects.length) { + Object[] tmp = new Object[objects.length * 2]; + System.arraycopy(objects, 0, tmp, 0, objects.length); + objects = tmp; + } + objects[numObj++] = obj; + return numObj - 1; + } + + /** + * Replaces the object at `idx` with new one `obj`. + */ + public void replace(int idx, Object obj) { + assert (idx < numObj); + objects[idx] = obj; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java new file mode 100644 index 0000000000000..d512392dcaacc --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util; + +import java.util.HashMap; + +/** + * An unique object pool stores a collection of unique objects in it. + */ +public class UniqueObjectPool extends ObjectPool { + + /** + * A hash map from objects to their indexes in the array. + */ + private HashMap objIndex; + + public UniqueObjectPool(int capacity) { + super(capacity); + objIndex = new HashMap(); + } + + /** + * Put an object `obj` into the pool. If there is an existing object equals to `obj`, it will + * return the index of the existing one. + */ + @Override + public int put(Object obj) { + if (objIndex.containsKey(obj)) { + return objIndex.get(obj); + } else { + int idx = super.put(obj); + objIndex.put(obj, idx); + return idx; + } + } + + /** + * The objects can not be replaced. + */ + @Override + public void replace(int idx, Object obj) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 61a29c89d8df3..57de0f26a9720 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -28,7 +28,10 @@ import org.apache.spark.unsafe.types.UTF8String abstract class InternalRow extends Row { // This is only use for test - override def getString(i: Int): String = getAs[UTF8String](i).toString + override def getString(i: Int): String = { + val str = getAs[UTF8String](i) + if (str != null) str.toString else null + } // These expensive API should not be used internally. final override def getDecimal(i: Int): java.math.BigDecimal = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index b61d490429e4f..b11fc245c4af9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.util.ObjectPool import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods @@ -33,6 +34,8 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { this(schema.fields.map(_.dataType)) } + def numFields: Int = fieldTypes.length + /** Re-used pointer to the unsafe row being written */ private[this] val unsafeRow = new UnsafeRow() @@ -68,8 +71,8 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { * @param baseOffset the base offset of the destination address * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. */ - def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long): Int = { - unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) + def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool) if (writers.length > 0) { // zero-out the bitset @@ -84,16 +87,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { } var fieldNumber = 0 - var appendCursor: Int = fixedLengthSize + var cursor: Int = fixedLengthSize while (fieldNumber < writers.length) { if (row.isNullAt(fieldNumber)) { unsafeRow.setNullAt(fieldNumber) } else { - appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor) + cursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, cursor) } fieldNumber += 1 } - appendCursor + cursor } } @@ -108,11 +111,11 @@ private abstract class UnsafeColumnWriter { * @param source the row being converted * @param target a pointer to the converted unsafe row * @param column the column to write - * @param appendCursor the offset from the start of the unsafe row to the end of the row; + * @param cursor the offset from the start of the unsafe row to the end of the row; * used for calculating where variable-length data should be written * @return the number of variable-length bytes written */ - def write(source: InternalRow, target: UnsafeRow, column: Int, appendCursor: Int): Int + def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int /** * Return the number of bytes that are needed to write this variable-length value. @@ -134,8 +137,7 @@ private object UnsafeColumnWriter { case DoubleType => DoubleUnsafeColumnWriter case StringType => StringUnsafeColumnWriter case BinaryType => BinaryUnsafeColumnWriter - case t => - throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") + case t => ObjectUnsafeColumnWriter } } } @@ -152,6 +154,7 @@ private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter +private object ObjectUnsafeColumnWriter extends ObjectUnsafeColumnWriter private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { // Primitives don't write to the variable-length region: @@ -159,88 +162,56 @@ private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { } private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setNullAt(column) 0 } } private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setBoolean(column, source.getBoolean(column)) 0 } } private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setByte(column, source.getByte(column)) 0 } } private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setShort(column, source.getShort(column)) 0 } } private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setInt(column, source.getInt(column)) 0 } } private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setLong(column, source.getLong(column)) 0 } } private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setFloat(column, source.getFloat(column)) 0 } } private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setDouble(column, source.getDouble(column)) 0 } @@ -255,12 +226,10 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { - val offset = target.getBaseOffset + appendCursor + protected[this] def isString: Boolean + + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { + val offset = target.getBaseOffset + cursor val bytes = getBytes(source, column) val numBytes = bytes.length if ((numBytes & 0x07) > 0) { @@ -274,19 +243,32 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { offset, numBytes ) - target.setLong(column, (appendCursor.toLong << 32L) | numBytes.toLong) + val flag = if (isString) 1L << (UnsafeRow.OFFSET_BITS * 2) else 0 + target.setLong(column, flag | (cursor.toLong << UnsafeRow.OFFSET_BITS) | numBytes.toLong) ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } } private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + protected[this] def isString: Boolean = true def getBytes(source: InternalRow, column: Int): Array[Byte] = { source.getAs[UTF8String](column).getBytes } } private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + protected[this] def isString: Boolean = false def getBytes(source: InternalRow, column: Int): Array[Byte] = { source.getAs[Array[Byte]](column) } } + +private class ObjectUnsafeColumnWriter private() extends UnsafeColumnWriter { + def getSize(sourceRow: InternalRow, column: Int): Int = 0 + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { + val obj = source.get(column) + val idx = target.getPool.put(obj) + target.setLong(column, - idx) + 0 + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 3095ccb77761b..6fafc2f86684c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,8 +23,9 @@ import scala.util.Random import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.unsafe.types.UTF8String @@ -33,10 +34,10 @@ class UnsafeFixedWidthAggregationMapSuite with Matchers with BeforeAndAfterEach { - import UnsafeFixedWidthAggregationMap._ - private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) + private def emptyProjection: Projection = + GenerateProjection.generate(Seq(Literal(0)), Seq(AttributeReference("price", IntegerType)())) private def emptyAggregationBuffer: InternalRow = InternalRow(0) private var memoryManager: TaskMemoryManager = null @@ -52,21 +53,11 @@ class UnsafeFixedWidthAggregationMapSuite } } - test("supported schemas") { - assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) - assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) - - assert( - !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) - assert( - !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) - } - test("empty map") { val map = new UnsafeFixedWidthAggregationMap( - emptyAggregationBuffer, - aggBufferSchema, - groupKeySchema, + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), memoryManager, 1024, // initial capacity false // disable perf metrics @@ -77,9 +68,9 @@ class UnsafeFixedWidthAggregationMapSuite test("updating values for a single key") { val map = new UnsafeFixedWidthAggregationMap( - emptyAggregationBuffer, - aggBufferSchema, - groupKeySchema, + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), memoryManager, 1024, // initial capacity false // disable perf metrics @@ -103,9 +94,9 @@ class UnsafeFixedWidthAggregationMapSuite test("inserting large random keys") { val map = new UnsafeFixedWidthAggregationMap( - emptyAggregationBuffer, - aggBufferSchema, - groupKeySchema, + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), memoryManager, 128, // initial capacity false // disable perf metrics @@ -120,6 +111,36 @@ class UnsafeFixedWidthAggregationMapSuite }.toSet seenKeys.size should be (groupKeys.size) seenKeys should be (groupKeys) + + map.free() + } + + test("with decimal in the key and values") { + val groupKeySchema = StructType(StructField("price", DecimalType(10, 0)) :: Nil) + val aggBufferSchema = StructType(StructField("amount", DecimalType.Unlimited) :: Nil) + val emptyProjection = GenerateProjection.generate(Seq(Literal(Decimal(0))), + Seq(AttributeReference("price", DecimalType.Unlimited)())) + val map = new UnsafeFixedWidthAggregationMap( + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), + memoryManager, + 1, // initial capacity + false // disable perf metrics + ) + + (0 until 100).foreach { i => + val groupKey = InternalRow(Decimal(i % 10)) + val row = map.getAggregationBuffer(groupKey) + row.update(0, Decimal(i)) + } + val seenKeys: Set[Int] = map.iterator().asScala.map { entry => + entry.key.getAs[Decimal](0).toInt + }.toSet + seenKeys.size should be (10) + seenKeys should be ((0 until 10).toSet) + + map.free() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index c0675f4f4dff6..94c2f3242b122 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -23,10 +23,11 @@ import java.util.Arrays import org.scalatest.Matchers import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{ObjectPool, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.types.UTF8String class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { @@ -40,16 +41,21 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.setInt(2, 2) val sizeRequired: Int = converter.getSizeRequirement(row) - sizeRequired should be (8 + (3 * 8)) + assert(sizeRequired === 8 + (3 * 8)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) - unsafeRow.getLong(0) should be (0) - unsafeRow.getLong(1) should be (1) - unsafeRow.getInt(2) should be (2) + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.getLong(1) === 1) + assert(unsafeRow.getInt(2) === 2) + + unsafeRow.setLong(1, 3) + assert(unsafeRow.getLong(1) === 3) + unsafeRow.setInt(2, 4) + assert(unsafeRow.getInt(2) === 4) } test("basic conversion with primitive, string and binary types") { @@ -58,22 +64,67 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) - row.setString(1, "Hello") + row.update(1, UTF8String.fromString("Hello")) row.update(2, "World".getBytes) val sizeRequired: Int = converter.getSizeRequirement(row) - sizeRequired should be (8 + (8 * 3) + + assert(sizeRequired === 8 + (8 * 3) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) - unsafeRow.getLong(0) should be (0) - unsafeRow.getString(1) should be ("Hello") - unsafeRow.getBinary(2) should be ("World".getBytes) + val pool = new ObjectPool(10) + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.getString(1) === "Hello") + assert(unsafeRow.get(2) === "World".getBytes) + + unsafeRow.update(1, UTF8String.fromString("World")) + assert(unsafeRow.getString(1) === "World") + assert(pool.size === 0) + unsafeRow.update(1, UTF8String.fromString("Hello World")) + assert(unsafeRow.getString(1) === "Hello World") + assert(pool.size === 1) + + unsafeRow.update(2, "World".getBytes) + assert(unsafeRow.get(2) === "World".getBytes) + assert(pool.size === 1) + unsafeRow.update(2, "Hello World".getBytes) + assert(unsafeRow.get(2) === "Hello World".getBytes) + assert(pool.size === 2) + } + + test("basic conversion with primitive, decimal and array") { + val fieldTypes: Array[DataType] = Array(LongType, DecimalType(10, 0), ArrayType(StringType)) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.update(1, Decimal(1)) + row.update(2, Array(2)) + + val pool = new ObjectPool(10) + val sizeRequired: Int = converter.getSizeRequirement(row) + assert(sizeRequired === 8 + (8 * 3)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) + assert(numBytesWritten === sizeRequired) + assert(pool.size === 2) + + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.get(1) === Decimal(1)) + assert(unsafeRow.get(2) === Array(2)) + + unsafeRow.update(1, Decimal(2)) + assert(unsafeRow.get(1) === Decimal(2)) + unsafeRow.update(2, Array(3, 4)) + assert(unsafeRow.get(2) === Array(3, 4)) + assert(pool.size === 2) } test("basic conversion with primitive, string, date and timestamp types") { @@ -87,21 +138,27 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25"))) val sizeRequired: Int = converter.getSizeRequirement(row) - sizeRequired should be (8 + (8 * 4) + + assert(sizeRequired === 8 + (8 * 4) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) - unsafeRow.getLong(0) should be (0) - unsafeRow.getString(1) should be ("Hello") + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.getString(1) === "Hello") // Date is represented as Int in unsafeRow - DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) should be (Date.valueOf("1970-01-01")) + assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("1970-01-01")) // Timestamp is represented as Long in unsafeRow DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be (Timestamp.valueOf("2015-05-08 08:10:25")) + + unsafeRow.setInt(2, DateTimeUtils.fromJavaDate(Date.valueOf("2015-06-22"))) + assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("2015-06-22")) + unsafeRow.setLong(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-06-22 08:10:25"))) + DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be + (Timestamp.valueOf("2015-06-22 08:10:25")) } test("null handling") { @@ -113,7 +170,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { IntegerType, LongType, FloatType, - DoubleType) + DoubleType, + StringType, + BinaryType, + DecimalType.Unlimited, + ArrayType(IntegerType) + ) val converter = new UnsafeRowConverter(fieldTypes) val rowWithAllNullColumns: InternalRow = { @@ -127,8 +189,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow( - rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val createdFromNull = new UnsafeRow() createdFromNull.pointTo( @@ -136,13 +198,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { for (i <- 0 to fieldTypes.length - 1) { assert(createdFromNull.isNullAt(i)) } - createdFromNull.getBoolean(1) should be (false) - createdFromNull.getByte(2) should be (0) - createdFromNull.getShort(3) should be (0) - createdFromNull.getInt(4) should be (0) - createdFromNull.getLong(5) should be (0) + assert(createdFromNull.getBoolean(1) === false) + assert(createdFromNull.getByte(2) === 0) + assert(createdFromNull.getShort(3) === 0) + assert(createdFromNull.getInt(4) === 0) + assert(createdFromNull.getLong(5) === 0) assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) - assert(java.lang.Double.isNaN(createdFromNull.getFloat(7))) + assert(java.lang.Double.isNaN(createdFromNull.getDouble(7))) + assert(createdFromNull.getString(8) === null) + assert(createdFromNull.get(9) === null) + assert(createdFromNull.get(10) === null) + assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those // columns, then the serialized row representation should be identical to what we would get by @@ -157,28 +223,68 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.setLong(5, 500) r.setFloat(6, 600) r.setDouble(7, 700) + r.update(8, UTF8String.fromString("hello")) + r.update(9, "world".getBytes) + r.update(10, Decimal(10)) + r.update(11, Array(11)) r } - val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8) + val pool = new ObjectPool(1) + val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2) converter.writeRow( - rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) val setToNullAfterCreation = new UnsafeRow() setToNullAfterCreation.pointTo( - setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) - setToNullAfterCreation.isNullAt(0) should be (rowWithNoNullColumns.isNullAt(0)) - setToNullAfterCreation.getBoolean(1) should be (rowWithNoNullColumns.getBoolean(1)) - setToNullAfterCreation.getByte(2) should be (rowWithNoNullColumns.getByte(2)) - setToNullAfterCreation.getShort(3) should be (rowWithNoNullColumns.getShort(3)) - setToNullAfterCreation.getInt(4) should be (rowWithNoNullColumns.getInt(4)) - setToNullAfterCreation.getLong(5) should be (rowWithNoNullColumns.getLong(5)) - setToNullAfterCreation.getFloat(6) should be (rowWithNoNullColumns.getFloat(6)) - setToNullAfterCreation.getDouble(7) should be (rowWithNoNullColumns.getDouble(7)) + assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) + assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) + assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2)) + assert(setToNullAfterCreation.getShort(3) === rowWithNoNullColumns.getShort(3)) + assert(setToNullAfterCreation.getInt(4) === rowWithNoNullColumns.getInt(4)) + assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5)) + assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) + assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) + assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) + assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) + assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) for (i <- 0 to fieldTypes.length - 1) { + if (i >= 8) { + setToNullAfterCreation.update(i, null) + } setToNullAfterCreation.setNullAt(i) } - assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer)) + // There are some garbage left in the var-length area + assert(Arrays.equals(createdFromNullBuffer, + java.util.Arrays.copyOf(setToNullAfterCreationBuffer, sizeRequired / 8))) + + setToNullAfterCreation.setNullAt(0) + setToNullAfterCreation.setBoolean(1, false) + setToNullAfterCreation.setByte(2, 20) + setToNullAfterCreation.setShort(3, 30) + setToNullAfterCreation.setInt(4, 400) + setToNullAfterCreation.setLong(5, 500) + setToNullAfterCreation.setFloat(6, 600) + setToNullAfterCreation.setDouble(7, 700) + setToNullAfterCreation.update(8, UTF8String.fromString("hello")) + setToNullAfterCreation.update(9, "world".getBytes) + setToNullAfterCreation.update(10, Decimal(10)) + setToNullAfterCreation.update(11, Array(11)) + + assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) + assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) + assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2)) + assert(setToNullAfterCreation.getShort(3) === rowWithNoNullColumns.getShort(3)) + assert(setToNullAfterCreation.getInt(4) === rowWithNoNullColumns.getInt(4)) + assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5)) + assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) + assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) + assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) + assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) + assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala new file mode 100644 index 0000000000000..94764df4b9cdb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class ObjectPoolSuite extends SparkFunSuite with Matchers { + + test("pool") { + val pool = new ObjectPool(1) + assert(pool.put(1) === 0) + assert(pool.put("hello") === 1) + assert(pool.put(false) === 2) + + assert(pool.get(0) === 1) + assert(pool.get(1) === "hello") + assert(pool.get(2) === false) + assert(pool.size() === 3) + + pool.replace(1, "world") + assert(pool.get(1) === "world") + assert(pool.size() === 3) + } + + test("unique pool") { + val pool = new UniqueObjectPool(1) + assert(pool.put(1) === 0) + assert(pool.put("hello") === 1) + assert(pool.put(1) === 0) + assert(pool.put("hello") === 1) + + assert(pool.get(0) === 1) + assert(pool.get(1) === "hello") + assert(pool.size() === 2) + + intercept[UnsupportedOperationException] { + pool.replace(1, "world") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index ba2c8f53d702d..44930f82b53a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -238,11 +238,6 @@ case class GeneratedAggregate( StructType(fields) } - val schemaSupportsUnsafe: Boolean = { - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) - } - child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. val initialValues = computeFunctions.flatMap(_.initialValues) @@ -283,12 +278,12 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) - } else if (unsafeEnabled && schemaSupportsUnsafe) { + } else if (unsafeEnabled) { log.info("Using Unsafe-based aggregator") val aggregationMap = new UnsafeFixedWidthAggregationMap( - newAggregationBuffer(EmptyRow), - aggregationBufferSchema, - groupKeySchema, + newAggregationBuffer, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggregationBufferSchema), TaskContext.get.taskMemoryManager(), 1024 * 16, // initial capacity false // disable tracking of performance metrics @@ -323,9 +318,6 @@ case class GeneratedAggregate( } } } else { - if (unsafeEnabled) { - log.info("Not using Unsafe-based aggregator because it is not supported for this schema") - } val buffers = new java.util.HashMap[InternalRow, MutableRow]() var currentRow: InternalRow = null From 4e880cf5967c0933e1d098a1d1f7db34b23ca8f8 Mon Sep 17 00:00:00 2001 From: Rosstin Date: Mon, 29 Jun 2015 16:09:29 -0700 Subject: [PATCH 16/39] [SPARK-8661][ML] for LinearRegressionSuite.scala, changed javadoc-style comments to regular multiline comments, to make copy-pasting R code more simple for mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala, changed javadoc-style comments to regular multiline comments, to make copy-pasting R code more simple Author: Rosstin Closes #7098 from Rosstin/SPARK-8661 and squashes the following commits: 5a05dee [Rosstin] SPARK-8661 for LinearRegressionSuite.scala, changed javadoc-style comments to regular multiline comments to make it easier to copy-paste the R code. bb9a4b1 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8660 242aedd [Rosstin] SPARK-8660, changed comment style from JavaDoc style to normal multiline comment in order to make copypaste into R easier, in file classification/LogisticRegressionSuite.scala 2cd2985 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8639 21ac1e5 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8639 6c18058 [Rosstin] fixed minor typos in docs/README.md and docs/api.md --- .../ml/regression/LinearRegressionSuite.scala | 192 +++++++++--------- 1 file changed, 96 insertions(+), 96 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index ad1e9da692ee2..5f39d44f37352 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -28,26 +28,26 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ @transient var datasetWithoutIntercept: DataFrame = _ - /** - * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML - * is the same as the one trained by R's glmnet package. The following instruction - * describes how to reproduce the data in R. - * - * import org.apache.spark.mllib.util.LinearDataGenerator - * val data = - * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), - * Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2) - * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1) - * .saveAsTextFile("path") + /* + In `LinearRegressionSuite`, we will make sure that the model trained by SparkML + is the same as the one trained by R's glmnet package. The following instruction + describes how to reproduce the data in R. + + import org.apache.spark.mllib.util.LinearDataGenerator + val data = + sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), + Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2) + data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1) + .saveAsTextFile("path") */ override def beforeAll(): Unit = { super.beforeAll() dataset = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) - /** - * datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating - * training model without intercept + /* + datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating + training model without intercept */ datasetWithoutIntercept = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( @@ -59,20 +59,20 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = new LinearRegression val model = trainer.fit(dataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * library("glmnet") - * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) - * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) - * label <- as.numeric(data$V1) - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.300528 - * as.numeric.data.V2. 4.701024 - * as.numeric.data.V3. 7.198257 + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + label <- as.numeric(data$V1) + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.300528 + as.numeric.data.V2. 4.701024 + as.numeric.data.V3. 7.198257 */ val interceptR = 6.298698 val weightsR = Array(4.700706, 7.199082) @@ -94,29 +94,29 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model = trainer.fit(dataset) val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, - * intercept = FALSE)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * as.numeric.data.V2. 6.995908 - * as.numeric.data.V3. 5.275131 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, + intercept = FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.995908 + as.numeric.data.V3. 5.275131 */ val weightsR = Array(6.995908, 5.275131) assert(model.intercept ~== 0 relTol 1E-3) assert(model.weights(0) ~== weightsR(0) relTol 1E-3) assert(model.weights(1) ~== weightsR(1) relTol 1E-3) - /** - * Then again with the data with no intercept: - * > weightsWithoutIntercept - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * as.numeric.data3.V2. 4.70011 - * as.numeric.data3.V3. 7.19943 + /* + Then again with the data with no intercept: + > weightsWithoutIntercept + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data3.V2. 4.70011 + as.numeric.data3.V3. 7.19943 */ val weightsWithoutInterceptR = Array(4.70011, 7.19943) @@ -129,14 +129,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.24300 - * as.numeric.data.V2. 4.024821 - * as.numeric.data.V3. 6.679841 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.24300 + as.numeric.data.V2. 4.024821 + as.numeric.data.V3. 6.679841 */ val interceptR = 6.24300 val weightsR = Array(4.024821, 6.679841) @@ -158,15 +158,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setFitIntercept(false) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, - * intercept=FALSE)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * as.numeric.data.V2. 6.299752 - * as.numeric.data.V3. 4.772913 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, + intercept=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.299752 + as.numeric.data.V3. 4.772913 */ val interceptR = 0.0 val weightsR = Array(6.299752, 4.772913) @@ -187,14 +187,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.328062 - * as.numeric.data.V2. 3.222034 - * as.numeric.data.V3. 4.926260 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.328062 + as.numeric.data.V2. 3.222034 + as.numeric.data.V3. 4.926260 */ val interceptR = 5.269376 val weightsR = Array(3.736216, 5.712356) @@ -216,15 +216,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setFitIntercept(false) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, - * intercept = FALSE)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * as.numeric.data.V2. 5.522875 - * as.numeric.data.V3. 4.214502 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + intercept = FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 5.522875 + as.numeric.data.V3. 4.214502 */ val interceptR = 0.0 val weightsR = Array(5.522875, 4.214502) @@ -245,14 +245,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.324108 - * as.numeric.data.V2. 3.168435 - * as.numeric.data.V3. 5.200403 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.324108 + as.numeric.data.V2. 3.168435 + as.numeric.data.V3. 5.200403 */ val interceptR = 5.696056 val weightsR = Array(3.670489, 6.001122) @@ -274,15 +274,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setFitIntercept(false) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, - * intercept=FALSE)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * as.numeric.dataM.V2. 5.673348 - * as.numeric.dataM.V3. 4.322251 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, + intercept=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.dataM.V2. 5.673348 + as.numeric.dataM.V3. 4.322251 */ val interceptR = 0.0 val weightsR = Array(5.673348, 4.322251) From 4b497a724a87ef24702c2df9ec6863ee57a87c1c Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 29 Jun 2015 16:26:05 -0700 Subject: [PATCH 17/39] [SPARK-8710] [SQL] Change ScalaReflection.mirror from a val to a def. jira: https://issues.apache.org/jira/browse/SPARK-8710 Author: Yin Huai Closes #7094 from yhuai/SPARK-8710 and squashes the following commits: c854baa [Yin Huai] Change ScalaReflection.mirror from a val to a def. --- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 90698cd572de4..21b1de1ab9cb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -28,7 +28,11 @@ import org.apache.spark.sql.types._ */ object ScalaReflection extends ScalaReflection { val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe - val mirror: universe.Mirror = universe.runtimeMirror(Thread.currentThread().getContextClassLoader) + // Since we are creating a runtime mirror usign the class loader of current thread, + // we need to use def at here. So, every time we call mirror, it is using the + // class loader of the current thread. + override def mirror: universe.Mirror = + universe.runtimeMirror(Thread.currentThread().getContextClassLoader) } /** @@ -39,7 +43,7 @@ trait ScalaReflection { val universe: scala.reflect.api.Universe /** The mirror used to access types in the universe */ - val mirror: universe.Mirror + def mirror: universe.Mirror import universe._ From 881662e9c93893430756320f51cef0fc6643f681 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 29 Jun 2015 16:34:50 -0700 Subject: [PATCH 18/39] [SPARK-8589] [SQL] cleanup DateTimeUtils move date time related operations into `DateTimeUtils` and rename some methods to make it more clear. Author: Wenchen Fan Closes #6980 from cloud-fan/datetime and squashes the following commits: 9373a9d [Wenchen Fan] cleanup DateTimeUtil --- .../spark/sql/catalyst/expressions/Cast.scala | 43 ++---------- .../sql/catalyst/util/DateTimeUtils.scala | 70 +++++++++++++------ .../spark/sql/hive/hiveWriterContainers.scala | 2 +- 3 files changed, 58 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8d66968a2fc35..d69d490ad666a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} import java.sql.{Date, Timestamp} -import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -122,9 +121,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes) - case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.toString(d))) + case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, - t => UTF8String.fromString(timestampToString(DateTimeUtils.toJavaTimestamp(t)))) + t => UTF8String.fromString(DateTimeUtils.timestampToString(t))) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -183,7 +182,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case ByteType => buildCast[Byte](_, b => longToTimestamp(b.toLong)) case DateType => - buildCast[Int](_, d => DateTimeUtils.toMillisSinceEpoch(d) * 10000) + buildCast[Int](_, d => DateTimeUtils.daysToMillis(d) * 10000) // TimestampWritable.decimalToTimestamp case DecimalType() => buildCast[Decimal](_, d => decimalToTimestamp(d)) @@ -216,18 +215,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w ts / 10000000.0 } - // Converts Timestamp to string according to Hive TimestampWritable convention - private[this] def timestampToString(ts: Timestamp): String = { - val timestampString = ts.toString - val formatted = Cast.threadLocalTimestampFormat.get.format(ts) - - if (timestampString.length > 19 && timestampString.substring(19) != ".0") { - formatted + timestampString.substring(19) - } else { - formatted - } - } - // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => @@ -449,11 +436,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (DateType, StringType) => defineCodeGen(ctx, ev, c => s"""${ctx.stringType}.fromString( - org.apache.spark.sql.catalyst.util.DateTimeUtils.toString($c))""") - // Special handling required for timestamps in hive test cases since the toString function - // does not match the expected output. + org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""") case (TimestampType, StringType) => - super.genCode(ctx, ev) + defineCodeGen(ctx, ev, c => + s"""${ctx.stringType}.fromString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""") case (_, StringType) => defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))") @@ -477,19 +464,3 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } } - -object Cast { - // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { - override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") - } - } - - // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { - override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd") - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index ff79884a44d00..640e67e2ecd76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} -import java.text.SimpleDateFormat +import java.text.{DateFormat, SimpleDateFormat} import java.util.{Calendar, TimeZone} -import org.apache.spark.sql.catalyst.expressions.Cast - /** * Helper functions for converting between internal and external date and time representations. * Dates are exposed externally as java.sql.Date and are represented internally as the number of @@ -41,35 +39,53 @@ object DateTimeUtils { // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. - private val LOCAL_TIMEZONE = new ThreadLocal[TimeZone] { + private val threadLocalLocalTimeZone = new ThreadLocal[TimeZone] { override protected def initialValue: TimeZone = { Calendar.getInstance.getTimeZone } } - private def javaDateToDays(d: Date): Int = { - millisToDays(d.getTime) + // `SimpleDateFormat` is not thread-safe. + private val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { + override def initialValue(): SimpleDateFormat = { + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + } + } + + // `SimpleDateFormat` is not thread-safe. + private val threadLocalDateFormat = new ThreadLocal[DateFormat] { + override def initialValue(): SimpleDateFormat = { + new SimpleDateFormat("yyyy-MM-dd") + } } + // we should use the exact day as Int, for example, (year, month, day) -> day def millisToDays(millisLocal: Long): Int = { - ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt + ((millisLocal + threadLocalLocalTimeZone.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt } - def toMillisSinceEpoch(days: Int): Long = { + // reverse of millisToDays + def daysToMillis(days: Int): Long = { val millisUtc = days.toLong * MILLIS_PER_DAY - millisUtc - LOCAL_TIMEZONE.get().getOffset(millisUtc) + millisUtc - threadLocalLocalTimeZone.get().getOffset(millisUtc) } - def fromJavaDate(date: Date): Int = { - javaDateToDays(date) - } + def dateToString(days: Int): String = + threadLocalDateFormat.get.format(toJavaDate(days)) - def toJavaDate(daysSinceEpoch: Int): Date = { - new Date(toMillisSinceEpoch(daysSinceEpoch)) - } + // Converts Timestamp to string according to Hive TimestampWritable convention. + def timestampToString(num100ns: Long): String = { + val ts = toJavaTimestamp(num100ns) + val timestampString = ts.toString + val formatted = threadLocalTimestampFormat.get.format(ts) - def toString(days: Int): String = Cast.threadLocalDateFormat.get.format(toJavaDate(days)) + if (timestampString.length > 19 && timestampString.substring(19) != ".0") { + formatted + timestampString.substring(19) + } else { + formatted + } + } def stringToTime(s: String): java.util.Date = { if (!s.contains('T')) { @@ -100,7 +116,21 @@ object DateTimeUtils { } /** - * Return a java.sql.Timestamp from number of 100ns since epoch + * Returns the number of days since epoch from from java.sql.Date. + */ + def fromJavaDate(date: Date): Int = { + millisToDays(date.getTime) + } + + /** + * Returns a java.sql.Date from number of days since epoch. + */ + def toJavaDate(daysSinceEpoch: Int): Date = { + new Date(daysToMillis(daysSinceEpoch)) + } + + /** + * Returns a java.sql.Timestamp from number of 100ns since epoch. */ def toJavaTimestamp(num100ns: Long): Timestamp = { // setNanos() will overwrite the millisecond part, so the milliseconds should be @@ -118,7 +148,7 @@ object DateTimeUtils { } /** - * Return the number of 100ns since epoch from java.sql.Timestamp. + * Returns the number of 100ns since epoch from java.sql.Timestamp. */ def fromJavaTimestamp(t: Timestamp): Long = { if (t != null) { @@ -129,7 +159,7 @@ object DateTimeUtils { } /** - * Return the number of 100ns (hundred of nanoseconds) since epoch from Julian day + * Returns the number of 100ns (hundred of nanoseconds) since epoch from Julian day * and nanoseconds in a day */ def fromJulianDay(day: Int, nanoseconds: Long): Long = { @@ -139,7 +169,7 @@ object DateTimeUtils { } /** - * Return Julian day and nanoseconds in a day from the number of 100ns (hundred of nanoseconds) + * Returns Julian day and nanoseconds in a day from the number of 100ns (hundred of nanoseconds) */ def toJulianDay(num100ns: Long): (Int, Long) = { val seconds = num100ns / HUNDRED_NANOS_PER_SECOND + SECONDS_PER_DAY / 2 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ab75b12e2a2e7..ecc78a5f8d321 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -201,7 +201,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( def convertToHiveRawString(col: String, value: Any): String = { val raw = String.valueOf(value) schema(col).dataType match { - case DateType => DateTimeUtils.toString(raw.toInt) + case DateType => DateTimeUtils.dateToString(raw.toInt) case _: DecimalType => BigDecimal(raw).toString() case _ => raw } From cec98525fd2b731cb78935bf7bc6c7963411744e Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 29 Jun 2015 17:19:05 -0700 Subject: [PATCH 19/39] [SPARK-8634] [STREAMING] [TESTS] Fix flaky test StreamingListenerSuite "receiver info reporting" As per the unit test log in https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/35754/ ``` 15/06/24 23:09:10.210 Thread-3495 INFO ReceiverTracker: Starting 1 receivers 15/06/24 23:09:10.270 Thread-3495 INFO SparkContext: Starting job: apply at Transformer.scala:22 ... 15/06/24 23:09:14.259 ForkJoinPool-4-worker-29 INFO StreamingListenerSuiteReceiver: Started receiver and sleeping 15/06/24 23:09:14.270 ForkJoinPool-4-worker-29 INFO StreamingListenerSuiteReceiver: Reporting error and sleeping ``` it needs at least 4 seconds to receive all receiver events in this slow machine, but `timeout` for `eventually` is only 2 seconds. This PR increases `timeout` to make this test stable. Author: zsxwing Closes #7017 from zsxwing/SPARK-8634 and squashes the following commits: 719cae4 [zsxwing] Fix flaky test StreamingListenerSuite "receiver info reporting" --- .../org/apache/spark/streaming/StreamingListenerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 1dc8960d60528..7bc7727a9fbe4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -116,7 +116,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { ssc.start() try { - eventually(timeout(2000 millis), interval(20 millis)) { + eventually(timeout(30 seconds), interval(20 millis)) { collector.startedReceiverStreamIds.size should equal (1) collector.startedReceiverStreamIds(0) should equal (0) collector.stoppedReceiverStreamIds should have size 1 From fbf75738feddebb352d5cedf503b573105d4b7a7 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 29 Jun 2015 17:20:05 -0700 Subject: [PATCH 20/39] [SPARK-7287] [SPARK-8567] [TEST] Add sc.stop to applications in SparkSubmitSuite Hopefully, this suite will not be flaky anymore. Author: Yin Huai Closes #7027 from yhuai/SPARK-8567 and squashes the following commits: c0167e2 [Yin Huai] Add sc.stop(). --- .../spark/deploy/SparkSubmitSuite.scala | 2 ++ .../regression-test-SPARK-8489/Main.scala | 1 + .../regression-test-SPARK-8489/test.jar | Bin 6811 -> 6828 bytes 3 files changed, 3 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 357ed90be3f5c..2e05dec99b6bf 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -548,6 +548,7 @@ object JarCreationTest extends Logging { if (result.nonEmpty) { throw new Exception("Could not load user class from jar:\n" + result(0)) } + sc.stop() } } @@ -573,6 +574,7 @@ object SimpleApplicationTest { s"Master had $config=$masterValue but executor had $config=$executorValue") } } + sc.stop() } } diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala index e1715177e3f1b..0e428ba1d7456 100644 --- a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala +++ b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala @@ -38,6 +38,7 @@ object Main { val df = hc.createDataFrame(Seq(MyCoolClass("1", "2", "3"))) df.collect() println("Regression test for SPARK-8489 success!") + sc.stop() } } diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar index 4f59fba9eab558131b2587e51b7c2e2d54348bd1..5944aa6076a5fe7a8188c947fd6847c046614101 100644 GIT binary patch delta 1819 zcmY+Fdpy&N8^^!q&d6nHbB(cs!yGh4h#1w-VbXpXd4G^E~h8`RDoEK~QJrZcO7ZR4bX zovjUayHIFvIB4u~enzUU90>C?Xn`=q87-R_p_>v%U9o!iPEZU_oi;@Ykr^AW3Krfy zmcQC-)r+LqTpB$vO1<1}<6pm(VB^$y5NwOaqD4bQ%lmWmu_TxnX~D2i%w^e2>i zBOIKNgW?}+>8K5g4A!-qxZ$Kfc4x3WT+su)P9o8Af=%SWJ_dAhA>7`Spxwjw*L=tp z9;xt`sDXKE+gBR&q{>abkMI;lqA`#0&u&#K?nG~~2SL~xC$!ZoPG6>oxnO9b)VQi& zmifq4VngbH11qpDt`&zCx4@IG=F}QAi@I$SGpebd;+yvDk>Gqh^^Zr+q;E)Sv3G7< zko3JoJ{R~{d!D?tt*=?5 zjdP+JG;wKn??*dme5_x3F0Y#1zXD}L1jCYS(hO=-59%g zE|DI0pw6uQ+#7g=PC($_RmLU#d7X8H6M-h`If3=9>EzsO$7|4z{f~_9nBA9YHb{J_ zH;AnB{gkgED;+oHa?^f-XYFd9ih)l75ME~Vb-nkE0*gTj1gZr zmI6P`fni3aj^2urMdMXv_OVeSclGogc*Sfixun8dB%B@e;ocEqVN^kr(Y(%Hqgnr0 zO77z(l|MsGo-D2lwGQ^jFA#3aoyo0EBDQu)UgFF?DjR*9h3jz>wN@(iL8Gzm6Oo7%tZq;#+T# zrP|J!?{#grxk!)JZ|XD1&30Q;o4Hk|coLU}8p@aHUu;=Q2vEsne)5NF#5xd;WU8r* zE_dC^XVylao(UYVedTTOfTzCrYUpCr(Ul>0JJD_+{ez$OrfbrE(7)1Nf*`%iTeuQu zr{n6W=vbkV1j4V&+uAwQTmkkVIc1x&mcQ0~+&4k=6+{DP?lvD~= z@1okIWs-QA$NIkLJuau}Bv&O@r8V-p=6b-y$Qa$k1=B<3NjP)MY5d()&j{$#xgi-p z{bzWbyOr=KvA!kSVKDOK!h8n=xwz&c%Z_pzb6s2;JRZzDU;hA=c#t8zcQa#|Olh5q zVZB)#=fxz3HO!k#_g>DLI)x-@xWi~ZZQa%#iKTs_4zJ;)p`hQ;(85_<-|9~x)$3TU zX>^MHgbm**D`=hw8Q+wvO#gNBs0(kB)gAo~Z#$v=J2ss*#38DZ%2Lr$lZ@eyB zE(}E8e5hucwbeJ=UGdL@56FhwjB1aqh^3>;5-rH8%=+aRp7a2C9d;?&ljUHw^j7gg>qW}FNHfIq(Cen2?@--XoI+Q}z44Cx!{ z6Io{?A_I~U0zrWP+Y3zzW+(sEP8@2=8Ob}60B}nI0OV;QFjd+POyQf+c)?Ua@1-(E zAhZ-17{9Ci{z~cZU;yAI`yWpG8!r-z`2pte{m#liG%7-9zbNZ|43&mfq}Y~f zle*&8MHGq2V6AJcy6P&^K~<}moi4lk+4Y&{pZ9q`@9%w|_pk4lYr-;-cEy9l6aZi_ z7{Dcu-ji+s9WJ~tQvM(i=dk&phz6cI7O3wrm-=@?O(7CUZB1+m%(^(&nn+=G#?4k@ zcsnp1qW<-275~XT&qCw)0FRSV6c-*xIeyPr5X`poae?~!foeYw7YFQx1tkVy>IQaO zwwtH`SG)w+Kt$-x1p)xk;(){bv6)!Fu{w3wJNc6$HS`0+ss1rBY_HuIC=Q;gQ_L8W zl67<90px-N<#3Q;vwL@JkXjMVR7*X0PZw3X*}E2WA#-O|hxFv`4)M*;LOaWW&#S+? zJ$z{;X<0PX104+6jo&@!OeIdW91M3+|B&{xxDBoTbsVBoWdqhVdzu#}>CU^DvuL-v z)5s~u+Ax9C>06}1F|qrziwDkHyMoWm9fc0D^cm7<9c0&*RzmrpP1}!*f@jeS=m4Zx zsNkG#-5Ki#r864K=2^_O!t)7v)T=1sG5+TsXSKy#DFt+le1lIxL#S>vO4Bc(40S+WxssX?ShFq1rFS^&nB0 z-R9E+8jai{)p2MWccs+jNzioQwVK;G@2?J~J;ad3>x+t(A7#q*Cec26dak*9sFiT_ zm(1D>Go^pgsMEUXFbV$&&K~ob$uX4mo{bIs4Jdju$gF?_8f~~e>#Tgff12j$VbK3| z1WStw;>&6g_Qgv=us;}+qdcJ&tA-<-8>9f8=5L`n&$Z(3jCkFj){oZoT^mflnZn)@ zEML+&Wz%;W7H0Dl^PR@I1ioV~axv+RENl8sbXRvj2@F* z(8ss@z%ujdw5-`2)oDz!qh7Il9nPjve1!TcGFkRus!4+E>bLs^5-7Ls4@)Lr_FMB? zHghXb025;y!^^B21ms3Z+D2=Bxrv0;XrVm+zO zQinBNVMvM;VbLA-Ejx^64A%mf$MmA(8=3rSx*QxPckXWWyyZaa(yiJBQ;gnj@2a6? z*(>`#pYTsCc)y=XccxxlJC(qpp5g3%Xk(ChE=M!3`iJEf0#&zOn(nZ^k{?V8=qhP5SjoAQkmwvS@NTwg`VPB*`Yaf+) zkzvCZzNivX%T_`qlp}HnE*TA*eV@I&;xJLDkT7lDa{Kf!{_jbA>L!78mNTze#Su$y z34WRQx4ERU)O+znKIgn|f7mv#Hi?Xj*^4a^DWy{~@+A_+<_({XmpazumOAWc$_n3? z8_&=Rv@s6#dVj!0-}=%8+ssT?GPm;e0~=9~3&TkQ`8e&dRg} z;LyxvOxn=kTYeQ$5_26g{vRI-jx-u>J6s4ZmT;H5F>4FcBl4oV2y7R+jV77~T zwZ+7!MXE~s7#&afYSS$g9sV+32A=P4-Dcg3t6uazTeA>|T}OhpB&~0}yZx6>p+iE9 zFu_@8g3yv9@8+|5prb}yuMBDC6~BixIWF*uE=8v*y0*5u0yXj`l8rdv`Cymlr@QLr zgnG4wxn1+?-@S|ZnEp;G#c3w$y;WrN@0^{tmrzI8#=SgKaO)7+*^nb_p8zTV{$~>g z8ew-N3kDaV#K=a-FeVVtBdahXkV+~DvHyVlkmH9K|MSfVgc1;c^4kdW%Qt_k&;^^dG1;{1+ From 5d30eae56051c563a8427f330b09ef66db0a0d21 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 29 Jun 2015 17:21:35 -0700 Subject: [PATCH 21/39] [SPARK-8437] [DOCS] Using directory path without wildcard for filename slow for large number of files with wholeTextFiles and binaryFiles Note that 'dir/*' can be more efficient in some Hadoop FS implementations that 'dir/' Author: Sean Owen Closes #7036 from srowen/SPARK-8437 and squashes the following commits: 0e813ae [Sean Owen] Note that 'dir/*' can be more efficient in some Hadoop FS implementations that 'dir/' --- core/src/main/scala/org/apache/spark/SparkContext.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b3c3bf3746e18..cb7e24c374152 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -831,6 +831,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * @note Small files are preferred, large file is also allowable, but may cause bad performance. + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files in a directory + * rather than `.../path/` or `.../path` * * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @@ -878,9 +880,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * (a-hdfs-path/part-nnnnn, its content) * }}} * - * @param minPartitions A suggestion value of the minimal splitting number for input data. - * * @note Small files are preferred; very large files may cause bad performance. + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files in a directory + * rather than `.../path/` or `.../path` + * + * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @Experimental def binaryFiles( From d7f796da45d9a7c76ee4c29a9e0661ef76d8028a Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 29 Jun 2015 17:27:02 -0700 Subject: [PATCH 22/39] [SPARK-8410] [SPARK-8475] remove previous ivy resolution when using spark-submit This PR also includes re-ordering the order that repositories are used when resolving packages. User provided repositories will be prioritized. cc andrewor14 Author: Burak Yavuz Closes #7089 from brkyvz/delete-prev-ivy-resolution and squashes the following commits: a21f95a [Burak Yavuz] remove previous ivy resolution when using spark-submit --- .../org/apache/spark/deploy/SparkSubmit.scala | 37 ++++++++++++------- .../spark/deploy/SparkSubmitUtilsSuite.scala | 6 +-- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index abf222757a95b..b1d6ec209d62b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -756,6 +756,20 @@ private[spark] object SparkSubmitUtils { val cr = new ChainResolver cr.setName("list") + val repositoryList = remoteRepos.getOrElse("") + // add any other remote repositories other than maven central + if (repositoryList.trim.nonEmpty) { + repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => + val brr: IBiblioResolver = new IBiblioResolver + brr.setM2compatible(true) + brr.setUsepoms(true) + brr.setRoot(repo) + brr.setName(s"repo-${i + 1}") + cr.add(brr) + printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + } + } + val localM2 = new IBiblioResolver localM2.setM2compatible(true) localM2.setRoot(m2Path.toURI.toString) @@ -786,20 +800,6 @@ private[spark] object SparkSubmitUtils { sp.setRoot("http://dl.bintray.com/spark-packages/maven") sp.setName("spark-packages") cr.add(sp) - - val repositoryList = remoteRepos.getOrElse("") - // add any other remote repositories other than maven central - if (repositoryList.trim.nonEmpty) { - repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => - val brr: IBiblioResolver = new IBiblioResolver - brr.setM2compatible(true) - brr.setUsepoms(true) - brr.setRoot(repo) - brr.setName(s"repo-${i + 1}") - cr.add(brr) - printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") - } - } cr } @@ -922,6 +922,15 @@ private[spark] object SparkSubmitUtils { // A Module descriptor must be specified. Entries are dummy strings val md = getModuleDescriptor + // clear ivy resolution from previous launches. The resolution file is usually at + // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file + // leads to confusion with Ivy when the files can no longer be found at the repository + // declared in that file/ + val mdId = md.getModuleRevisionId + val previousResolution = new File(ivySettings.getDefaultCache, + s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml") + if (previousResolution.exists) previousResolution.delete + md.setDefaultConf(ivyConfName) // Add exclusion rules for Spark and Scala Library diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 12c40f0b7d658..c9b435a9228d3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -77,9 +77,9 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(resolver2.getResolvers.size() === 7) val expected = repos.split(",").map(r => s"$r/") resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: AbstractResolver, i) => - if (i > 3) { - assert(resolver.getName === s"repo-${i - 3}") - assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i - 4)) + if (i < 3) { + assert(resolver.getName === s"repo-${i + 1}") + assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i)) } } } From 4a9e03fa850af9e4ee56d011671faa04fb601170 Mon Sep 17 00:00:00 2001 From: Michael Sannella x268 Date: Mon, 29 Jun 2015 17:28:28 -0700 Subject: [PATCH 23/39] [SPARK-8019] [SPARKR] Support SparkR spawning worker R processes with a command other then Rscript This is a simple change to add a new environment variable "spark.sparkr.r.command" that specifies the command that SparkR will use when creating an R engine process. If this is not specified, "Rscript" will be used by default. I did not add any documentation, since I couldn't find any place where environment variables (such as "spark.sparkr.use.daemon") are documented. I also did not add a unit test. The only test that would work generally would be one starting SparkR with sparkR.init(sparkEnvir=list(spark.sparkr.r.command="Rscript")), just using the default value. I think that this is a low-risk change. Likely committers: shivaram Author: Michael Sannella x268 Closes #6557 from msannell/altR and squashes the following commits: 7eac142 [Michael Sannella x268] add spark.sparkr.r.command config parameter --- core/src/main/scala/org/apache/spark/api/r/RRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 4dfa7325934ff..524676544d6f5 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -391,7 +391,7 @@ private[r] object RRDD { } private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { - val rCommand = "Rscript" + val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript") val rOptions = "--vanilla" val rExecScript = rLibDir + "/SparkR/worker/" + script val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) From 4c1808be4d3aaa37a5a878892e91ca73ea405ffa Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 29 Jun 2015 18:32:31 -0700 Subject: [PATCH 24/39] Revert "[SPARK-8437] [DOCS] Using directory path without wildcard for filename slow for large number of files with wholeTextFiles and binaryFiles" This reverts commit 5d30eae56051c563a8427f330b09ef66db0a0d21. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cb7e24c374152..b3c3bf3746e18 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -831,8 +831,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * @note Small files are preferred, large file is also allowable, but may cause bad performance. - * @note On some filesystems, `.../path/*` can be a more efficient way to read all files in a directory - * rather than `.../path/` or `.../path` * * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @@ -880,11 +878,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * (a-hdfs-path/part-nnnnn, its content) * }}} * - * @note Small files are preferred; very large files may cause bad performance. - * @note On some filesystems, `.../path/*` can be a more efficient way to read all files in a directory - * rather than `.../path/` or `.../path` - * * @param minPartitions A suggestion value of the minimal splitting number for input data. + * + * @note Small files are preferred; very large files may cause bad performance. */ @Experimental def binaryFiles( From 620605a4a1123afaab2674e38251f1231dea17ce Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 29 Jun 2015 18:40:30 -0700 Subject: [PATCH 25/39] [SPARK-8456] [ML] Ngram featurizer python Python API for N-gram feature transformer Author: Feynman Liang Closes #6960 from feynmanliang/ngram-featurizer-python and squashes the following commits: f9e37c9 [Feynman Liang] Remove debugging code 4dd81f4 [Feynman Liang] Fix typo and doctest 06c79ac [Feynman Liang] Style guide 26c1175 [Feynman Liang] Add python NGram API --- python/pyspark/ml/feature.py | 71 +++++++++++++++++++++++++++++++++++- python/pyspark/ml/tests.py | 11 ++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index ddb33f427ac64..8804dace849b3 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -21,7 +21,7 @@ from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer from pyspark.mllib.common import inherit_doc -__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'Normalizer', 'OneHotEncoder', +__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', 'Word2VecModel'] @@ -265,6 +265,75 @@ class IDFModel(JavaModel): """ +@inherit_doc +@ignore_unicode_prefix +class NGram(JavaTransformer, HasInputCol, HasOutputCol): + """ + A feature transformer that converts the input array of strings into an array of n-grams. Null + values in the input array are ignored. + It returns an array of n-grams where each n-gram is represented by a space-separated string of + words. + When the input is empty, an empty array is returned. + When the input array length is less than n (number of elements per n-gram), no n-grams are + returned. + + >>> df = sqlContext.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])]) + >>> ngram = NGram(n=2, inputCol="inputTokens", outputCol="nGrams") + >>> ngram.transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e']) + >>> # Change n-gram length + >>> ngram.setParams(n=4).transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e']) + >>> # Temporarily modify output column. + >>> ngram.transform(df, {ngram.outputCol: "output"}).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], output=[u'a b c d', u'b c d e']) + >>> ngram.transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e']) + >>> # Must use keyword arguments to specify params. + >>> ngram.setParams("text") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. + """ + + # a placeholder to make it appear in the generated doc + n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)") + + @keyword_only + def __init__(self, n=2, inputCol=None, outputCol=None): + """ + __init__(self, n=2, inputCol=None, outputCol=None) + """ + super(NGram, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid) + self.n = Param(self, "n", "number of elements per n-gram (>=1)") + self._setDefault(n=2) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, n=2, inputCol=None, outputCol=None): + """ + setParams(self, n=2, inputCol=None, outputCol=None) + Sets params for this NGram. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setN(self, value): + """ + Sets the value of :py:attr:`n`. + """ + self._paramMap[self.n] = value + return self + + def getN(self): + """ + Gets the value of n or its default value. + """ + return self.getOrDefault(self.n) + + @inherit_doc class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6adbf166f34a8..c151d21fd661a 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -252,6 +252,17 @@ def test_idf(self): output = idf0m.transform(dataset) self.assertIsNotNone(output.head().idf) + def test_ngram(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + ([["a", "b", "c", "d", "e"]])], ["input"]) + ngram0 = NGram(n=4, inputCol="input", outputCol="output") + self.assertEqual(ngram0.getN(), 4) + self.assertEqual(ngram0.getInputCol(), "input") + self.assertEqual(ngram0.getOutputCol(), "output") + transformedDF = ngram0.transform(dataset) + self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + if __name__ == "__main__": unittest.main() From ecacb1e88a135c802e253793e7c863d6ca8d2408 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 29 Jun 2015 18:48:28 -0700 Subject: [PATCH 26/39] [SPARK-8715] ArrayOutOfBoundsException fixed for DataFrameStatSuite.crosstab cc yhuai Author: Burak Yavuz Closes #7100 from brkyvz/ct-flakiness-fix and squashes the following commits: abc299a [Burak Yavuz] change 'to' to until 7e96d7c [Burak Yavuz] ArrayOutOfBoundsException fixed for DataFrameStatSuite.crosstab --- .../test/scala/org/apache/spark/sql/DataFrameStatSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 64ec1a70c47e6..765094da6bda7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -78,7 +78,7 @@ class DataFrameStatSuite extends SparkFunSuite { val rows = crosstab.collect() rows.foreach { row => val i = row.getString(0).toInt - for (col <- 1 to 9) { + for (col <- 1 until columnNames.length) { val j = columnNames(col).toInt assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong) } From 4915e9e3bffb57eac319ef2173b4a6ae4073d25e Mon Sep 17 00:00:00 2001 From: Steven She Date: Mon, 29 Jun 2015 18:50:09 -0700 Subject: [PATCH 27/39] [SPARK-8669] [SQL] Fix crash with BINARY (ENUM) fields with Parquet 1.7 Patch to fix crash with BINARY fields with ENUM original types. Author: Steven She Closes #7048 from stevencanopy/SPARK-8669 and squashes the following commits: 2e72979 [Steven She] [SPARK-8669] [SQL] Fix crash with BINARY (ENUM) fields with Parquet 1.7 --- .../spark/sql/parquet/CatalystSchemaConverter.scala | 2 +- .../org/apache/spark/sql/parquet/ParquetSchemaSuite.scala | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index 4fd3e93b70311..2be7c64612cd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -177,7 +177,7 @@ private[parquet] class CatalystSchemaConverter( case BINARY => field.getOriginalType match { - case UTF8 => StringType + case UTF8 | ENUM => StringType case null if assumeBinaryIsString => StringType case null => BinaryType case DECIMAL => makeDecimalType() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index d0bfcde7e032b..35d3c33f99a06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -161,6 +161,14 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { """.stripMargin, binaryAsString = true) + testSchemaInference[Tuple1[String]]( + "binary enum as string", + """ + |message root { + | optional binary _1 (ENUM); + |} + """.stripMargin) + testSchemaInference[Tuple1[Seq[Int]]]( "non-nullable array - non-standard", """ From f9b6bf2f83d9dad273aa36d65d0560d35b941cc2 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 29 Jun 2015 18:50:23 -0700 Subject: [PATCH 28/39] [SPARK-7667] [MLLIB] MLlib Python API consistency check MLlib Python API consistency check Author: Yanbo Liang Closes #6856 from yanboliang/spark-7667 and squashes the following commits: 21bae35 [Yanbo Liang] remove duplicate code eb12f95 [Yanbo Liang] fix doc inherit problem 9e7ec3c [Yanbo Liang] address comments e763d32 [Yanbo Liang] MLlib Python API consistency check --- python/pyspark/mllib/feature.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index f00bb93b7bf40..b5138773fd61b 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -111,6 +111,15 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer): """ def transform(self, vector): + """ + Applies transformation on a vector or an RDD[Vector]. + + Note: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. + + :param vector: Vector or RDD of Vector to be transformed. + """ if isinstance(vector, RDD): vector = vector.map(_convert_to_vector) else: @@ -191,7 +200,7 @@ def fit(self, dataset): Computes the mean and variance and stores as a model to be used for later scaling. - :param data: The data used to compute the mean and variance + :param dataset: The data used to compute the mean and variance to build the transformation model. :return: a StandardScalarModel """ @@ -346,10 +355,6 @@ def transform(self, x): vector :return: an RDD of TF-IDF vectors or a TF-IDF vector """ - if isinstance(x, RDD): - return JavaVectorTransformer.transform(self, x) - - x = _convert_to_vector(x) return JavaVectorTransformer.transform(self, x) def idf(self): From 7bbbe380c52419cd580d1c99c10131184e4ad440 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 29 Jun 2015 21:32:40 -0700 Subject: [PATCH 29/39] [SPARK-5161] Parallelize Python test execution This commit parallelizes the Python unit test execution, significantly reducing Jenkins build times. Parallelism is now configurable by passing the `-p` or `--parallelism` flags to either `dev/run-tests` or `python/run-tests` (the default parallelism is 4, but I've successfully tested with higher parallelism). To avoid flakiness, I've disabled the Spark Web UI for the Python tests, similar to what we've done for the JVM tests. Author: Josh Rosen Closes #7031 from JoshRosen/parallelize-python-tests and squashes the following commits: feb3763 [Josh Rosen] Re-enable other tests f87ea81 [Josh Rosen] Only log output from failed tests d4ded73 [Josh Rosen] Logging improvements a2717e1 [Josh Rosen] Make parallelism configurable via dev/run-tests 1bacf1b [Josh Rosen] Merge remote-tracking branch 'origin/master' into parallelize-python-tests 110cd9d [Josh Rosen] Fix universal_newlines for Python 3 cd13db8 [Josh Rosen] Also log python_implementation 9e31127 [Josh Rosen] Log Python --version output for each executable. a2b9094 [Josh Rosen] Bump up parallelism. 5552380 [Josh Rosen] Python 3 fix 866b5b9 [Josh Rosen] Fix lazy logging warnings in Prospector checks 87cb988 [Josh Rosen] Skip MLLib tests for PyPy 8309bfe [Josh Rosen] Temporarily disable parallelism to debug a failure 9129027 [Josh Rosen] Disable Spark UI in Python tests 037b686 [Josh Rosen] Temporarily disable JVM tests so we can test Python speedup in Jenkins. af4cef4 [Josh Rosen] Initial attempt at parallelizing Python test execution --- dev/run-tests | 2 +- dev/run-tests.py | 24 +++++++- dev/sparktestsupport/shellutils.py | 1 + python/pyspark/java_gateway.py | 2 + python/run-tests.py | 97 +++++++++++++++++++++++------- 5 files changed, 101 insertions(+), 25 deletions(-) diff --git a/dev/run-tests b/dev/run-tests index a00d9f0c27639..257d1e8d50bb4 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -20,4 +20,4 @@ FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" -exec python -u ./dev/run-tests.py +exec python -u ./dev/run-tests.py "$@" diff --git a/dev/run-tests.py b/dev/run-tests.py index e5c897b94d167..4596e07014733 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -19,6 +19,7 @@ from __future__ import print_function import itertools +from optparse import OptionParser import os import re import sys @@ -360,12 +361,13 @@ def run_scala_tests(build_tool, hadoop_version, test_modules): run_scala_tests_sbt(test_modules, test_profiles) -def run_python_tests(test_modules): +def run_python_tests(test_modules, parallelism): set_title_and_block("Running PySpark tests", "BLOCK_PYSPARK_UNIT_TESTS") command = [os.path.join(SPARK_HOME, "python", "run-tests")] if test_modules != [modules.root]: command.append("--modules=%s" % ','.join(m.name for m in test_modules)) + command.append("--parallelism=%i" % parallelism) run_cmd(command) @@ -379,7 +381,25 @@ def run_sparkr_tests(): print("Ignoring SparkR tests as R was not found in PATH") +def parse_opts(): + parser = OptionParser( + prog="run-tests" + ) + parser.add_option( + "-p", "--parallelism", type="int", default=4, + help="The number of suites to test in parallel (default %default)" + ) + + (opts, args) = parser.parse_args() + if args: + parser.error("Unsupported arguments: %s" % ' '.join(args)) + if opts.parallelism < 1: + parser.error("Parallelism cannot be less than 1") + return opts + + def main(): + opts = parse_opts() # Ensure the user home directory (HOME) is valid and is an absolute directory if not USER_HOME or not os.path.isabs(USER_HOME): print("[error] Cannot determine your home directory as an absolute path;", @@ -461,7 +481,7 @@ def main(): modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: - run_python_tests(modules_with_python_tests) + run_python_tests(modules_with_python_tests, opts.parallelism) if any(m.should_run_r_tests for m in test_modules): run_sparkr_tests() diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py index ad9b0cc89e4ab..12bd0bf3a4fe9 100644 --- a/dev/sparktestsupport/shellutils.py +++ b/dev/sparktestsupport/shellutils.py @@ -15,6 +15,7 @@ # limitations under the License. # +from __future__ import print_function import os import shutil import subprocess diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3cee4ea6e3a35..90cd342a6cf7f 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -51,6 +51,8 @@ def launch_gateway(): on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + if os.environ.get("SPARK_TESTING"): + submit_args = "--conf spark.ui.enabled=false " + submit_args command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us diff --git a/python/run-tests.py b/python/run-tests.py index 7d485b500ee3a..aaa35e936a806 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -18,12 +18,19 @@ # from __future__ import print_function +import logging from optparse import OptionParser import os import re import subprocess import sys +import tempfile +from threading import Thread, Lock import time +if sys.version < '3': + import Queue +else: + import queue as Queue # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -43,34 +50,44 @@ def print_red(text): LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") +FAILURE_REPORTING_LOCK = Lock() +LOGGER = logging.getLogger() def run_individual_python_test(test_name, pyspark_python): env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)} - print(" Running test: %s ..." % test_name, end='') + LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() - with open(LOG_FILE, 'a') as log_file: - retcode = subprocess.call( - [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], - stderr=log_file, stdout=log_file, env=env) + per_test_output = tempfile.TemporaryFile() + retcode = subprocess.Popen( + [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], + stderr=per_test_output, stdout=per_test_output, env=env).wait() duration = time.time() - start_time # Exit on the first failure. if retcode != 0: - with open(LOG_FILE, 'r') as log_file: - for line in log_file: + with FAILURE_REPORTING_LOCK: + with open(LOG_FILE, 'ab') as log_file: + per_test_output.seek(0) + log_file.writelines(per_test_output.readlines()) + per_test_output.seek(0) + for line in per_test_output: if not re.match('[0-9]+', line): print(line, end='') - print_red("\nHad test failures in %s; see logs." % test_name) - exit(-1) + per_test_output.close() + print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python)) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(-1) else: - print("ok (%is)" % duration) + per_test_output.close() + LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration) def get_default_python_executables(): python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)] if "python2.6" not in python_execs: - print("WARNING: Not testing against `python2.6` because it could not be found; falling" - " back to `python` instead") + LOGGER.warning("Not testing against `python2.6` because it could not be found; falling" + " back to `python` instead") python_execs.insert(0, "python") return python_execs @@ -88,16 +105,31 @@ def parse_opts(): default=",".join(sorted(python_modules.keys())), help="A comma-separated list of Python modules to test (default: %default)" ) + parser.add_option( + "-p", "--parallelism", type="int", default=4, + help="The number of suites to test in parallel (default %default)" + ) + parser.add_option( + "--verbose", action="store_true", + help="Enable additional debug logging" + ) (opts, args) = parser.parse_args() if args: parser.error("Unsupported arguments: %s" % ' '.join(args)) + if opts.parallelism < 1: + parser.error("Parallelism cannot be less than 1") return opts def main(): opts = parse_opts() - print("Running PySpark tests. Output is in python/%s" % LOG_FILE) + if (opts.verbose): + log_level = logging.DEBUG + else: + log_level = logging.INFO + logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s") + LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE) if os.path.exists(LOG_FILE): os.remove(LOG_FILE) python_execs = opts.python_executables.split(',') @@ -108,24 +140,45 @@ def main(): else: print("Error: unrecognized module %s" % module_name) sys.exit(-1) - print("Will test against the following Python executables: %s" % python_execs) - print("Will test the following Python modules: %s" % [x.name for x in modules_to_test]) + LOGGER.info("Will test against the following Python executables: %s", python_execs) + LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test]) - start_time = time.time() + task_queue = Queue.Queue() for python_exec in python_execs: python_implementation = subprocess.check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], universal_newlines=True).strip() - print("Testing with `%s`: " % python_exec, end='') - subprocess.call([python_exec, "--version"]) - + LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation) + LOGGER.debug("%s version is: %s", python_exec, subprocess.check_output( + [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip()) for module in modules_to_test: if python_implementation not in module.blacklisted_python_implementations: - print("Running %s tests ..." % module.name) for test_goal in module.python_test_goals: - run_individual_python_test(test_goal, python_exec) + task_queue.put((python_exec, test_goal)) + + def process_queue(task_queue): + while True: + try: + (python_exec, test_goal) = task_queue.get_nowait() + except Queue.Empty: + break + try: + run_individual_python_test(test_goal, python_exec) + finally: + task_queue.task_done() + + start_time = time.time() + for _ in range(opts.parallelism): + worker = Thread(target=process_queue, args=(task_queue,)) + worker.daemon = True + worker.start() + try: + task_queue.join() + except (KeyboardInterrupt, SystemExit): + print_red("Exiting due to interrupt") + sys.exit(-1) total_duration = time.time() - start_time - print("Tests passed in %i seconds" % total_duration) + LOGGER.info("Tests passed in %i seconds", total_duration) if __name__ == "__main__": From ea775b0662b952849ac7fe2026fc3fd4714c37e3 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 29 Jun 2015 21:41:59 -0700 Subject: [PATCH 30/39] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #1767 (close requested by 'andrewor14') Closes #6952 (close requested by 'andrewor14') Closes #7051 (close requested by 'andrewor14') Closes #5357 (close requested by 'marmbrus') Closes #5233 (close requested by 'andrewor14') Closes #6930 (close requested by 'JoshRosen') Closes #5502 (close requested by 'andrewor14') Closes #6778 (close requested by 'andrewor14') Closes #7006 (close requested by 'andrewor14') From f79410c49b2225b2acdc58293574860230987775 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Jun 2015 22:32:43 -0700 Subject: [PATCH 31/39] [SPARK-8721][SQL] Rename ExpectsInputTypes => AutoCastInputTypes. Author: Reynold Xin Closes #7109 from rxin/auto-cast and squashes the following commits: a914cc3 [Reynold Xin] [SPARK-8721][SQL] Rename ExpectsInputTypes => AutoCastInputTypes. --- .../catalyst/analysis/HiveTypeCoercion.scala | 8 +- .../sql/catalyst/expressions/Expression.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 118 ++++++++---------- .../spark/sql/catalyst/expressions/misc.scala | 6 +- .../sql/catalyst/expressions/predicates.scala | 6 +- .../expressions/stringOperations.scala | 10 +- 6 files changed, 71 insertions(+), 79 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 976fa57cb98d5..c3d68197d64ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -116,7 +116,7 @@ trait HiveTypeCoercion { IfCoercion :: Division :: PropagateTypes :: - ExpectedInputConversion :: + AddCastForAutoCastInputTypes :: Nil /** @@ -709,15 +709,15 @@ trait HiveTypeCoercion { /** * Casts types according to the expected input types for Expressions that have the trait - * `ExpectsInputTypes`. + * [[AutoCastInputTypes]]. */ - object ExpectedInputConversion extends Rule[LogicalPlan] { + object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => + case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { case (child, actual, expected) => if (actual == expected) child else Cast(child, expected) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index f59db3d5dfc23..e5dc7b9b5c884 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -261,7 +261,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio * Expressions that require a specific `DataType` as input should implement this trait * so that the proper type conversions can be performed in the analyzer. */ -trait ExpectsInputTypes { +trait AutoCastInputTypes { self: Expression => def expectedChildTypes: Seq[DataType] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 4b57ddd9c5768..a022f3727bd58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -56,7 +56,7 @@ abstract class LeafMathExpression(c: Double, name: String) * @param name The short name of the function */ abstract class UnaryMathExpression(f: Double => Double, name: String) - extends UnaryExpression with Serializable with ExpectsInputTypes { + extends UnaryExpression with Serializable with AutoCastInputTypes { self: Product => override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) @@ -99,7 +99,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product => override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -211,19 +211,11 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia } case class Bin(child: Expression) - extends UnaryExpression with Serializable with ExpectsInputTypes { - - val name: String = "BIN" - - override def foldable: Boolean = child.foldable - override def nullable: Boolean = true - override def toString: String = s"$name($child)" + extends UnaryExpression with Serializable with AutoCastInputTypes { override def expectedChildTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType - def funcName: String = name.toLowerCase - override def eval(input: InternalRow): Any = { val evalE = child.eval(input) if (evalE == null) { @@ -239,61 +231,13 @@ case class Bin(child: Expression) } } -//////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Binary math functions -//////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -case class Atan2(left: Expression, right: Expression) - extends BinaryMathExpression(math.atan2, "ATAN2") { - - override def eval(input: InternalRow): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 - val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, - evalE2.asInstanceOf[Double] + 0.0) - if (result.isNaN) null else result - } - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ - } -} - -case class Pow(left: Expression, right: Expression) - extends BinaryMathExpression(math.pow, "POWER") { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ - } -} /** * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. - * Otherwise if the number is a STRING, - * it converts each character into its hexadecimal representation and returns the resulting STRING. - * Negative numbers would be treated as two's complement. + * Otherwise if the number is a STRING, it converts each character into its hex representation + * and returns the resulting STRING. Negative numbers would be treated as two's complement. */ -case class Hex(child: Expression) - extends UnaryExpression with Serializable { +case class Hex(child: Expression) extends UnaryExpression with Serializable { override def dataType: DataType = StringType @@ -337,7 +281,7 @@ case class Hex(child: Expression) private def doHex(bytes: Array[Byte], length: Int): UTF8String = { val value = new Array[Byte](length * 2) var i = 0 - while(i < length) { + while (i < length) { value(i * 2) = Character.toUpperCase(Character.forDigit( (bytes(i) & 0xF0) >>> 4, 16)).toByte value(i * 2 + 1) = Character.toUpperCase(Character.forDigit( @@ -362,6 +306,54 @@ case class Hex(child: Expression) } } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Binary math functions +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +case class Atan2(left: Expression, right: Expression) + extends BinaryMathExpression(math.atan2, "ATAN2") { + + override def eval(input: InternalRow): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 + val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, + evalE2.asInstanceOf[Double] + 0.0) + if (result.isNaN) null else result + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} + +case class Pow(left: Expression, right: Expression) + extends BinaryMathExpression(math.pow, "POWER") { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} + case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 9a39165a1ff05..27805bff293f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -31,7 +31,7 @@ import org.apache.spark.unsafe.types.UTF8String * For input of type [[BinaryType]] */ case class Md5(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with AutoCastInputTypes { override def dataType: DataType = StringType @@ -61,7 +61,7 @@ case class Md5(child: Expression) * the hash length is not one of the permitted values, the return value is NULL. */ case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with Serializable with ExpectsInputTypes { + extends BinaryExpression with Serializable with AutoCastInputTypes { override def dataType: DataType = StringType @@ -146,7 +146,7 @@ case class Sha2(left: Expression, right: Expression) * A function that calculates a sha1 hash value and returns it as a hex string * For input of type [[BinaryType]] or [[StringType]] */ -case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTypes { override def dataType: DataType = StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3a12d03ba6bb9..386cf6a8df6df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -70,7 +70,7 @@ trait PredicateHelper { } -case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes { +case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes { override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable override def toString: String = s"NOT $child" @@ -123,7 +123,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { + extends BinaryExpression with Predicate with AutoCastInputTypes { override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -172,7 +172,7 @@ case class And(left: Expression, right: Expression) } case class Or(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { + extends BinaryExpression with Predicate with AutoCastInputTypes { override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index a6225fdafedde..ce184e4f32f18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends ExpectsInputTypes { +trait StringRegexExpression extends AutoCastInputTypes { self: BinaryExpression => def escape(v: String): String @@ -111,7 +111,7 @@ case class RLike(left: Expression, right: Expression) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) } -trait CaseConversionExpression extends ExpectsInputTypes { +trait CaseConversionExpression extends AutoCastInputTypes { self: UnaryExpression => def convert(v: UTF8String): UTF8String @@ -158,7 +158,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringComparison extends ExpectsInputTypes { +trait StringComparison extends AutoCastInputTypes { self: BinaryExpression => def compare(l: UTF8String, r: UTF8String): Boolean @@ -221,7 +221,7 @@ case class EndsWith(left: Expression, right: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with AutoCastInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) @@ -295,7 +295,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) /** * A function that return the length of the given string expression. */ -case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes { override def dataType: DataType = IntegerType override def expectedChildTypes: Seq[DataType] = Seq(StringType) From e6c3f7462b3fde220ec0084b52388dd4dabb75b9 Mon Sep 17 00:00:00 2001 From: Yadong Qi Date: Mon, 29 Jun 2015 22:34:38 -0700 Subject: [PATCH 32/39] [SPARK-8650] [SQL] Use the user-specified app name priority in SparkSQLCLIDriver or HiveThriftServer2 When run `./bin/spark-sql --name query1.sql` [Before] ![before](https://cloud.githubusercontent.com/assets/1400819/8370336/fa20b75a-1bf8-11e5-9171-040049a53240.png) [After] ![after](https://cloud.githubusercontent.com/assets/1400819/8370189/dcc35cb4-1bf6-11e5-8796-a0694140bffb.png) Author: Yadong Qi Closes #7030 from watermen/SPARK-8650 and squashes the following commits: 51b5134 [Yadong Qi] Improve code and add comment. e3d7647 [Yadong Qi] use spark.app.name priority. --- .../apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 79eda1f5123bf..1d41c46131828 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -38,9 +38,14 @@ private[hive] object SparkSQLEnv extends Logging { val sparkConf = new SparkConf(loadDefaults = true) val maybeSerializer = sparkConf.getOption("spark.serializer") val maybeKryoReferenceTracking = sparkConf.getOption("spark.kryo.referenceTracking") + // If user doesn't specify the appName, we want to get [SparkSQL::localHostName] instead of + // the default appName [SparkSQLCLIDriver] in cli or beeline. + val maybeAppName = sparkConf + .getOption("spark.app.name") + .filterNot(_ == classOf[SparkSQLCLIDriver].getName) sparkConf - .setAppName(s"SparkSQL::${Utils.localHostName()}") + .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}")) .set( "spark.serializer", maybeSerializer.getOrElse("org.apache.spark.serializer.KryoSerializer")) From 6c5a6db4d53d6db8aa3464ea6713cf0d3a3bdfb5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 29 Jun 2015 23:08:51 -0700 Subject: [PATCH 33/39] [SPARK-5161] [HOTFIX] Fix bug in Python test failure reporting This patch fixes a bug introduced in #7031 which can cause Jenkins to incorrectly report a build with failed Python tests as passing if an error occurred while printing the test failure message. Author: Josh Rosen Closes #7112 from JoshRosen/python-tests-hotfix and squashes the following commits: c3f2961 [Josh Rosen] Hotfix for bug in Python test failure reporting --- python/run-tests.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/python/run-tests.py b/python/run-tests.py index aaa35e936a806..b7737650daa54 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -58,22 +58,33 @@ def run_individual_python_test(test_name, pyspark_python): env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)} LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() - per_test_output = tempfile.TemporaryFile() - retcode = subprocess.Popen( - [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], - stderr=per_test_output, stdout=per_test_output, env=env).wait() + try: + per_test_output = tempfile.TemporaryFile() + retcode = subprocess.Popen( + [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], + stderr=per_test_output, stdout=per_test_output, env=env).wait() + except: + LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(1) duration = time.time() - start_time # Exit on the first failure. if retcode != 0: - with FAILURE_REPORTING_LOCK: - with open(LOG_FILE, 'ab') as log_file: + try: + with FAILURE_REPORTING_LOCK: + with open(LOG_FILE, 'ab') as log_file: + per_test_output.seek(0) + log_file.writelines(per_test_output) per_test_output.seek(0) - log_file.writelines(per_test_output.readlines()) - per_test_output.seek(0) - for line in per_test_output: - if not re.match('[0-9]+', line): - print(line, end='') - per_test_output.close() + for line in per_test_output: + decoded_line = line.decode() + if not re.match('[0-9]+', decoded_line): + print(decoded_line, end='') + per_test_output.close() + except: + LOGGER.exception("Got an exception while trying to print failed test output") + finally: print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python)) # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if # this code is invoked from a thread other than the main thread. From 12671dd5e468beedc2681ff2bdf95fba81f8f29c Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 29 Jun 2015 23:44:11 -0700 Subject: [PATCH 34/39] [SPARK-8434][SQL]Add a "pretty" parameter to the "show" method to display long strings Sometimes the user may want to show the complete content of cells. Now `sql("set -v").show()` displays: ![screen shot 2015-06-18 at 4 34 51 pm](https://cloud.githubusercontent.com/assets/1000778/8227339/14d3c5ea-15d9-11e5-99b9-f00b7e93beef.png) The user needs to use something like `sql("set -v").collect().foreach(r => r.toSeq.mkString("\t"))` to show the complete content. This PR adds a `pretty` parameter to show. If `pretty` is false, `show` won't truncate strings or align cells right. ![screen shot 2015-06-18 at 4 21 44 pm](https://cloud.githubusercontent.com/assets/1000778/8227407/b6f8dcac-15d9-11e5-8219-8079280d76fc.png) Author: zsxwing Closes #6877 from zsxwing/show and squashes the following commits: 22e28e9 [zsxwing] pretty -> truncate e582628 [zsxwing] Add pretty parameter to the show method in R a3cd55b [zsxwing] Fix calling showString in R 923cee4 [zsxwing] Add a "pretty" parameter to show to display long strings --- R/pkg/R/DataFrame.R | 4 +- python/pyspark/sql/dataframe.py | 7 ++- .../org/apache/spark/sql/DataFrame.scala | 55 ++++++++++++++++--- .../org/apache/spark/sql/DataFrameSuite.scala | 21 +++++++ 4 files changed, 76 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 6feabf4189c2d..60702824acb46 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -169,8 +169,8 @@ setMethod("isLocal", #'} setMethod("showDF", signature(x = "DataFrame"), - function(x, numRows = 20) { - s <- callJMethod(x@sdf, "showString", numToInt(numRows)) + function(x, numRows = 20, truncate = TRUE) { + s <- callJMethod(x@sdf, "showString", numToInt(numRows), truncate) cat(s) }) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 152b87351db31..4b9efa0a210fb 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -247,9 +247,12 @@ def isLocal(self): return self._jdf.isLocal() @since(1.3) - def show(self, n=20): + def show(self, n=20, truncate=True): """Prints the first ``n`` rows to the console. + :param n: Number of rows to show. + :param truncate: Whether truncate long strings and align cells right. + >>> df DataFrame[age: int, name: string] >>> df.show() @@ -260,7 +263,7 @@ def show(self, n=20): | 5| Bob| +---+-----+ """ - print(self._jdf.showString(n)) + print(self._jdf.showString(n, truncate)) def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 986e59133919f..8fe1f7e34cb5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -169,8 +169,9 @@ class DataFrame private[sql]( /** * Internal API for Python * @param _numRows Number of rows to show + * @param truncate Whether truncate long strings and align cells right */ - private[sql] def showString(_numRows: Int): String = { + private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { val numRows = _numRows.max(0) val sb = new StringBuilder val takeResult = take(numRows + 1) @@ -188,7 +189,7 @@ class DataFrame private[sql]( case seq: Seq[_] => seq.mkString("[", ", ", "]") case _ => cell.toString } - if (str.length > 20) str.substring(0, 17) + "..." else str + if (truncate && str.length > 20) str.substring(0, 17) + "..." else str }: Seq[String] } @@ -207,7 +208,11 @@ class DataFrame private[sql]( // column names rows.head.zipWithIndex.map { case (cell, i) => - StringUtils.leftPad(cell, colWidths(i)) + if (truncate) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } }.addString(sb, "|", "|", "|\n") sb.append(sep) @@ -215,7 +220,11 @@ class DataFrame private[sql]( // data rows.tail.map { _.zipWithIndex.map { case (cell, i) => - StringUtils.leftPad(cell.toString, colWidths(i)) + if (truncate) { + StringUtils.leftPad(cell.toString, colWidths(i)) + } else { + StringUtils.rightPad(cell.toString, colWidths(i)) + } }.addString(sb, "|", "|", "|\n") } @@ -331,7 +340,8 @@ class DataFrame private[sql]( def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] /** - * Displays the [[DataFrame]] in a tabular form. For example: + * Displays the [[DataFrame]] in a tabular form. Strings more than 20 characters will be + * truncated, and all cells will be aligned right. For example: * {{{ * year month AVG('Adj Close) MAX('Adj Close) * 1980 12 0.503218 0.595103 @@ -345,15 +355,46 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def show(numRows: Int): Unit = println(showString(numRows)) + def show(numRows: Int): Unit = show(numRows, true) /** - * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. * @group action * @since 1.3.0 */ def show(): Unit = show(20) + /** + * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @group action + * @since 1.5.0 + */ + def show(truncate: Boolean): Unit = show(20, truncate) + + /** + * Displays the [[DataFrame]] in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @group action + * @since 1.5.0 + */ + def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate)) + /** * Returns a [[DataFrameNaFunctions]] for working with missing data. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d06b9c5785527..50d324c0686fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -492,6 +492,27 @@ class DataFrameSuite extends QueryTest { testData.select($"*").show(1000) } + test("showString: truncate = [true, false]") { + val longString = Array.fill(21)("1").mkString + val df = ctx.sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = """+---------------------+ + ||_1 | + |+---------------------+ + ||1 | + ||111111111111111111111| + |+---------------------+ + |""".stripMargin + assert(df.showString(10, false) === expectedAnswerForFalse) + val expectedAnswerForTrue = """+--------------------+ + || _1| + |+--------------------+ + || 1| + ||11111111111111111...| + |+--------------------+ + |""".stripMargin + assert(df.showString(10, true) === expectedAnswerForTrue) + } + test("showString(negative)") { val expectedAnswer = """+---+-----+ ||key|value| From 5452457410ffe881773f2f2cdcdc752467b19720 Mon Sep 17 00:00:00 2001 From: Shuo Xiang Date: Mon, 29 Jun 2015 23:50:34 -0700 Subject: [PATCH 35/39] [SPARK-8551] [ML] Elastic net python code example Author: Shuo Xiang Closes #6946 from coderxiang/en-java-code-example and squashes the following commits: 7a4bdf8 [Shuo Xiang] address comments cddb02b [Shuo Xiang] add elastic net python example code f4fa534 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 6ad4865 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 180b496 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' aa0717d [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 5f109b4 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' c5c5bfe [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 98804c9 [Shuo Xiang] fix bug in topBykey and update test --- .../src/main/python/ml/logistic_regression.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 examples/src/main/python/ml/logistic_regression.py diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py new file mode 100644 index 0000000000000..55afe1b207fe0 --- /dev/null +++ b/examples/src/main/python/ml/logistic_regression.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.evaluation import MulticlassMetrics +from pyspark.ml.feature import StringIndexer +from pyspark.mllib.util import MLUtils +from pyspark.sql import SQLContext + +""" +A simple example demonstrating a logistic regression with elastic net regularization Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/logistic_regression.py +""" + +if __name__ == "__main__": + + if len(sys.argv) > 1: + print("Usage: logistic_regression", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonLogisticRegressionExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [training, test] = td.randomSplit([0.7, 0.3]) + + lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel") + lr.setElasticNetParam(0.8) + + # Fit the model + lrModel = lr.fit(training) + + predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + sc.stop() From 2ed0c0ac4686ea779f98713978e37b97094edc1c Mon Sep 17 00:00:00 2001 From: Tim Ellison Date: Tue, 30 Jun 2015 13:49:52 +0100 Subject: [PATCH 36/39] [SPARK-7756] [CORE] More robust SSL options processing. Subset the enabled algorithms in an SSLOptions to the elements that are supported by the protocol provider. Update the list of ciphers in the sample config to include modern algorithms, and specify both Oracle and IBM names. In practice the user would either specify their own chosen cipher suites, or specify none, and delegate the decision to the provider. Author: Tim Ellison Closes #7043 from tellison/SSLEnhancements and squashes the following commits: 034efa5 [Tim Ellison] Ensure Java imports are grouped and ordered by package. 3797f8b [Tim Ellison] Remove unnecessary use of Option to improve clarity, and fix import style ordering. 4b5c89f [Tim Ellison] More robust SSL options processing. --- .../scala/org/apache/spark/SSLOptions.scala | 43 ++++++++++++++++--- .../org/apache/spark/SSLOptionsSuite.scala | 20 ++++++--- .../org/apache/spark/SSLSampleConfigs.scala | 24 ++++++++--- .../apache/spark/SecurityManagerSuite.scala | 21 ++++++--- 4 files changed, 85 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 2cdc167f85af0..32df42d57dbd6 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -17,7 +17,9 @@ package org.apache.spark -import java.io.File +import java.io.{File, FileInputStream} +import java.security.{KeyStore, NoSuchAlgorithmException} +import javax.net.ssl.{KeyManager, KeyManagerFactory, SSLContext, TrustManager, TrustManagerFactory} import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} import org.eclipse.jetty.util.ssl.SslContextFactory @@ -38,7 +40,7 @@ import org.eclipse.jetty.util.ssl.SslContextFactory * @param trustStore a path to the trust-store file * @param trustStorePassword a password to access the trust-store file * @param protocol SSL protocol (remember that SSLv3 was compromised) supported by Java - * @param enabledAlgorithms a set of encryption algorithms to use + * @param enabledAlgorithms a set of encryption algorithms that may be used */ private[spark] case class SSLOptions( enabled: Boolean = false, @@ -48,7 +50,8 @@ private[spark] case class SSLOptions( trustStore: Option[File] = None, trustStorePassword: Option[String] = None, protocol: Option[String] = None, - enabledAlgorithms: Set[String] = Set.empty) { + enabledAlgorithms: Set[String] = Set.empty) + extends Logging { /** * Creates a Jetty SSL context factory according to the SSL settings represented by this object. @@ -63,7 +66,7 @@ private[spark] case class SSLOptions( trustStorePassword.foreach(sslContextFactory.setTrustStorePassword) keyPassword.foreach(sslContextFactory.setKeyManagerPassword) protocol.foreach(sslContextFactory.setProtocol) - sslContextFactory.setIncludeCipherSuites(enabledAlgorithms.toSeq: _*) + sslContextFactory.setIncludeCipherSuites(supportedAlgorithms.toSeq: _*) Some(sslContextFactory) } else { @@ -94,7 +97,7 @@ private[spark] case class SSLOptions( .withValue("akka.remote.netty.tcp.security.protocol", ConfigValueFactory.fromAnyRef(protocol.getOrElse(""))) .withValue("akka.remote.netty.tcp.security.enabled-algorithms", - ConfigValueFactory.fromIterable(enabledAlgorithms.toSeq)) + ConfigValueFactory.fromIterable(supportedAlgorithms.toSeq)) .withValue("akka.remote.netty.tcp.enable-ssl", ConfigValueFactory.fromAnyRef(true))) } else { @@ -102,6 +105,36 @@ private[spark] case class SSLOptions( } } + /* + * The supportedAlgorithms set is a subset of the enabledAlgorithms that + * are supported by the current Java security provider for this protocol. + */ + private val supportedAlgorithms: Set[String] = { + var context: SSLContext = null + try { + context = SSLContext.getInstance(protocol.orNull) + /* The set of supported algorithms does not depend upon the keys, trust, or + rng, although they will influence which algorithms are eventually used. */ + context.init(null, null, null) + } catch { + case npe: NullPointerException => + logDebug("No SSL protocol specified") + context = SSLContext.getDefault + case nsa: NoSuchAlgorithmException => + logDebug(s"No support for requested SSL protocol ${protocol.get}") + context = SSLContext.getDefault + } + + val providerAlgorithms = context.getServerSocketFactory.getSupportedCipherSuites.toSet + + // Log which algorithms we are discarding + (enabledAlgorithms &~ providerAlgorithms).foreach { cipher => + logDebug(s"Discarding unsupported cipher $cipher") + } + + enabledAlgorithms & providerAlgorithms + } + /** Returns a string representation of this SSLOptions with all the passwords masked. */ override def toString: String = s"SSLOptions{enabled=$enabled, " + s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 376481ba541fa..25b79bce6ab98 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import javax.net.ssl.SSLContext import com.google.common.io.Files import org.apache.spark.util.Utils @@ -29,6 +30,15 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + // Pick two cipher suites that the provider knows about + val sslContext = SSLContext.getInstance("TLSv1.2") + sslContext.init(null, null, null) + val algorithms = sslContext + .getServerSocketFactory + .getDefaultCipherSuites + .take(2) + .toSet + val conf = new SparkConf conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) @@ -36,9 +46,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", - "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") - conf.set("spark.ssl.protocol", "SSLv3") + conf.set("spark.ssl.enabledAlgorithms", algorithms.mkString(",")) + conf.set("spark.ssl.protocol", "TLSv1.2") val opts = SSLOptions.parse(conf, "spark.ssl") @@ -52,9 +61,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(opts.trustStorePassword === Some("password")) assert(opts.keyStorePassword === Some("password")) assert(opts.keyPassword === Some("password")) - assert(opts.protocol === Some("SSLv3")) - assert(opts.enabledAlgorithms === - Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + assert(opts.protocol === Some("TLSv1.2")) + assert(opts.enabledAlgorithms === algorithms) } test("test resolving property with defaults specified ") { diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 1a099da2c6c8e..33270bec6247c 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -25,6 +25,20 @@ object SSLSampleConfigs { this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + val enabledAlgorithms = + // A reasonable set of TLSv1.2 Oracle security provider suites + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + + "TLS_RSA_WITH_AES_256_CBC_SHA256, " + + "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, " + + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + + "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, " + + // and their equivalent names in the IBM Security provider + "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + + "SSL_RSA_WITH_AES_256_CBC_SHA256, " + + "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256, " + + "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + + "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256" + def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.ssl.enabled", "true") @@ -33,9 +47,8 @@ object SSLSampleConfigs { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", - "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") - conf.set("spark.ssl.protocol", "TLSv1") + conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) + conf.set("spark.ssl.protocol", "TLSv1.2") conf } @@ -47,9 +60,8 @@ object SSLSampleConfigs { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", - "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") - conf.set("spark.ssl.protocol", "TLSv1") + conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) + conf.set("spark.ssl.protocol", "TLSv1.2") conf } diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index e9b64aa82a17a..f34aefca4eb18 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -127,6 +127,17 @@ class SecurityManagerSuite extends SparkFunSuite { test("ssl on setup") { val conf = SSLSampleConfigs.sparkSSLConfig() + val expectedAlgorithms = Set( + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384", + "TLS_RSA_WITH_AES_256_CBC_SHA256", + "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256", + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", + "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256", + "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384", + "SSL_RSA_WITH_AES_256_CBC_SHA256", + "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256", + "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256", + "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256") val securityManager = new SecurityManager(conf) @@ -143,9 +154,8 @@ class SecurityManagerSuite extends SparkFunSuite { assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password")) assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password")) assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) - assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1")) - assert(securityManager.fileServerSSLOptions.enabledAlgorithms === - Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2")) + assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms) assert(securityManager.akkaSSLOptions.trustStore.isDefined === true) assert(securityManager.akkaSSLOptions.trustStore.get.getName === "truststore") @@ -154,9 +164,8 @@ class SecurityManagerSuite extends SparkFunSuite { assert(securityManager.akkaSSLOptions.trustStorePassword === Some("password")) assert(securityManager.akkaSSLOptions.keyStorePassword === Some("password")) assert(securityManager.akkaSSLOptions.keyPassword === Some("password")) - assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1")) - assert(securityManager.akkaSSLOptions.enabledAlgorithms === - Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1.2")) + assert(securityManager.akkaSSLOptions.enabledAlgorithms === expectedAlgorithms) } test("ssl off setup") { From 08fab4843845136358f3a7251e8d90135126b419 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 30 Jun 2015 07:58:49 -0700 Subject: [PATCH 37/39] [SPARK-8590] [SQL] add code gen for ExtractValue TODO: use array instead of Seq as internal representation for `ArrayType` Author: Wenchen Fan Closes #6982 from cloud-fan/extract-value and squashes the following commits: e203bc1 [Wenchen Fan] address comments 4da0f0b [Wenchen Fan] some clean up f679969 [Wenchen Fan] fix bug e64f942 [Wenchen Fan] remove generic e3f8427 [Wenchen Fan] fix style and address comments fc694e8 [Wenchen Fan] add code gen for extract value --- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 46 ++++-- .../catalyst/expressions/ExtractValue.scala | 76 ++++++++-- .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 15 +- .../codegen/GenerateMutableProjection.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 13 +- .../sql/catalyst/expressions/predicates.scala | 3 - .../spark/sql/catalyst/expressions/sets.scala | 4 - .../spark/sql/catalyst/util/TypeUtils.scala | 2 +- .../expressions/ComplexTypeSuite.scala | 131 ++++++++++-------- 11 files changed, 199 insertions(+), 101 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 5db2fcfcb267b..dc0b4ac5cd9bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -47,7 +47,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) s""" boolean ${ev.isNull} = i.isNullAt($ordinal); ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)}); + ${ctx.defaultValue(dataType)} : (${ctx.getColumn("i", dataType, ordinal)}); """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index e5dc7b9b5c884..aed48921bdeb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -179,9 +179,10 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def toString: String = s"($left $symbol $right)" override def isThreadSafe: Boolean = left.isThreadSafe && right.isThreadSafe + /** - * Short hand for generating binary evaluation code, which depends on two sub-evaluations of - * the same type. If either of the sub-expressions is null, the result of this computation + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * * @param f accepts two variable names and returns Java code to compute the output. @@ -190,15 +191,23 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express ctx: CodeGenContext, ev: GeneratedExpressionCode, f: (String, String) => String): String = { - // TODO: Right now some timestamp tests fail if we enforce this... - if (left.dataType != right.dataType) { - // log.warn(s"${left.dataType} != ${right.dataType}") - } + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s"$result = ${f(eval1, eval2)};" + }) + } + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val resultCode = f(eval1.primitive, eval2.primitive) - + val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; @@ -206,7 +215,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express if (!${ev.isNull}) { ${eval2.code} if (!${eval2.isNull}) { - ${ev.primitive} = $resultCode; + $resultCode } else { ${ev.isNull} = true; } @@ -245,13 +254,26 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio ctx: CodeGenContext, ev: GeneratedExpressionCode, f: String => String): String = { + nullSafeCodeGen(ctx, ev, (result, eval) => { + s"$result = ${f(eval)};" + }) + } + + /** + * Called by unary expressions to generate a code block that returns null if its parent returns + * null, and if not not null, use `f` to generate the expression. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { val eval = child.gen(ctx) - // reuse the previous isNull - ev.isNull = eval.isNull + val resultCode = f(ev.primitive, eval.primitive) eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.primitive} = ${f(eval.primitive)}; + $resultCode } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 4d7c95ffd1850..3020e7fc967f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -21,6 +21,7 @@ import scala.collection.Map import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ object ExtractValue { @@ -38,7 +39,7 @@ object ExtractValue { def apply( child: Expression, extraction: Expression, - resolver: Resolver): ExtractValue = { + resolver: Resolver): Expression = { (child.dataType, extraction) match { case (StructType(fields), NonNullLiteral(v, StringType)) => @@ -73,7 +74,7 @@ object ExtractValue { def unapply(g: ExtractValue): Option[(Expression, Expression)] = { g match { case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal)) - case _ => Some((g.child, null)) + case s: ExtractValueWithStruct => Some((s.child, null)) } } @@ -101,11 +102,11 @@ object ExtractValue { * Note: concrete extract value expressions are created only by `ExtractValue.apply`, * we don't need to do type check for them. */ -trait ExtractValue extends UnaryExpression { - self: Product => +trait ExtractValue { + self: Expression => } -abstract class ExtractValueWithStruct extends ExtractValue { +abstract class ExtractValueWithStruct extends UnaryExpression with ExtractValue { self: Product => def field: StructField @@ -125,6 +126,18 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) val baseValue = child.eval(input).asInstanceOf[InternalRow] if (baseValue == null) null else baseValue(ordinal) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval) => { + s""" + if ($eval.isNullAt($ordinal)) { + ${ev.isNull} = true; + } else { + $result = ${ctx.getColumn(eval, dataType, ordinal)}; + } + """ + }) + } } /** @@ -137,6 +150,7 @@ case class GetArrayStructFields( containsNull: Boolean) extends ExtractValueWithStruct { override def dataType: DataType = ArrayType(field.dataType, containsNull) + override def nullable: Boolean = child.nullable || containsNull || field.nullable override def eval(input: InternalRow): Any = { val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]] @@ -146,18 +160,39 @@ case class GetArrayStructFields( } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arraySeqClass = "scala.collection.mutable.ArraySeq" + // TODO: consider using Array[_] for ArrayType child to avoid + // boxing of primitives + nullSafeCodeGen(ctx, ev, (result, eval) => { + s""" + final int n = $eval.size(); + final $arraySeqClass values = new $arraySeqClass(n); + for (int j = 0; j < n; j++) { + InternalRow row = (InternalRow) $eval.apply(j); + if (row != null && !row.isNullAt($ordinal)) { + values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)}); + } + } + $result = (${ctx.javaType(dataType)}) values; + """ + }) + } } -abstract class ExtractValueWithOrdinal extends ExtractValue { +abstract class ExtractValueWithOrdinal extends BinaryExpression with ExtractValue { self: Product => def ordinal: Expression + def child: Expression + + override def left: Expression = child + override def right: Expression = ordinal /** `Null` is returned for invalid ordinals. */ override def nullable: Boolean = true - override def foldable: Boolean = child.foldable && ordinal.foldable override def toString: String = s"$child[$ordinal]" - override def children: Seq[Expression] = child :: ordinal :: Nil override def eval(input: InternalRow): Any = { val value = child.eval(input) @@ -195,6 +230,19 @@ case class GetArrayItem(child: Expression, ordinal: Expression) baseValue(index) } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s""" + final int index = (int)$eval2; + if (index >= $eval1.size() || index < 0) { + ${ev.isNull} = true; + } else { + $result = (${ctx.boxedType(dataType)})$eval1.apply(index); + } + """ + }) + } } /** @@ -209,4 +257,16 @@ case class GetMapValue(child: Expression, ordinal: Expression) val baseValue = value.asInstanceOf[Map[Any, _]] baseValue.get(ordinal).orNull } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s""" + if ($eval1.contains($eval2)) { + $result = (${ctx.boxedType(dataType)})$eval1.apply($eval2); + } else { + ${ev.isNull} = true; + } + """ + }) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 3d4d9e2d798f0..ae765c1653203 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -82,8 +82,6 @@ case class Abs(child: Expression) extends UnaryArithmetic { abstract class BinaryArithmetic extends BinaryExpression { self: Product => - /** Name of the function for this expression on a [[Decimal]] type. */ - def decimalMethod: String = "" override def dataType: DataType = left.dataType @@ -113,6 +111,10 @@ abstract class BinaryArithmetic extends BinaryExpression { } } + /** Name of the function for this expression on a [[Decimal]] type. */ + def decimalMethod: String = + sys.error("BinaryArithmetics must override either decimalMethod or genCode") + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 57e0bede5db20..bf6a6a124088e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -82,24 +82,24 @@ class CodeGenContext { /** * Returns the code to access a column in Row for a given DataType. */ - def getColumn(dataType: DataType, ordinal: Int): String = { + def getColumn(row: String, dataType: DataType, ordinal: Int): String = { val jt = javaType(dataType) if (isPrimitiveType(jt)) { - s"i.get${primitiveTypeName(jt)}($ordinal)" + s"$row.get${primitiveTypeName(jt)}($ordinal)" } else { - s"($jt)i.apply($ordinal)" + s"($jt)$row.apply($ordinal)" } } /** * Returns the code to update a column in Row for a given DataType. */ - def setColumn(dataType: DataType, ordinal: Int, value: String): String = { + def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { val jt = javaType(dataType) if (isPrimitiveType(jt)) { - s"set${primitiveTypeName(jt)}($ordinal, $value)" + s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" } else { - s"update($ordinal, $value)" + s"$row.update($ordinal, $value)" } } @@ -127,6 +127,9 @@ class CodeGenContext { case dt: DecimalType => decimalType case BinaryType => "byte[]" case StringType => stringType + case _: StructType => "InternalRow" + case _: ArrayType => s"scala.collection.Seq" + case _: MapType => s"scala.collection.Map" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 64ef357a4f954..addb8023d9c0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -43,7 +43,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu if(${evaluationCode.isNull}) mutableRow.setNullAt($i); else - mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)}; + ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index a022f3727bd58..da63f2fa970cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -78,17 +78,14 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) def funcName: String = name.toLowerCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval = child.gen(ctx) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); + nullSafeCodeGen(ctx, ev, (result, eval) => { + s""" + ${ev.primitive} = java.lang.Math.${funcName}($eval); if (Double.valueOf(${ev.primitive}).isNaN()) { ${ev.isNull} = true; } - } - """ + """ + }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 386cf6a8df6df..98cd5aa8148c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -69,10 +69,7 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } - case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes { - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"NOT $child" override def expectedChildTypes: Seq[DataType] = Seq(BooleanType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index efc6f50b78943..daa9f4403ffab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -135,8 +135,6 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { */ case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { - override def nullable: Boolean = left.nullable || right.nullable - override def dataType: DataType = left.dataType override def symbol: String = "++=" @@ -185,8 +183,6 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres */ case class CountSet(child: Expression) extends UnaryExpression { - override def nullable: Boolean = child.nullable - override def dataType: DataType = LongType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 8656cc334d09f..3148309a2166f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.types._ /** - * Helper function to check for valid data types + * Helper functions to check for valid data types. */ object TypeUtils { def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index b80911e7257fc..3515d044b2f7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -40,51 +40,42 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } test("GetArrayItem") { + val typeA = ArrayType(StringType) + val array = Literal.create(Seq("a", "b"), typeA) testIntegralDataTypes { convert => - val array = Literal.create(Seq("a", "b"), ArrayType(StringType)) checkEvaluation(GetArrayItem(array, Literal(convert(1))), "b") } + val nullArray = Literal.create(null, typeA) + val nullInt = Literal.create(null, IntegerType) + checkEvaluation(GetArrayItem(nullArray, Literal(1)), null) + checkEvaluation(GetArrayItem(array, nullInt), null) + checkEvaluation(GetArrayItem(nullArray, nullInt), null) + + val nestedArray = Literal.create(Seq(Seq(1)), ArrayType(ArrayType(IntegerType))) + checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1)) } - test("CreateStruct") { - val row = InternalRow(1, 2, 3) - val c1 = 'a.int.at(0).as("a") - val c3 = 'c.int.at(2).as("c") - checkEvaluation(CreateStruct(Seq(c1, c3)), InternalRow(1, 3), row) + test("GetMapValue") { + val typeM = MapType(StringType, StringType) + val map = Literal.create(Map("a" -> "b"), typeM) + val nullMap = Literal.create(null, typeM) + val nullString = Literal.create(null, StringType) + + checkEvaluation(GetMapValue(map, Literal("a")), "b") + checkEvaluation(GetMapValue(map, nullString), null) + checkEvaluation(GetMapValue(nullMap, nullString), null) + checkEvaluation(GetMapValue(map, nullString), null) + + val nestedMap = Literal.create(Map("a" -> Map("b" -> "c")), MapType(StringType, typeM)) + checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c")) } - test("complex type") { - val row = create_row( - "^Ba*n", // 0 - null.asInstanceOf[UTF8String], // 1 - create_row("aa", "bb"), // 2 - Map("aa"->"bb"), // 3 - Seq("aa", "bb") // 4 - ) - - val typeS = StructType( - StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil - ) - val typeMap = MapType(StringType, StringType) - val typeArray = ArrayType(StringType) - - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal("aa")), "bb", row) - checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row) - checkEvaluation( - GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal.create(null, StringType)), null, row) - - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal(1)), "bb", row) - checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row) - checkEvaluation( - GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal.create(null, IntegerType)), null, row) - - def getStructField(expr: Expression, fieldName: String): ExtractValue = { + test("GetStructField") { + val typeS = StructType(StructField("a", IntegerType) :: Nil) + val struct = Literal.create(create_row(1), typeS) + val nullStruct = Literal.create(null, typeS) + + def getStructField(expr: Expression, fieldName: String): GetStructField = { expr.dataType match { case StructType(fields) => val field = fields.find(_.name == fieldName).get @@ -92,28 +83,58 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } } - def quickResolve(u: UnresolvedExtractValue): ExtractValue = { - ExtractValue(u.child, u.extraction, _ == _) - } + checkEvaluation(getStructField(struct, "a"), 1) + checkEvaluation(getStructField(nullStruct, "a"), null) + + val nestedStruct = Literal.create(create_row(create_row(1)), + StructType(StructField("a", typeS) :: Nil)) + checkEvaluation(getStructField(nestedStruct, "a"), create_row(1)) + + val typeS_fieldNotNullable = StructType(StructField("a", IntegerType, false) :: Nil) + val struct_fieldNotNullable = Literal.create(create_row(1), typeS_fieldNotNullable) + val nullStruct_fieldNotNullable = Literal.create(null, typeS_fieldNotNullable) + + assert(getStructField(struct_fieldNotNullable, "a").nullable === false) + assert(getStructField(struct, "a").nullable === true) + assert(getStructField(nullStruct_fieldNotNullable, "a").nullable === true) + assert(getStructField(nullStruct, "a").nullable === true) + } - checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) - checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row) + test("GetArrayStructFields") { + val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val arrayStruct = Literal.create(Seq(create_row(1)), typeAS) + val nullArrayStruct = Literal.create(null, typeAS) - val typeS_notNullable = StructType( - StructField("a", StringType, nullable = false) - :: StructField("b", StringType, nullable = false) :: Nil - ) + def getArrayStructFields(expr: Expression, fieldName: String): GetArrayStructFields = { + expr.dataType match { + case ArrayType(StructType(fields), containsNull) => + val field = fields.find(_.name == fieldName).get + GetArrayStructFields(expr, field, fields.indexOf(field), containsNull) + } + } + + checkEvaluation(getArrayStructFields(arrayStruct, "a"), Seq(1)) + checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null) + } - assert(getStructField(BoundReference(2, typeS, nullable = true), "a").nullable === true) - assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable - === false) + test("CreateStruct") { + val row = create_row(1, 2, 3) + val c1 = 'a.int.at(0).as("a") + val c3 = 'c.int.at(2).as("c") + checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row) + } - assert(getStructField(Literal.create(null, typeS), "a").nullable === true) - assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true) + test("test dsl for complex type") { + def quickResolve(u: UnresolvedExtractValue): Expression = { + ExtractValue(u.child, u.extraction, _ == _) + } - checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row) - checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row) - checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row) + checkEvaluation(quickResolve('c.map(MapType(StringType, StringType)).at(0).getItem("a")), + "b", create_row(Map("a" -> "b"))) + checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)), + "b", create_row(Seq("a", "b"))) + checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")), + 1, create_row(create_row(1))) } test("error message of ExtractValue") { From 865a834e51ac3074811a11fd99a36d942f7f7de8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 30 Jun 2015 08:08:15 -0700 Subject: [PATCH 38/39] [SPARK-8723] [SQL] improve divide and remainder code gen We can avoid execution of both left and right expression by null and zero check. Author: Wenchen Fan Closes #7111 from cloud-fan/cg and squashes the following commits: d6b12ef [Wenchen Fan] improve divide and remainder code gen --- .../sql/catalyst/expressions/arithmetic.scala | 54 ++++++++++++------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ae765c1653203..5363b3556886a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -216,23 +216,32 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val test = if (left.dataType.isInstanceOf[DecimalType]) { + val isZero = if (dataType.isInstanceOf[DecimalType]) { s"${eval2.primitive}.isZero()" } else { s"${eval2.primitive} == 0" } - val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol " - val javaType = ctx.javaType(left.dataType) - eval1.code + eval2.code + - s""" + val javaType = ctx.javaType(dataType) + val divide = if (dataType.isInstanceOf[DecimalType]) { + s"${eval1.primitive}.$decimalMethod(${eval2.primitive})" + } else { + s"($javaType)(${eval1.primitive} $symbol ${eval2.primitive})" + } + s""" + ${eval2.code} boolean ${ev.isNull} = false; - ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; - if (${eval1.isNull} || ${eval2.isNull} || $test) { + $javaType ${ev.primitive} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { - ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive})); + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = $divide; + } } - """ + """ } } @@ -273,23 +282,32 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val test = if (left.dataType.isInstanceOf[DecimalType]) { + val isZero = if (dataType.isInstanceOf[DecimalType]) { s"${eval2.primitive}.isZero()" } else { s"${eval2.primitive} == 0" } - val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol " - val javaType = ctx.javaType(left.dataType) - eval1.code + eval2.code + - s""" + val javaType = ctx.javaType(dataType) + val remainder = if (dataType.isInstanceOf[DecimalType]) { + s"${eval1.primitive}.$decimalMethod(${eval2.primitive})" + } else { + s"($javaType)(${eval1.primitive} $symbol ${eval2.primitive})" + } + s""" + ${eval2.code} boolean ${ev.isNull} = false; - ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; - if (${eval1.isNull} || ${eval2.isNull} || $test) { + $javaType ${ev.primitive} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { - ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive})); + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = $remainder; + } } - """ + """ } } From a48e61915354d33fb98944a8eb5a5d48dd102041 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 30 Jun 2015 08:17:24 -0700 Subject: [PATCH 39/39] [SPARK-8680] [SQL] Slightly improve PropagateTypes JIRA: https://issues.apache.org/jira/browse/SPARK-8680 This PR slightly improve `PropagateTypes` in `HiveTypeCoercion`. It moves `q.inputSet` outside `q transformExpressions` instead calling `inputSet` multiple times. It also builds a map of attributes for looking attribute easily. Author: Liang-Chi Hsieh Closes #7087 from viirya/improve_propagatetypes and squashes the following commits: 5c314c1 [Liang-Chi Hsieh] For comments. 913f6ad [Liang-Chi Hsieh] Slightly improve PropagateTypes. --- .../catalyst/analysis/HiveTypeCoercion.scala | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index c3d68197d64ac..e525ad623ff12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -131,20 +131,22 @@ trait HiveTypeCoercion { // Don't propagate types from unresolved children. case q: LogicalPlan if !q.childrenResolved => q - case q: LogicalPlan => q transformExpressions { - case a: AttributeReference => - q.inputSet.find(_.exprId == a.exprId) match { - // This can happen when a Attribute reference is born in a non-leaf node, for example - // due to a call to an external script like in the Transform operator. - // TODO: Perhaps those should actually be aliases? - case None => a - // Leave the same if the dataTypes match. - case Some(newType) if a.dataType == newType.dataType => a - case Some(newType) => - logDebug(s"Promoting $a to $newType in ${q.simpleString}}") - newType - } - } + case q: LogicalPlan => + val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap + q transformExpressions { + case a: AttributeReference => + inputMap.get(a.exprId) match { + // This can happen when a Attribute reference is born in a non-leaf node, for example + // due to a call to an external script like in the Transform operator. + // TODO: Perhaps those should actually be aliases? + case None => a + // Leave the same if the dataTypes match. + case Some(newType) if a.dataType == newType.dataType => a + case Some(newType) => + logDebug(s"Promoting $a to $newType in ${q.simpleString}}") + newType + } + } } }