Skip to content

Commit

Permalink
Prefix comparators for float and double
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 6, 2015
1 parent b310c88 commit 66a813e
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ private PrefixComparators() {}

public static final IntPrefixComparator INTEGER = new IntPrefixComparator();
public static final LongPrefixComparator LONG = new LongPrefixComparator();
public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator();
public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();

public static final class IntPrefixComparator extends PrefixComparator {
@Override
Expand All @@ -45,4 +47,30 @@ public int compare(long a, long b) {
return (a < b) ? -1 : (a > b) ? 1 : 0;
}
}

public static final class FloatPrefixComparator extends PrefixComparator {
@Override
public int compare(long aPrefix, long bPrefix) {
float a = Float.intBitsToFloat((int) aPrefix);
float b = Float.intBitsToFloat((int) bPrefix);
return (a < b) ? -1 : (a > b) ? 1 : 0;
}

public long computePrefix(float value) {
return Float.floatToIntBits(value) & 0xffffffffL;
}
}

public static final class DoublePrefixComparator extends PrefixComparator {
@Override
public int compare(long aPrefix, long bPrefix) {
double a = Double.longBitsToDouble(aPrefix);
double b = Double.longBitsToDouble(bPrefix);
return (a < b) ? -1 : (a > b) ? 1 : 0;
}

public long computePrefix(double value) {
return Double.doubleToLongBits(value);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression}

import scala.runtime.AbstractFunction1

object GenerateExpression extends CodeGenerator[Expression, InternalRow => Any] {

override protected def canonicalize(in: Expression): Expression = {
ExpressionCanonicalizer.execute(in)
}

override protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = {
BindReferences.bindReference(in, inputSchema)
}

override protected def create(expr: Expression): InternalRow => Any = {
val ctx = newCodeGenContext()
val eval = expr.gen(ctx)
val code =
s"""
|class SpecificExpression extends
| ${classOf[AbstractFunction1[InternalRow, Any]].getName}<${classOf[InternalRow].getName}, Object> {
|
| @Override
| public SpecificExpression generate($exprType[] expr) {
| return new SpecificExpression(expr);
| }
|
| @Override
| public Object apply(InternalRow i) {
| ${eval.code}
| return ${eval.isNull} ? null : ${eval.primitive};
| }
| }
""".stripMargin
logDebug(s"Generated expression '$expr':\n$code")
println(code)
compile(code).generate(ctx.references.toArray).asInstanceOf[InternalRow => Any]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
*/
class CodeGenerationSuite extends SparkFunSuite {

test("generate expression") {
GenerateExpression.generate(Add(Literal(1), Literal(1)))
}

test("multithreaded eval") {
import scala.concurrent._
import ExecutionContext.Implicits.global
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.types.{LongType, IntegerType}
import org.apache.spark.sql.types.{DoubleType, FloatType, LongType, IntegerType}
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator}


Expand All @@ -38,6 +38,8 @@ object SortPrefixUtils {
sortOrder.dataType match {
case IntegerType => PrefixComparators.INTEGER
case LongType => PrefixComparators.LONG
case FloatType => PrefixComparators.FLOAT
case DoubleType => PrefixComparators.DOUBLE
case _ => NoOpPrefixComparator
}
}
Expand All @@ -47,6 +49,10 @@ object SortPrefixUtils {
case IntegerType => (row: InternalRow) =>
PrefixComparators.INTEGER.computePrefix(sortOrder.child.eval(row).asInstanceOf[Int])
case LongType => (row: InternalRow) => sortOrder.child.eval(row).asInstanceOf[Long]
case FloatType => (row: InternalRow) =>
PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float])
case DoubleType => (row: InternalRow) =>
PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double])
case _ => (row: InternalRow) => 0L
}
}
Expand Down

0 comments on commit 66a813e

Please sign in to comment.