Skip to content

Commit

Permalink
Progress on implementing the DSL according to Scala 3 idioms
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobfi committed Feb 21, 2024
1 parent 7d79bf3 commit 0cc2ef6
Show file tree
Hide file tree
Showing 21 changed files with 228 additions and 283 deletions.
6 changes: 4 additions & 2 deletions README.md
Expand Up @@ -13,20 +13,22 @@ The library is available through Maven Central.
SBT style dependency: `"com.audienceproject" %% "crossbow" % "latest"`

# API

```scala
import com.audienceproject.crossbow.DataFrame
import com.audienceproject.crossbow.Implicits._
import com.audienceproject.crossbow.*

val data = Seq(("a", 1), ("b", 2), ("c", 3))
val df = DataFrame.fromSeq(data)

df.printSchema()

/**
* _0: String
* _1: int
*/

df.as[(String, Int)].foreach(println)

/**
* ("a", 1)
* ("b", 2)
Expand Down
2 changes: 1 addition & 1 deletion build.sbt
Expand Up @@ -6,7 +6,7 @@ version := "0.2.0"

scalaVersion := "3.3.1"

scalacOptions ++= Seq("-deprecation", "-feature")
scalacOptions ++= Seq("-deprecation", "-feature", "-language:implicitConversions")

libraryDependencies += "org.scalatest" %% "scalatest-funsuite" % "3.2.17" % "test"

Expand Down
193 changes: 80 additions & 113 deletions src/main/scala/com/audienceproject/crossbow/DataFrame.scala

Large diffs are not rendered by default.

44 changes: 0 additions & 44 deletions src/main/scala/com/audienceproject/crossbow/Implicits.scala

This file was deleted.

4 changes: 0 additions & 4 deletions src/main/scala/com/audienceproject/crossbow/JoinType.scala

This file was deleted.

51 changes: 51 additions & 0 deletions src/main/scala/com/audienceproject/crossbow/dsl.scala
@@ -0,0 +1,51 @@
package com.audienceproject.crossbow

import com.audienceproject.crossbow.exceptions.InvalidExpressionException
import com.audienceproject.crossbow.expr.*

enum JoinType:
case Inner, FullOuter, LeftOuter, RightOuter

// Column expression.
extension (sc: StringContext)
def $(args: Any*): DataFrame ?=> Expr = Expr.Cell(sc.s(args: _*))

extension [T: TypeTag](seq: Seq[T])
def toDataFrame(names: String*): DataFrame =
if (names.isEmpty) DataFrame.fromSeq[T](seq)
else DataFrame.fromSeq[T](seq).renameColumns(names *)

// Literal value.
def lit[T: TypeTag](value: T): Expr = Expr.Literal(value)

given Conversion[Int, Expr] = lit[Int](_)
given Conversion[Long, Expr] = lit[Long](_)
given Conversion[Double, Expr] = lit[Double](_)
given Conversion[Boolean, Expr] = lit[Boolean](_)
given Conversion[String, Expr] = lit[String](_)

// Index function.
def index(): DataFrame ?=> Expr = Expr.Index()

// Lambda function.
def lambda[T, R](f: T => R): Expr => Expr =
(expr: Expr) => Expr.Unary(expr, f)

// Sequence of values.
def seq(exprs: Expr*): Expr = Expr.List(exprs)

// Aggregators.
def sum(expr: Expr): Expr = expr.typeOf match
case RuntimeType.Int => Expr.Aggregate[Int, Int](expr, _ + _, 0)
case RuntimeType.Long => Expr.Aggregate[Long, Long](expr, _ + _, 0L)
case RuntimeType.Double => Expr.Aggregate[Double, Double](expr, _ + _, 0d)
case _ => throw new InvalidExpressionException("sum", expr)

def count(): Expr = Expr.Aggregate[Any, Int](lit(1), (_, x) => x + 1, 0)

def collect(expr: Expr): Expr = Expr.Aggregate[Any, Seq[Any]](expr, (e, seq) => seq :+ e, Vector.empty)

def one(expr: Expr): Expr = Expr.Aggregate[Any, Any](expr, (elem, _) => elem, null)

// Custom aggregator.
def reducer[T, U](seed: U)(f: (T, U) => U): Expr => Expr = Expr.Aggregate[T, U](_, f, seed)
Expand Up @@ -3,4 +3,4 @@ package com.audienceproject.crossbow.exceptions
import com.audienceproject.crossbow.expr.RuntimeType

class IncorrectTypeException(expected: RuntimeType, actual: RuntimeType)
extends RuntimeException(s"Expected $expected, but Expr type was $actual")
extends RuntimeException(s"Expected $expected, but runtime type was $actual")
Expand Up @@ -4,7 +4,9 @@ import com.audienceproject.crossbow.exceptions.InvalidExpressionException

import scala.annotation.targetName

extension (x: Expr)
trait ArithmeticOps:

x: Expr =>

@targetName("plus")
def +(y: Expr): Expr = (x.typeOf, y.typeOf) match
Expand Down
Expand Up @@ -2,7 +2,9 @@ package com.audienceproject.crossbow.expr

import scala.annotation.targetName

extension (x: Expr)
trait BaseOps:

x: Expr =>

@targetName("equals")
def ===(y: Expr): Expr = Expr.Binary[Any, Any, Boolean](x, y, _ == _)
Expand Down
Expand Up @@ -4,7 +4,9 @@ import com.audienceproject.crossbow.exceptions.InvalidExpressionException

import scala.annotation.targetName

extension (x: Expr)
trait BooleanOps:

x: Expr =>

def not: Expr = x.typeOf match
case RuntimeType.Boolean => Expr.Unary[Boolean, Boolean](x, !_)
Expand Down
Expand Up @@ -4,7 +4,9 @@ import com.audienceproject.crossbow.exceptions.InvalidExpressionException

import scala.annotation.targetName

extension (x: Expr)
trait ComparisonOps:

x: Expr =>

@targetName("gt")
def >(y: Expr): Expr = (x.typeOf, y.typeOf) match
Expand All @@ -20,7 +22,7 @@ extension (x: Expr)
case (RuntimeType.Double, RuntimeType.Long) => Expr.Binary[Double, Long, Boolean](x, y, _ > _)
case (RuntimeType.Double, RuntimeType.Int) => Expr.Binary[Double, Int, Boolean](x, y, _ > _)
case (RuntimeType.Double, RuntimeType.Double) => Expr.Binary[Double, Double, Boolean](x, y, _ > _)
case _ => throw new InvalidExpressionException("gt", x, y)
case _ => throw InvalidExpressionException("gt", x, y)

@targetName("lt")
def <(y: Expr): Expr = y > x
Expand Down
6 changes: 2 additions & 4 deletions src/main/scala/com/audienceproject/crossbow/expr/Expr.scala
Expand Up @@ -3,16 +3,14 @@ package com.audienceproject.crossbow.expr
import com.audienceproject.crossbow.DataFrame
import com.audienceproject.crossbow.exceptions.{IncorrectTypeException, InvalidExpressionException}

sealed trait Expr:
sealed trait Expr extends BaseOps, ArithmeticOps, BooleanOps, ComparisonOps:
private[crossbow] val typeOf: RuntimeType

private[crossbow] def eval(i: Int): Any

private[crossbow] def as[T]: Int => T = eval.asInstanceOf[Int => T]

private[crossbow] def typecheckAs[T: TypeTag]: Int => T =
val expectedType = summon[TypeTag[T]].runtimeType
if (expectedType == typeOf) as[T]
if (expectedType == typeOf) eval.asInstanceOf[Int => T]
else throw new IncorrectTypeException(expectedType, typeOf)

private[crossbow] object Expr:
Expand Down
15 changes: 7 additions & 8 deletions src/main/scala/com/audienceproject/crossbow/expr/TypeTag.scala
Expand Up @@ -7,24 +7,23 @@ trait TypeTag[T]:

object TypeTag:

inline def getRuntimeType[T]: RuntimeType = ${ getRuntimeTypeImpl[T] }
inline given of[T]: TypeTag[T] = new TypeTag[T]:
override val runtimeType: RuntimeType = getRuntimeType[T]

private inline def getRuntimeType[T]: RuntimeType = ${ getRuntimeTypeImpl[T] }

private def getRuntimeTypeImpl[T: Type](using Quotes): Expr[RuntimeType] =
Type.of[T] match
case '[Int] => '{ RuntimeType.Int }
case '[Long] => '{ RuntimeType.Long }
case '[Double] => '{ RuntimeType.Double }
case '[Boolean] => '{ RuntimeType.Boolean }
case '[java.lang.String] => '{ RuntimeType.String }
case '[head *: tail] =>
Expr.ofList(tupleToList[head *: tail]) match
case '{ $productTypes } => '{ RuntimeType.Product(${ productTypes } *) }
case '[String] => '{ RuntimeType.String }
case '[head *: tail] => '{ RuntimeType.Product(${ Expr.ofList(tupleToList[head *: tail]) } *) }
case '[Seq[t]] => '{ RuntimeType.List(${ getRuntimeTypeImpl[t] }) }
case _ => '{ RuntimeType.Any(${ Expr(Type.show[T]) }) }

private def tupleToList[T: Type](using Quotes): List[Expr[RuntimeType]] =
Type.of[T] match
case '[head *: tail] => getRuntimeTypeImpl[head] :: tupleToList[tail]
case _ => Nil

inline given typeTag[T]: TypeTag[T] = new TypeTag[T]:
override val runtimeType: RuntimeType = TypeTag.getRuntimeType[T]
26 changes: 15 additions & 11 deletions src/main/scala/com/audienceproject/crossbow/schema/Schema.scala
@@ -1,38 +1,42 @@
package com.audienceproject.crossbow.schema

case class Schema(columns: Seq[Column] = Seq.empty) {
import com.audienceproject.crossbow.expr.RuntimeType

case class Schema(columns: Seq[Column] = Seq.empty):

val size: Int = columns.length

def add(column: Column): Schema = Schema(columns :+ column)

def get(columnName: String): Column = columns.find(_.name == columnName).getOrElse(
throw new NoSuchElementException(s"Schema does not contain a column with name '$columnName''")
)
def get(columnName: String): Column =
columns.find(_.name == columnName).getOrElse:
throw new NoSuchElementException(s"Schema does not contain a column with name '$columnName''")

def indexOf(columnName: String): Int = {
def indexOf(columnName: String): Int =
val index = columns.indexWhere(_.name == columnName)
if (index >= 0) index
else throw new NoSuchElementException(s"Schema does not contain a column with name '$columnName''")
}

private[crossbow] def matchColumns(other: Schema): Seq[(Column, Either[(Int, Int), Either[Int, Int]])] = {
def toRuntimeType: RuntimeType = columns match
case Seq(single) => single.columnType
case _ => RuntimeType.Product(columns.map(_.columnType) *)

private[crossbow] def matchColumns(other: Schema): Seq[(Column, Either[(Int, Int), Either[Int, Int]])] =
val leftMatches = columns.zipWithIndex.map({
case (thisCol, i) =>
val j = other.columns.indexWhere(_.name == thisCol.name)
if (j >= 0) {
val otherCol = other.columns(j)
if (thisCol.columnType != otherCol.columnType)
throw new IllegalArgumentException(s"Columns $thisCol and $otherCol do not match. Please check schemas.")
throw IllegalArgumentException(s"Columns $thisCol and $otherCol do not match. Please check schemas.")
else (thisCol, Left(i, j))
} else (thisCol, Right(Left(i)))
})
val rightMatches = other.columns.zipWithIndex.collect({
case (otherCol, j) if !columns.exists(_.name == otherCol.name) => (otherCol, Right(Right(j)))
})
leftMatches ++ rightMatches
}
end matchColumns

override def toString: String = columns.map(_.toString).mkString("\n")

}
end Schema
Expand Up @@ -4,53 +4,53 @@ import com.audienceproject.crossbow.expr.RuntimeType

import scala.reflect.ClassTag

private[crossbow] def sliceColumn(eval: Int => ?, ofType: RuntimeType, indices: IndexedSeq[Int]): Array[_] =
private def sliceColumn(eval: Int => ?, ofType: RuntimeType, indices: IndexedSeq[Int]): Array[_] =
ofType match
case RuntimeType.Int => fillArray[Int](indices, eval.asInstanceOf[Int => Int])
case RuntimeType.Long => fillArray[Long](indices, eval.asInstanceOf[Int => Long])
case RuntimeType.Double => fillArray[Double](indices, eval.asInstanceOf[Int => Double])
case RuntimeType.Boolean => fillArray[Boolean](indices, eval.asInstanceOf[Int => Boolean])
case _ => fillArray[Any](indices, eval)

private[crossbow] def padColumn(data: Array[_], ofType: RuntimeType, padding: Int): Array[_] =
private def padColumn(data: Array[_], ofType: RuntimeType, padding: Int): Array[_] =
ofType match
case RuntimeType.Int => fillArray[Int](data.indices, data.asInstanceOf[Array[Int]], padding)
case RuntimeType.Long => fillArray[Long](data.indices, data.asInstanceOf[Array[Long]], padding)
case RuntimeType.Double => fillArray[Double](data.indices, data.asInstanceOf[Array[Double]], padding)
case RuntimeType.Boolean => fillArray[Boolean](data.indices, data.asInstanceOf[Array[Boolean]], padding)
case _ => fillArray[Any](data.indices, data.asInstanceOf[Array[Any]], padding)

private[crossbow] def fillArray[T: ClassTag](indices: IndexedSeq[Int], getValue: Int => T, padding: Int = 0): Array[T] =
private def fillArray[T: ClassTag](indices: IndexedSeq[Int], getValue: Int => T, padding: Int = 0): Array[T] =
val arr = new Array[T](indices.size + math.abs(padding))
val indexOffset = math.max(padding, 0)
for (i <- indices.indices if indices(i) >= 0) arr(i + indexOffset) = getValue(indices(i))
arr

private[crossbow] def spliceColumns(data: Seq[Array[_]], ofType: RuntimeType): Array[_] =
private def spliceColumns(data: Seq[Array[_]], ofType: RuntimeType): Array[_] =
ofType match
case RuntimeType.Int => fillNArray[Int](data.map(_.asInstanceOf[Array[Int]]))
case RuntimeType.Long => fillNArray[Long](data.map(_.asInstanceOf[Array[Long]]))
case RuntimeType.Double => fillNArray[Double](data.map(_.asInstanceOf[Array[Double]]))
case RuntimeType.Boolean => fillNArray[Boolean](data.map(_.asInstanceOf[Array[Boolean]]))
case _ => fillNArray[Any](data.map(_.asInstanceOf[Array[Any]]))

private[crossbow] def fillNArray[T: ClassTag](nData: Seq[Array[T]]): Array[T] =
private def fillNArray[T: ClassTag](nData: Seq[Array[T]]): Array[T] =
val arr = new Array[T](nData.map(_.length).sum)
nData.foldLeft(0):
case (offset, data) =>
for (i <- data.indices) arr(i + offset) = data(i)
data.length + offset
arr

private[crossbow] def repeatColumn(data: Array[_], ofType: RuntimeType, reps: Array[Int]): Array[_] =
private def repeatColumn(data: Array[_], ofType: RuntimeType, reps: Array[Int]): Array[_] =
ofType match
case RuntimeType.Int => fillRepeatArray[Int](data.asInstanceOf[Array[Int]], reps)
case RuntimeType.Long => fillRepeatArray[Long](data.asInstanceOf[Array[Long]], reps)
case RuntimeType.Double => fillRepeatArray[Double](data.asInstanceOf[Array[Double]], reps)
case RuntimeType.Boolean => fillRepeatArray[Boolean](data.asInstanceOf[Array[Boolean]], reps)
case _ => fillRepeatArray[Any](data.asInstanceOf[Array[Any]], reps)

private[crossbow] def fillRepeatArray[T: ClassTag](data: Array[T], reps: Array[Int]): Array[T] =
private def fillRepeatArray[T: ClassTag](data: Array[T], reps: Array[Int]): Array[T] =
val arr = new Array[T](reps.sum)
reps.indices.foldLeft(0):
case (offset, i) =>
Expand All @@ -59,7 +59,7 @@ private[crossbow] def fillRepeatArray[T: ClassTag](data: Array[T], reps: Array[I
next
arr

private[crossbow] def convert(data: Seq[Any], dataType: RuntimeType): Array[_] =
private def convert(data: Seq[Any], dataType: RuntimeType): Array[_] =
dataType match
case RuntimeType.Int => data.asInstanceOf[Seq[Int]].toArray
case RuntimeType.Long => data.asInstanceOf[Seq[Long]].toArray
Expand Down

0 comments on commit 0cc2ef6

Please sign in to comment.