Skip to content

Commit

Permalink
Made some updates to runtime type conformance and implemented sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobfi committed Mar 20, 2024
1 parent ebf43fb commit ba27bc7
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 74 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Expand Up @@ -4,7 +4,7 @@ name := "crossbow"

version := "0.2.0"

scalaVersion := "3.3.1"
scalaVersion := "3.3.3"

scalacOptions ++= Seq("-deprecation", "-feature", "-language:implicitConversions")

Expand Down
32 changes: 14 additions & 18 deletions src/main/scala/com/audienceproject/crossbow/DataFrame.scala
@@ -1,14 +1,14 @@
package com.audienceproject.crossbow

import com.audienceproject.crossbow.algorithms.{GroupBy, SortMergeJoin}
import com.audienceproject.crossbow.exceptions.{IncorrectTypeException, JoinException}
import com.audienceproject.crossbow.exceptions.JoinException
import com.audienceproject.crossbow.expr.*
import com.audienceproject.crossbow.schema.{Column, Schema}

import scala.util.Sorting

class DataFrame private(
private val columnData: Vector[Array[_]], val schema: Schema,
private val columnData: Vector[Array[?]], val schema: Schema,
private val sortKey: Option[Expr] = None):

val rowCount: Int = if (columnData.isEmpty) 0 else columnData.head.length
Expand Down Expand Up @@ -51,11 +51,7 @@ class DataFrame private(
* @tparam T the type of a row in this DataFrame
* @return [[TypedView]] on the contents of this DataFrame
*/
def as[T: TypeTag]: TypedView[T] =
val schemaType = schema.toRuntimeType
val expectedType = summon[TypeTag[T]].runtimeType
if (expectedType == schemaType) TypedView[T]
else throw IncorrectTypeException(expectedType, schemaType)
def as[T]: TypedView[T] = TypedView[T]

/**
* Add a column to the DataFrame, evaluating to 'expr' at each individual row index.
Expand Down Expand Up @@ -126,9 +122,9 @@ class DataFrame private(
*/
def explode(expr: DataFrame ?=> Expr): DataFrame =
given DataFrame = this
val eval = expr.typecheckAs[Seq[_]]
val eval = expr.typecheckAs[List[?]]
val RuntimeType.List(innerType) = expr.typeOf: @unchecked // This unapply is safe due to typecheck.
val nestedCol = fillArray[Seq[_]](0 until rowCount, eval.apply)
val nestedCol = fillArray[Seq[?]](0 until rowCount, eval.apply)
val reps = nestedCol.map(_.size)
val colData = for (i <- 0 until numColumns) yield repeatColumn(columnData(i), schema.columns(i).columnType, reps)
val explodedColSchema = expr match
Expand All @@ -150,23 +146,23 @@ class DataFrame private(
/**
* Sort this DataFrame by the evaluation of 'expr'. If a natural ordering exists on this value, it will be used.
* User-defined orderings on other types or for overwriting the natural orderings with an explicit ordering can be
* supplied through the 'givenOrderings' argument.
* supplied through the 'order' argument in the using clause.
*
* @param expr the [[Expr]] to evaluate as a sort key
* @param givenOrderings explicit [[Order]] to use on the sort key, or list of [[Order]] if the key is a tuple
* @param stable whether the sorting should be stable or not - quicksort is used if not, else mergesort
* @param expr the [[Expr]] to evaluate as a sort key
* @param stable whether the sorting should be stable or not - quicksort is used if not, else mergesort
* @param order explicit [[Order]] to use on the sort key
* @return new DataFrame
*/
def sortBy(expr: DataFrame ?=> Expr, givenOrderings: Seq[Order] = Seq.empty, stable: Boolean = false): DataFrame =
def sortBy(expr: DataFrame ?=> Expr, stable: Boolean = false)(using order: Order = Order.Implicit): DataFrame =
given DataFrame = this
if (sortKey.contains(expr) && givenOrderings.isEmpty) this
if (sortKey.contains(expr) && order == Order.Implicit) this
else
val ord = Order.getOrdering(expr.typeOf, givenOrderings)
val indices = Array.tabulate(rowCount)(identity)
val ord = order.getOrdering(expr.typeOf).asInstanceOf[Ordering[Any]]
given Ordering[Int] = (x: Int, y: Int) => ord.compare(expr.eval(x), expr.eval(y))
if (stable) Sorting.stableSort[Int](indices)
else Sorting.quickSort[Int](indices)
slice(indices.toIndexedSeq, if (givenOrderings.isEmpty) Some(expr) else None)
slice(indices.toIndexedSeq, Option.when(order == Order.Implicit)(expr))

/**
* Join this DataFrame on another DataFrame, with the key evaluated by 'joinExpr'.
Expand All @@ -183,7 +179,7 @@ class DataFrame private(
val left = joinExpr(using this)
val right = joinExpr(using other)
if (left.typeOf != right.typeOf) throw JoinException(left)
val ordering = Order.getOrdering(left.typeOf)
val ordering = Order.Implicit.getOrdering(left.typeOf).asInstanceOf[Ordering[Any]]
SortMergeJoin(this.sortBy(joinExpr), other.sortBy(joinExpr), joinExpr, joinType, ordering)

/**
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/com/audienceproject/crossbow/dsl.scala
Expand Up @@ -56,4 +56,4 @@ def collect(expr: Expr): Expr = Expr.Aggregate[Any, Seq[Any]](expr, (e, seq) =>
def one(expr: Expr): Expr = Expr.Aggregate[Any, Any](expr, (elem, _) => elem, null)

// Custom aggregator.
def reducer[T, U](seed: U)(f: (T, U) => U): Expr => Expr = Expr.Aggregate[T, U](_, f, seed)
def reducer[T: TypeTag, U: TypeTag](seed: U)(f: (T, U) => U): Expr => Expr = Expr.Aggregate[T, U](_, f, seed)
4 changes: 2 additions & 2 deletions src/main/scala/com/audienceproject/crossbow/expr/Expr.scala
Expand Up @@ -10,8 +10,8 @@ sealed trait Expr extends BaseOps, ArithmeticOps, BooleanOps, ComparisonOps:

private[crossbow] def typecheckAs[T: TypeTag]: Int => T =
val expectedType = summon[TypeTag[T]].runtimeType
if (expectedType == typeOf) eval.asInstanceOf[Int => T]
else throw new IncorrectTypeException(expectedType, typeOf)
if expectedType.compatible(typeOf) then eval.asInstanceOf[Int => T]
else throw IncorrectTypeException(expectedType, typeOf)

private[crossbow] object Expr:

Expand Down
63 changes: 26 additions & 37 deletions src/main/scala/com/audienceproject/crossbow/expr/Order.scala
@@ -1,42 +1,31 @@
package com.audienceproject.crossbow.expr

import com.audienceproject.crossbow.exceptions.NoOrderingException
import com.audienceproject.crossbow.exceptions.{IncorrectTypeException, NoOrderingException}

class Order private(private val ord: Ordering[_], private val internalType: RuntimeType)
enum Order:
case Implicit
case Explicit[T](private val ordering: Ordering[T], private val onType: RuntimeType)

object Order {
private[crossbow] def getOrdering(runtimeType: RuntimeType): Ordering[?] = this match
case Order.Implicit => runtimeType match
case RuntimeType.Int => Ordering.Int
case RuntimeType.Long => Ordering.Long
case RuntimeType.Double => Ordering.Double.TotalOrdering
case RuntimeType.Boolean => Ordering.Boolean
case RuntimeType.String => Ordering.String
case RuntimeType.Product(elementTypes*) =>
elementTypes.map(getOrdering) match
case Seq(o1, o2) => Ordering.Tuple2(o1, o2)
case Seq(o1, o2, o3) => Ordering.Tuple3(o1, o2, o3)
case Seq(o1, o2, o3, o4) => Ordering.Tuple4(o1, o2, o3, o4)
case Seq(o1, o2, o3, o4, o5) => Ordering.Tuple5(o1, o2, o3, o4, o5)
case Seq(o1, o2, o3, o4, o5, o6) => Ordering.Tuple6(o1, o2, o3, o4, o5, o6)
case _ => throw NoOrderingException(runtimeType)
case RuntimeType.List(elementType) => Ordering.Implicits.seqOrdering(getOrdering(elementType))
case _ => throw NoOrderingException(runtimeType)
case Order.Explicit(ordering, onType) =>
if onType.compatible(runtimeType) then ordering
else throw IncorrectTypeException(onType, runtimeType)

def by[T: TypeTag](ord: Ordering[T]): Order = new Order(ord, summon[TypeTag[T]].runtimeType)

private[crossbow] def getOrdering(internalType: RuntimeType, givens: Seq[Order] = Seq.empty): Ordering[Any] = {
givens.find(_.internalType == internalType) match {
case Some(explicitOrdering) => explicitOrdering.ord.asInstanceOf[Ordering[Any]]
case None =>
val implicitOrdering = internalType match {
case RuntimeType.Int => Ordering.Int
case RuntimeType.Long => Ordering.Long
case RuntimeType.Double => DoubleOrdering
case RuntimeType.Boolean => Ordering.Boolean
case RuntimeType.Product(elementTypes*) =>
val tupleTypes = elementTypes.map(getOrdering(_, givens))
tupleTypes match {
case Seq(o1, o2) => Ordering.Tuple2(o1, o2)
case Seq(o1, o2, o3) => Ordering.Tuple3(o1, o2, o3)
case Seq(o1, o2, o3, o4) => Ordering.Tuple4(o1, o2, o3, o4)
case Seq(o1, o2, o3, o4, o5) => Ordering.Tuple5(o1, o2, o3, o4, o5)
case Seq(o1, o2, o3, o4, o5, o6) => Ordering.Tuple6(o1, o2, o3, o4, o5, o6)
case _ => throw new NoOrderingException(internalType)
}
case _: RuntimeType.List => throw new NoOrderingException(internalType)
//case AnyType(runtimeType) if runtimeType =:= ru.typeOf[String] => Ordering.String // TODO
case _ => throw new NoOrderingException(internalType)
}
implicitOrdering.asInstanceOf[Ordering[Any]]
}
}

private object DoubleOrdering extends Ordering[Double] {
override def compare(x: Double, y: Double): Int = java.lang.Double.compare(x, y)
}

}
object Order:
def by[T: TypeTag](ord: Ordering[T]): Order = Explicit(ord, summon[TypeTag[T]].runtimeType)
Expand Up @@ -2,16 +2,23 @@ package com.audienceproject.crossbow.expr

enum RuntimeType:
case Int, Long, Double, Boolean, String
case Any(typeName: String)
case Product(elementTypes: RuntimeType*)
case List(elementType: RuntimeType)
case Generic(typeName: String)

def compatible(actual: RuntimeType): Boolean = (this, actual) match
case (Product(xs*), Product(ys*)) if xs.length == ys.length => xs.zip(ys).forall(_ compatible _)
case (List(x), List(y)) => x compatible y
case (Generic("scala.Any"), _) => true // TODO: Make an explicit RuntimeType.Any - the macro isn't obvious though
case (Generic(_), Generic(_)) => true
case (x, y) => x == y

override def toString: String = this match
case Int => "int"
case Long => "long"
case Double => "double"
case Boolean => "boolean"
case String => "string"
case Any(typeName) => typeName
case Product(elementTypes*) => s"(${elementTypes.mkString(",")})"
case List(elementType) => s"List($elementType)"
case Generic(typeName) => typeName
Expand Up @@ -21,7 +21,7 @@ object TypeTag:
case '[String] => '{ RuntimeType.String }
case '[head *: tail] => '{ RuntimeType.Product(${ Expr.ofList(tupleToList[head *: tail]) } *) }
case '[Seq[t]] => '{ RuntimeType.List(${ getRuntimeTypeImpl[t] }) }
case _ => '{ RuntimeType.Any(${ Expr(Type.show[T]) }) }
case _ => '{ RuntimeType.Generic(${ Expr(Type.show[T]) }) }

private def tupleToList[T: Type](using Quotes): List[Expr[RuntimeType]] =
Type.of[T] match
Expand Down
10 changes: 5 additions & 5 deletions src/main/scala/com/audienceproject/crossbow/slicing.scala
Expand Up @@ -4,15 +4,15 @@ import com.audienceproject.crossbow.expr.RuntimeType

import scala.reflect.ClassTag

private def sliceColumn(eval: Int => ?, ofType: RuntimeType, indices: IndexedSeq[Int]): Array[_] =
private def sliceColumn(eval: Int => ?, ofType: RuntimeType, indices: IndexedSeq[Int]): Array[?] =
ofType match
case RuntimeType.Int => fillArray[Int](indices, eval.asInstanceOf[Int => Int])
case RuntimeType.Long => fillArray[Long](indices, eval.asInstanceOf[Int => Long])
case RuntimeType.Double => fillArray[Double](indices, eval.asInstanceOf[Int => Double])
case RuntimeType.Boolean => fillArray[Boolean](indices, eval.asInstanceOf[Int => Boolean])
case _ => fillArray[Any](indices, eval)

private def padColumn(data: Array[_], ofType: RuntimeType, padding: Int): Array[_] =
private def padColumn(data: Array[?], ofType: RuntimeType, padding: Int): Array[?] =
ofType match
case RuntimeType.Int => fillArray[Int](data.indices, data.asInstanceOf[Array[Int]], padding)
case RuntimeType.Long => fillArray[Long](data.indices, data.asInstanceOf[Array[Long]], padding)
Expand All @@ -26,7 +26,7 @@ private def fillArray[T: ClassTag](indices: IndexedSeq[Int], getValue: Int => T,
for (i <- indices.indices if indices(i) >= 0) arr(i + indexOffset) = getValue(indices(i))
arr

private def spliceColumns(data: Seq[Array[_]], ofType: RuntimeType): Array[_] =
private def spliceColumns(data: Seq[Array[?]], ofType: RuntimeType): Array[?] =
ofType match
case RuntimeType.Int => fillNArray[Int](data.map(_.asInstanceOf[Array[Int]]))
case RuntimeType.Long => fillNArray[Long](data.map(_.asInstanceOf[Array[Long]]))
Expand All @@ -42,7 +42,7 @@ private def fillNArray[T: ClassTag](nData: Seq[Array[T]]): Array[T] =
data.length + offset
arr

private def repeatColumn(data: Array[_], ofType: RuntimeType, reps: Array[Int]): Array[_] =
private def repeatColumn(data: Array[?], ofType: RuntimeType, reps: Array[Int]): Array[?] =
ofType match
case RuntimeType.Int => fillRepeatArray[Int](data.asInstanceOf[Array[Int]], reps)
case RuntimeType.Long => fillRepeatArray[Long](data.asInstanceOf[Array[Long]], reps)
Expand All @@ -59,7 +59,7 @@ private def fillRepeatArray[T: ClassTag](data: Array[T], reps: Array[Int]): Arra
next
arr

private def convert(data: Seq[Any], dataType: RuntimeType): Array[_] =
private def convert(data: Seq[Any], dataType: RuntimeType): Array[?] =
dataType match
case RuntimeType.Int => data.asInstanceOf[Seq[Int]].toArray
case RuntimeType.Long => data.asInstanceOf[Seq[Long]].toArray
Expand Down
20 changes: 20 additions & 0 deletions src/test/scala/com/audienceproject/crossbow/core/ExprTest.scala
@@ -0,0 +1,20 @@
package com.audienceproject.crossbow.core

import com.audienceproject.crossbow.*
import org.scalatest.funsuite.AnyFunSuite

class ExprTest extends AnyFunSuite:

test("Typecheck generic lambda"):
val df = DataFrame.fromSeq(Seq(("a", Some(1)), ("b", None), ("c", Some(2))))
val getOrElse = lambda[Option[Int], Int](_.getOrElse(0))
val result = df.select(getOrElse($"_1")).as[Int].toSeq
assertResult(Seq(1, 0, 2))(result)

val df2 = DataFrame.fromSeq(Seq(List(Some(1)), List(None), List(Some(2))))
val unwrap = lambda[List[Option[Int]], Int](_.head.getOrElse(0))
val result2 = df2.select(unwrap($"_0")).as[Int].toSeq
assertResult(Seq(1, 0, 2))(result)

val unwrapInvalidType = lambda[List[Some[Int]], Int](_.head.get)
assertThrows[ClassCastException](df2.select(unwrapInvalidType($"_0")).as[Int].toSeq)
Expand Up @@ -15,7 +15,7 @@ class SortByTest extends AnyFunSuite:
assert(result == expected)

test("sortBy on single column with explicit ordering"):
val result = df.sortBy($"v", Seq(Order.by(Ordering.Int.reverse))).select($"k").as[String].toSeq
val result = df.sortBy($"v")(using Order.by(Ordering.Int.reverse)).select($"k").as[String].toSeq
val expected = Seq("d", "c", "b", "a")
assert(result == expected)

Expand All @@ -26,15 +26,15 @@ class SortByTest extends AnyFunSuite:

private case class Custom(x: Int)

test("sortBy with explicit ordering on custom type"):
test("No ordering on custom type"):
val makeCustom = lambda[Int, Custom](Custom.apply)
val customDf = df.addColumn(makeCustom($"v") as "custom")

assertThrows[NoOrderingException](customDf.sortBy($"custom"))

val customOrdering = new Ordering[Custom] {
override def compare(c1: Custom, c2: Custom): Int = c1.x - c2.x
}
val result = customDf.sortBy($"custom", Seq(Order.by(customOrdering))).select($"k").as[String].toSeq
test("sortBy with explicit ordering on custom type"):
val makeCustom = lambda[Int, Custom](Custom.apply)
val customDf = df.addColumn(makeCustom($"v") as "custom")
given customOrdering: Order = Order.by((c1: Custom, c2: Custom) => c1.x - c2.x)
val result = customDf.sortBy($"custom").select($"k").as[String].toSeq
val expected = Seq("a", "b", "c", "d")
assert(result == expected)

0 comments on commit ba27bc7

Please sign in to comment.