diff --git a/src/main/scala/com/audienceproject/crossbow/DataFrame.scala b/src/main/scala/com/audienceproject/crossbow/DataFrame.scala index 171e7ac..0dfcdb6 100644 --- a/src/main/scala/com/audienceproject/crossbow/DataFrame.scala +++ b/src/main/scala/com/audienceproject/crossbow/DataFrame.scala @@ -39,7 +39,7 @@ class DataFrame private(private val columnData: Vector[Array[_]], * @return new DataFrame */ def apply(columnNames: String*): DataFrame = { - val colExprs = columnNames.map(Expr.Column) + val colExprs = columnNames.map(Expr.Cell) select(colExprs: _*) } @@ -100,8 +100,8 @@ class DataFrame private(private val columnData: Vector[Array[_]], */ def select(exprs: Expr*): DataFrame = { val (colData, colSchemas) = exprs.zipWithIndex.map({ - case (Expr.Named(newName, Expr.Column(colName)), _) => (getColumnData(colName), schema.get(colName).renamed(newName)) - case (Expr.Column(colName), _) => (getColumnData(colName), schema.get(colName)) + case (Expr.Named(newName, Expr.Cell(colName)), _) => (getColumnData(colName), schema.get(colName).renamed(newName)) + case (Expr.Cell(colName), _) => (getColumnData(colName), schema.get(colName)) case (expr, i) => val eval = expr.compile(this) val newColSchema = expr match { @@ -225,7 +225,7 @@ class DataFrame private(private val columnData: Vector[Array[_]], private[crossbow] def slice(indices: IndexedSeq[Int]): DataFrame = { val newData = schema.columns.map(col => { - val eval = Expr.Column(col.name).compile(this) + val eval = Expr.Cell(col.name).compile(this) sliceColumn(eval, indices) }) new DataFrame(newData.toVector, schema) diff --git a/src/main/scala/com/audienceproject/crossbow/Implicits.scala b/src/main/scala/com/audienceproject/crossbow/Implicits.scala index faaf570..c998d5b 100644 --- a/src/main/scala/com/audienceproject/crossbow/Implicits.scala +++ b/src/main/scala/com/audienceproject/crossbow/Implicits.scala @@ -9,7 +9,7 @@ object Implicits { // Column expression. implicit class ColumnByName(val sc: StringContext) extends AnyVal { - def $(args: Any*): Expr = Expr.Column(sc.s(args: _*)) + def $(args: Any*): Expr = Expr.Cell(sc.s(args: _*)) } // Literal value. diff --git a/src/main/scala/com/audienceproject/crossbow/algorithms/GroupBy.scala b/src/main/scala/com/audienceproject/crossbow/algorithms/GroupBy.scala index 98d93e2..46bda98 100644 --- a/src/main/scala/com/audienceproject/crossbow/algorithms/GroupBy.scala +++ b/src/main/scala/com/audienceproject/crossbow/algorithms/GroupBy.scala @@ -14,8 +14,8 @@ private[crossbow] object GroupBy { val selectExprs = aggExprs.map(Traversal.transform(_, { case agg: Aggregator => aggregators += agg - Expr.Column(s"_res${aggregators.size}") - case col: Expr.Column => throw new AggregationException(col) + Expr.Cell(s"_res${aggregators.size}") + case col: Expr.Cell => throw new AggregationException(col) })) val keyEvals = keyExprs.map(_.compile(dataFrame)).toList @@ -34,14 +34,14 @@ private[crossbow] object GroupBy { val keyNames = keyExprs.zipWithIndex.map({ case (Expr.Named(name, _), _) => name - case (Expr.Column(name), _) => name + case (Expr.Cell(name), _) => name case (_, i) => s"_key$i" }) val keySchemaCols = keyNames.zip(keyEvals).map({ case (name, eval) => Column(name, eval.typeOf) }).toList val dataSchemaCols = reducers.zipWithIndex.map({ case (reducer, i) => Column(s"_res${i + 1}", reducer.typeOf) }) val temp = DataFrame.fromColumns(newKeyCols ++ newDataCols, Schema(keySchemaCols ++ dataSchemaCols)) - temp.select(keySchemaCols.map(c => Expr.Column(c.name)) ++ selectExprs: _*) + temp.select(keySchemaCols.map(c => Expr.Cell(c.name)) ++ selectExprs: _*) } } diff --git a/src/main/scala/com/audienceproject/crossbow/algorithms/SortMergeJoin.scala b/src/main/scala/com/audienceproject/crossbow/algorithms/SortMergeJoin.scala index 2fb444e..dfd06be 100644 --- a/src/main/scala/com/audienceproject/crossbow/algorithms/SortMergeJoin.scala +++ b/src/main/scala/com/audienceproject/crossbow/algorithms/SortMergeJoin.scala @@ -17,8 +17,8 @@ private[crossbow] object SortMergeJoin { val ordering = Order.getOrdering(internalType) - val leftSorted = left.addColumn(joinExpr as joinColName).sortBy(Expr.Column(joinColName)) - val rightSorted = right.addColumn(joinExpr as joinColName).sortBy(Expr.Column(joinColName)) + val leftSorted = left.addColumn(joinExpr as joinColName).sortBy(Expr.Cell(joinColName)) + val rightSorted = right.addColumn(joinExpr as joinColName).sortBy(Expr.Cell(joinColName)) val leftKey = leftSorted(joinColName).as[Any] val rightKey = rightSorted(joinColName).as[Any] diff --git a/src/main/scala/com/audienceproject/crossbow/expr/BinaryExpr.scala b/src/main/scala/com/audienceproject/crossbow/expr/BinaryExpr.scala index 21b81a0..65ab44f 100644 --- a/src/main/scala/com/audienceproject/crossbow/expr/BinaryExpr.scala +++ b/src/main/scala/com/audienceproject/crossbow/expr/BinaryExpr.scala @@ -20,8 +20,8 @@ protected abstract class BinaryExpr(private val lhs: Expr, private val rhs: Expr private[crossbow] object BinaryExpr { - case class BinaryOp[T, U, V: ru.TypeTag](lhs: Specialized[T], rhs: Specialized[U], - op: (T, U) => V) + private[BinaryExpr] case class BinaryOp[T, U, V: ru.TypeTag](lhs: Specialized[T], rhs: Specialized[U], + op: (T, U) => V) extends Specialized[V] { override def apply(i: Int): V = op(lhs(i), rhs(i)) } diff --git a/src/main/scala/com/audienceproject/crossbow/expr/Expr.scala b/src/main/scala/com/audienceproject/crossbow/expr/Expr.scala index d67e9b8..43d6acc 100644 --- a/src/main/scala/com/audienceproject/crossbow/expr/Expr.scala +++ b/src/main/scala/com/audienceproject/crossbow/expr/Expr.scala @@ -13,7 +13,7 @@ private[crossbow] object Expr { override private[crossbow] def compile(context: DataFrame) = expr.compile(context) } - case class Column(columnName: String) extends Expr { + case class Cell(columnName: String) extends Expr { override private[crossbow] def compile(context: DataFrame) = { val columnType = context.schema.get(columnName).columnType columnType match { diff --git a/src/main/scala/com/audienceproject/crossbow/expr/UnaryExpr.scala b/src/main/scala/com/audienceproject/crossbow/expr/UnaryExpr.scala index edf616d..33969cf 100644 --- a/src/main/scala/com/audienceproject/crossbow/expr/UnaryExpr.scala +++ b/src/main/scala/com/audienceproject/crossbow/expr/UnaryExpr.scala @@ -18,7 +18,7 @@ protected abstract class UnaryExpr(private val expr: Expr) extends Expr { private[crossbow] object UnaryExpr { - case class UnaryOp[T, U: ru.TypeTag](operand: Specialized[T], op: T => U) + private[UnaryExpr] case class UnaryOp[T, U: ru.TypeTag](operand: Specialized[T], op: T => U) extends Specialized[U] { override def apply(i: Int): U = op(operand(i)) } diff --git a/src/test/scala/com/audienceproject/crossbow/algorithms/TraversalTest.scala b/src/test/scala/com/audienceproject/crossbow/algorithms/TraversalTest.scala index 07224af..83f7453 100644 --- a/src/test/scala/com/audienceproject/crossbow/algorithms/TraversalTest.scala +++ b/src/test/scala/com/audienceproject/crossbow/algorithms/TraversalTest.scala @@ -12,7 +12,7 @@ class TraversalTest extends AnyFunSuite { test("Transform expression tree") { val transformedExpr = Traversal.transform(aggExpr, { - case _: Aggregator => Expr.Column("42") + case _: Aggregator => Expr.Cell("42") }) val expected = $"42" + $"42" / 2 * $"42" assert(transformedExpr == expected) @@ -22,7 +22,7 @@ class TraversalTest extends AnyFunSuite { val list = mutable.ListBuffer.empty[Expr] val transformedExpr = Traversal.transform(aggExpr, { case _: Aggregator => - val newCol = Expr.Column(s"_${list.size}") + val newCol = Expr.Cell(s"_${list.size}") list += newCol newCol }) diff --git a/src/test/scala/com/audienceproject/crossbow/core/TypedViewTest.scala b/src/test/scala/com/audienceproject/crossbow/core/TypedViewTest.scala new file mode 100644 index 0000000..ee4a17a --- /dev/null +++ b/src/test/scala/com/audienceproject/crossbow/core/TypedViewTest.scala @@ -0,0 +1,23 @@ +package com.audienceproject.crossbow.core + +import com.audienceproject.crossbow.DataFrame +import com.audienceproject.crossbow.Implicits._ +import org.scalatest.funsuite.AnyFunSuite + +class TypedViewTest extends AnyFunSuite { + + private val df = DataFrame.fromSeq(Seq(("a", 1), ("b", 2), ("c", 3))) + + test("Cast DataFrame to TypedView of single type") { + val result = df.select($"_0").as[String].toSeq + val expected = Seq("a", "b", "c") + assert(result == expected) + } + + test("Cast DataFrame to TypedView of tuple") { + val result = df.as[(String, Int)].toSeq + val expected = Seq(("a", 1), ("b", 2), ("c", 3)) + assert(result == expected) + } + +} diff --git a/wercker.yml b/wercker.yml index e3894f6..a140b34 100644 --- a/wercker.yml +++ b/wercker.yml @@ -11,7 +11,7 @@ build: code: sbt clean +compile - script: name: Test - code: sbt +test + code: sbt test - script: name: Clean again code: sbt clean