From e79981ec6f83785d9bc6a28ee5e1eb5945daf202 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 30 Mar 2018 07:26:40 +0900 Subject: [PATCH] Add an optimizer rule to filter out columns with low variances --- spark/spark-2.3/pom.xml | 7 ++ ...r.scala => UserProvidedLogicalPlans.scala} | 2 +- .../apache/spark/sql/hive/HivemallConf.scala | 83 +++++++++++++++++++ .../apache/spark/sql/hive/HivemallOps.scala | 14 +++- .../sql/optimizer/VarianceThreshold.scala | 71 ++++++++++++++++ .../sql/hive/FeatureSelectionRuleSuite.scala | 68 +++++++++++++++ .../hive/test/HivemallFeatureQueryTest.scala | 5 +- 7 files changed, 241 insertions(+), 9 deletions(-) rename spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/{UserProvidedPlanner.scala => UserProvidedLogicalPlans.scala} (97%) create mode 100644 spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallConf.scala create mode 100644 spark/spark-2.3/src/main/scala/org/apache/spark/sql/optimizer/VarianceThreshold.scala create mode 100644 spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/FeatureSelectionRuleSuite.scala diff --git a/spark/spark-2.3/pom.xml b/spark/spark-2.3/pom.xml index cfa64579d..46a22bdcf 100644 --- a/spark/spark-2.3/pom.xml +++ b/spark/spark-2.3/pom.xml @@ -46,6 +46,12 @@ org.apache.hivemall hivemall-core compile + + + io.netty + netty-all + + org.apache.hivemall @@ -105,6 +111,7 @@ org.scalatest scalatest_${scala.binary.version} + 3.0.3 test diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/UserProvidedLogicalPlans.scala similarity index 97% rename from spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala rename to spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/UserProvidedLogicalPlans.scala index 09d60a645..c12aeca0c 100644 --- a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/UserProvidedLogicalPlans.scala @@ -70,7 +70,7 @@ private object ExtractJoinTopKKeys extends Logging with PredicateHelper { } } -private[sql] class UserProvidedPlanner(val conf: SQLConf) extends Strategy { +private[sql] class UserProvidedLogicalPlans(val conf: SQLConf) extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractJoinTopKKeys( diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallConf.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallConf.scala new file mode 100644 index 000000000..dad42f981 --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallConf.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive + +import scala.language.implicitConversions + +import org.apache.spark.internal.config.{ConfigBuilder, ConfigEntry, ConfigReader} +import org.apache.spark.sql.internal.SQLConf + +object HivemallConf { + + /** + * Implicitly injects the [[HivemallConf]] into [[SQLConf]]. + */ + implicit def SQLConfToHivemallConf(conf: SQLConf): HivemallConf = new HivemallConf(conf) + + private val sqlConfEntries = java.util.Collections.synchronizedMap( + new java.util.HashMap[String, ConfigEntry[_]]()) + + private def register(entry: ConfigEntry[_]): Unit = sqlConfEntries.synchronized { + require(!sqlConfEntries.containsKey(entry.key), + s"Duplicate SQLConfigEntry. ${entry.key} has been registered") + sqlConfEntries.put(entry.key, entry) + } + + // For testing only + // TODO: Need to add tests for the configurations + private[sql] def unregister(entry: ConfigEntry[_]): Unit = sqlConfEntries.synchronized { + sqlConfEntries.remove(entry.key) + } + + def buildConf(key: String): ConfigBuilder = ConfigBuilder(key).onCreate(register) + + val FEATURE_SELECTION_ENABLED = + buildConf("spark.sql.optimizer.featureSelection.enabled") + .doc("Whether feature selections are applied in the optimizer") + .booleanConf + .createWithDefault(false) + + val FEATURE_SELECTION_VARIANCE_THRESHOLD = + buildConf("spark.sql.optimizer.featureSelection.varianceThreshold") + .doc("Specifies the threshold of variances to filter out features") + .doubleConf + .createWithDefault(0.05) +} + +class HivemallConf(conf: SQLConf) { + import HivemallConf._ + + private val reader = new ConfigReader(conf.settings) + + def featureSelectionEnabled: Boolean = getConf(FEATURE_SELECTION_ENABLED) + + def featureSelectionVarianceThreshold: Double = getConf(FEATURE_SELECTION_VARIANCE_THRESHOLD) + + /** ********************** SQLConf functionality methods ************ */ + + /** + * Return the value of Hivemall configuration property for the given key. If the key is not set + * yet, return `defaultValue` in [[ConfigEntry]]. + */ + private def getConf[T](entry: ConfigEntry[T]): T = { + require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") + entry.readFrom(reader) + } +} diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala index 94bcfd62b..6164464ca 100644 --- a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala @@ -32,9 +32,11 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, Generate, JoinTopK, LogicalPlan} -import org.apache.spark.sql.execution.UserProvidedPlanner +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.UserProvidedLogicalPlans import org.apache.spark.sql.execution.datasources.csv.{CsvToStruct, StructToCsv} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.optimizer.VarianceThreshold import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -97,7 +99,11 @@ final class HivemallOps(df: DataFrame) extends Logging { import internal.HivemallOpsImpl._ private lazy val _sparkSession = df.sparkSession - private lazy val _strategy = new UserProvidedPlanner(_sparkSession.sqlContext.conf) + + private lazy val _userDefinedStrategies: Seq[Strategy] = Seq( + new UserProvidedLogicalPlans(_sparkSession.sqlContext.conf)) + private lazy val _userDefinedOptimizations: Seq[Rule[LogicalPlan]] = Seq( + new VarianceThreshold(_sparkSession.sqlContext.conf)) /** * @see [[hivemall.regression.GeneralRegressorUDTF]] @@ -1151,8 +1157,8 @@ final class HivemallOps(df: DataFrame) extends Logging { @inline private def withTypedPlanInCustomStrategy(logicalPlan: => LogicalPlan) : DataFrame = { // Inject custom strategies - if (!_sparkSession.experimental.extraStrategies.contains(_strategy)) { - _sparkSession.experimental.extraStrategies = Seq(_strategy) + if (_sparkSession.experimental.extraStrategies == Nil) { + _sparkSession.experimental.extraStrategies = _userDefinedStrategies } withTypedPlan(logicalPlan) } diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/optimizer/VarianceThreshold.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/optimizer/VarianceThreshold.scala new file mode 100644 index 000000000..8457979c6 --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/optimizer/VarianceThreshold.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.optimizer + +import org.apache.spark.sql.catalyst.plans.logical.{Histogram, LogicalPlan, Project, Statistics} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf + + +/** + * This optimizer rule removes features with low variance; it removes all features whose + * variance doesn't meet some threshold. You can control this threshold by + * `spark.sql.optimizer.featureSelection.varianceThreshold` (0.05 by default). + */ +class VarianceThreshold(conf: SQLConf) extends Rule[LogicalPlan] { + import org.apache.spark.sql.hive.HivemallConf._ + + private def featureSelectionEnabled: Boolean = conf.featureSelectionEnabled + private def varianceThreshold: Double = conf.featureSelectionVarianceThreshold + + private def hasColumnHistogram(s: Statistics): Boolean = { + s.attributeStats.exists { case (_, stat) => + stat.histogram.isDefined + } + } + + private def satisfyVarianceThreshold(histgramOption: Option[Histogram]): Boolean = { + // TODO: Since binary types are not supported in histograms but they could frequently appear + // in user schemas, we would be better to handle the case here. + histgramOption.forall { hist => + // TODO: Make the value more precise by using `HistogramBin.ndv` + val dataSeq = hist.bins.map { bin => (bin.hi + bin.lo) / 2 } + val avg = dataSeq.sum / dataSeq.length + val variance = dataSeq.map { d => Math.pow(avg - d, 2.0) }.sum / dataSeq.length + varianceThreshold < variance + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + case p if featureSelectionEnabled && hasColumnHistogram(p.stats) => + val attributeStats = p.stats.attributeStats + val outputAttrs = p.output + val projectList = outputAttrs.zip(outputAttrs.map { a => attributeStats.get(a)}).flatMap { + case (_, Some(stat)) if !satisfyVarianceThreshold(stat.histogram) => None + case (attr, _) => Some(attr) + } + if (projectList != outputAttrs) { + Project(projectList, p) + } else { + p + } + case p => p + } +} diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/FeatureSelectionRuleSuite.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/FeatureSelectionRuleSuite.scala new file mode 100644 index 000000000..2a39158d4 --- /dev/null +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/FeatureSelectionRuleSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.optimizer.VarianceThreshold +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ + + +class FeatureSelectionRuleSuite extends SQLTestUtils with TestHiveSingleton { + + import hiveContext.implicits._ + + protected override def beforeAll(): Unit = { + super.beforeAll() + // Sets user-defined optimization rules for feature selection + hiveContext.experimental.extraOptimizations = Seq( + new VarianceThreshold(hiveContext.conf)) + } + + test("filter out features with low variances") { + withSQLConf( + HivemallConf.FEATURE_SELECTION_ENABLED.key -> "true", + HivemallConf.FEATURE_SELECTION_VARIANCE_THRESHOLD.key -> "0.1", + SQLConf.CBO_ENABLED.key -> "true", + SQLConf.HISTOGRAM_ENABLED.key -> "true") { + withTable("t") { + withTempDir { dir => + Seq((1, "one", 1.0, 1.0), + (1, "two", 1.1, 2.3), + (1, "three", 0.9, 3.5), + (1, "one", 0.9, 10.3)) + .toDF("c0", "c1", "c2", "c3") + .write + .parquet(s"${dir.getAbsolutePath}/t") + + spark.read.parquet(s"${dir.getAbsolutePath}/t").write.saveAsTable("t") + + sql("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS c0, c1, c2, c3") + + // Filters out `c0` and `c2` because of low variances + val optimizedPlan = sql("SELECT c0, * FROM t").queryExecution.optimizedPlan + assert(optimizedPlan.output.map(_.name) === Seq("c1", "c3")) + } + } + } + } +} diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala index bc656d100..625e71d3f 100644 --- a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala @@ -20,13 +20,10 @@ package org.apache.spark.sql.hive.test import scala.collection.mutable.Seq -import scala.reflect.runtime.universe.TypeTag import hivemall.tools.RegressionDatagen -import org.apache.spark.sql.{Column, QueryTest} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection} -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SQLTestUtils /**