Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

Commit

Permalink
Refactor each_top_k
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed Jan 26, 2017
1 parent 6e6b184 commit 98ff0b9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ final class HivemallOps(df: DataFrame) extends Logging {
import HivemallOps._
import HivemallUtils._

private[this] val _sparkSession = df.sparkSession
private[this] val _analyzer = _sparkSession.sessionState.analyzer

/**
* @see hivemall.regression.AdaDeltaUDTF
* @group regression
Expand Down Expand Up @@ -788,37 +791,24 @@ final class HivemallOps(df: DataFrame) extends Logging {
/**
* Returns `top-k` records for each `group`.
* @group misc
* @since 0.5.0
*/
def each_top_k(k: Int, group: String, score: String, args: String*)
: DataFrame = withTypedPlan {
val clusterDf = df.repartition(group).sortWithinPartitions(group)
val childrenAttributes = clusterDf.logicalPlan.output
val generator = Generate(
EachTopK(
k,
clusterDf.resolve(group),
clusterDf.resolve(score),
childrenAttributes
),
join = false, outer = false, None,
(Seq("rank") ++ childrenAttributes.map(_.name)).map(UnresolvedAttribute(_)),
clusterDf.logicalPlan)
val attributes = generator.generatedSet
val projectList = (Seq("rank") ++ args).map(s => attributes.find(_.name == s).get)
Project(projectList, generator)
}

@deprecated("use each_top_k(Int, String, String, String*) instead", "0.5.0")
def each_top_k(k: Column, group: Column, value: Column, args: Column*): DataFrame = {
def each_top_k(k: Column, group: Column, score: Column): DataFrame = withTypedPlan {
val kInt = k.expr match {
case Literal(v: Any, IntegerType) => v.asInstanceOf[Int]
case e => throw new AnalysisException("`k` must be integer, however " + e)
}
val groupStr = usePrettyExpression(group.expr).sql
val valueStr = usePrettyExpression(value.expr).sql
val argStrs = args.map(c => usePrettyExpression(c.expr).sql)
each_top_k(kInt, groupStr, valueStr, argStrs: _*)
val clusterDf = df.repartition(group).sortWithinPartitions(group)
val child = clusterDf.logicalPlan
val logicalPlan = Project(group.named +: score.named +: child.output, child)
_analyzer.execute(logicalPlan) match {
case Project(group :: score :: origCols, c) =>
Generate(
EachTopK(kInt, group, score, c.output),
join = false, outer = false, None,
(Seq("rank") ++ origCols.map(_.name)).map(UnresolvedAttribute(_)),
clusterDf.logicalPlan
)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package org.apache.spark.sql.hive

import org.apache.spark.sql.{AnalysisException, Column, Row}
import org.apache.spark.sql.functions
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.HivemallGroupedDataset._
import org.apache.spark.sql.hive.HivemallOps._
import org.apache.spark.sql.hive.HivemallUtils._
Expand Down Expand Up @@ -169,7 +169,8 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {

test("ftvec.hash") {
assert(DummyInputData.select(mhash("test")).count == DummyInputData.count)
assert(DummyInputData.select(sha1("test")).count == DummyInputData.count)
assert(DummyInputData.select(org.apache.spark.sql.hive.HivemallOps.sha1("test")).count ==
DummyInputData.count)
// assert(DummyInputData.select(array_hash_values(Seq("aaa", "bbb"))).count
// == DummyInputData.count)
// assert(DummyInputData.select(prefixed_hash_values(Seq("ccc", "ddd"), "prefix")).count
Expand Down Expand Up @@ -317,40 +318,26 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {

// Compute top-1 rows for each group
checkAnswer(
testDf.each_top_k(1, "key", "score", "key", "value"),
Row(1, "a", "3") ::
Row(1, "b", "4") ::
Row(1, "c", "6") ::
Nil
)
checkAnswer(
testDf.each_top_k(1, $"key", $"score", $"key", $"value"),
Row(1, "a", "3") ::
Row(1, "b", "4") ::
Row(1, "c", "6") ::
testDf.each_top_k(lit(1), $"key", $"score"),
Row(1, "a", "3", 0.8, Array(2, 5)) ::
Row(1, "b", "4", 0.3, Array(2)) ::
Row(1, "c", "6", 0.3, Array(1, 3)) ::
Nil
)

// Compute reverse top-1 rows for each group
checkAnswer(
testDf.each_top_k(-1, "key", "score", "key", "value"),
Row(1, "a", "1") ::
Row(1, "b", "5") ::
Row(1, "c", "6") ::
Nil
)
checkAnswer(
testDf.each_top_k(-1, $"key", $"score", $"key", $"value"),
Row(1, "a", "1") ::
Row(1, "b", "5") ::
Row(1, "c", "6") ::
testDf.each_top_k(-1, $"key", $"score"),
Row(1, "a", "1", 0.5, Array(0, 1, 2)) ::
Row(1, "b", "5", 0.1, Array(3)) ::
Row(1, "c", "6", 0.3, Array(1, 3)) ::
Nil
)

// Check if some exceptions thrown in case of some conditions
assert(intercept[AnalysisException] { testDf.each_top_k(0.1, $"key", $"score") }
assert(intercept[AnalysisException] { testDf.each_top_k(lit(0.1), $"key", $"score") }
.getMessage contains "`k` must be integer, however")
assert(intercept[AnalysisException] { testDf.each_top_k(1, "key", "data") }
assert(intercept[AnalysisException] { testDf.each_top_k(lit(1), $"key", $"data") }
.getMessage contains "must have a comparable type")
}

Expand Down Expand Up @@ -392,7 +379,6 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {

test("tools.array - select_k_best") {
import hiveContext.implicits._
import org.apache.spark.sql.functions._

val data = Seq(Seq(0, 1, 3), Seq(2, 4, 1), Seq(5, 4, 9))
val df = data.map(d => (d, Seq(3, 1, 2))).toDF("features", "importance_list")
Expand Down Expand Up @@ -577,7 +563,7 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
* WARN Column: Constructing trivially true equals predicate, 'rowid#1323 = rowid#1323'.
* Perhaps you need to use aliases.
*/
.select($"rowid", functions.when(sigmoid($"sum(value)") > 0.50, 1.0).otherwise(0.0))
.select($"rowid", when(sigmoid($"sum(value)") > 0.50, 1.0).otherwise(0.0))
.as("rowid", "predicted")

// Evaluation
Expand Down

0 comments on commit 98ff0b9

Please sign in to comment.