Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

Commit

Permalink
Refactor each_top_k
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed Jan 26, 2017
1 parent 79f92f4 commit 4039f9e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{EachTopK, Expression, Literal, NamedExpression, UserDefinedGenerator}
import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
import org.apache.spark.sql.types._
Expand All @@ -56,6 +55,9 @@ import org.apache.spark.unsafe.types.UTF8String
*/
final class HivemallOps(df: DataFrame) extends Logging {

private[this] val _sparkSession = df.sparkSession
private[this] val _analyzer = _sparkSession.sessionState.analyzer

/**
* @see hivemall.regression.AdaDeltaUDTF
* @group regression
Expand Down Expand Up @@ -788,37 +790,24 @@ final class HivemallOps(df: DataFrame) extends Logging {
/**
* Returns `top-k` records for each `group`.
* @group misc
* @since 0.5.0
*/
def each_top_k(k: Int, group: String, score: String, args: String*)
: DataFrame = withTypedPlan {
val clusterDf = df.repartition(df(group)).sortWithinPartitions(group)
val childrenAttributes = clusterDf.logicalPlan.output
val generator = Generate(
EachTopK(
k,
clusterDf.resolve(group),
clusterDf.resolve(score),
childrenAttributes
),
join = false, outer = false, None,
(Seq("rank") ++ childrenAttributes.map(_.name)).map(UnresolvedAttribute(_)),
clusterDf.logicalPlan)
val attributes = generator.generatedSet
val projectList = (Seq("rank") ++ args).map(s => attributes.find(_.name == s).get)
Project(projectList, generator)
}

@deprecated("use each_top_k(Int, String, String, String*) instead", "0.5.0")
def each_top_k(k: Column, group: Column, value: Column, args: Column*): DataFrame = {
*/
def each_top_k(k: Column, group: Column, score: Column): DataFrame = withTypedPlan {
val kInt = k.expr match {
case Literal(v: Any, IntegerType) => v.asInstanceOf[Int]
case e => throw new AnalysisException("`k` must be integer, however " + e)
}
val groupStr = usePrettyExpression(group.expr).sql
val valueStr = usePrettyExpression(value.expr).sql
val argStrs = args.map(c => usePrettyExpression(c.expr).sql)
each_top_k(kInt, groupStr, valueStr, argStrs: _*)
val clusterDf = df.repartition(group).sortWithinPartitions(group)
val child = clusterDf.logicalPlan
val logicalPlan = Project(group.named +: score.named +: child.output, child)
_analyzer.execute(logicalPlan) match {
case Project(group :: score :: origCols, c) =>
Generate(
EachTopK(kInt, group, score, c.output),
join = false, outer = false, None,
(Seq("rank") ++ origCols.map(_.name)).map(UnresolvedAttribute(_)),
clusterDf.logicalPlan
)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,40 +305,26 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {

// Compute top-1 rows for each group
checkAnswer(
testDf.each_top_k(1, "key", "score", "key", "value"),
Row(1, "a", "3") ::
Row(1, "b", "4") ::
Row(1, "c", "6") ::
Nil
)
checkAnswer(
testDf.each_top_k(lit(1), $"key", $"score", $"key", $"value"),
Row(1, "a", "3") ::
Row(1, "b", "4") ::
Row(1, "c", "6") ::
testDf.each_top_k(lit(1), $"key", $"score"),
Row(1, "a", "3", 0.8, Array(2, 5)) ::
Row(1, "b", "4", 0.3, Array(2)) ::
Row(1, "c", "6", 0.3, Array(1, 3)) ::
Nil
)

// Compute reverse top-1 rows for each group
checkAnswer(
testDf.each_top_k(-1, "key", "score", "key", "value"),
Row(1, "a", "1") ::
Row(1, "b", "5") ::
Row(1, "c", "6") ::
Nil
)
checkAnswer(
testDf.each_top_k(lit(-1), $"key", $"score", $"key", $"value"),
Row(1, "a", "1") ::
Row(1, "b", "5") ::
Row(1, "c", "6") ::
testDf.each_top_k(lit(-1), $"key", $"score"),
Row(1, "a", "1", 0.5, Array(0, 1, 2)) ::
Row(1, "b", "5", 0.1, Array(3)) ::
Row(1, "c", "6", 0.3, Array(1, 3)) ::
Nil
)

// Check if some exceptions thrown in case of some conditions
assert(intercept[AnalysisException] { testDf.each_top_k(lit(0.1), $"key", $"score") }
.getMessage contains "`k` must be integer, however")
assert(intercept[AnalysisException] { testDf.each_top_k(1, "key", "data") }
assert(intercept[AnalysisException] { testDf.each_top_k(lit(1), $"key", $"data") }
.getMessage contains "must have a comparable type")
}

Expand Down

0 comments on commit 4039f9e

Please sign in to comment.