Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

Commit

Permalink
Add an optimizer rule to filter out columns with low variances
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed Apr 2, 2018
1 parent bd14314 commit e79981e
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 9 deletions.
7 changes: 7 additions & 0 deletions spark/spark-2.3/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@
<groupId>org.apache.hivemall</groupId>
<artifactId>hivemall-core</artifactId>
<scope>compile</scope>
<exclusions>
<exclusion>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.hivemall</groupId>
Expand Down Expand Up @@ -105,6 +111,7 @@
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<version>3.0.3</version>
<scope>test</scope>
</dependency>
<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ private object ExtractJoinTopKKeys extends Logging with PredicateHelper {
}
}

private[sql] class UserProvidedPlanner(val conf: SQLConf) extends Strategy {
private[sql] class UserProvidedLogicalPlans(val conf: SQLConf) extends Strategy {

override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractJoinTopKKeys(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.hive

import scala.language.implicitConversions

import org.apache.spark.internal.config.{ConfigBuilder, ConfigEntry, ConfigReader}
import org.apache.spark.sql.internal.SQLConf

object HivemallConf {

/**
* Implicitly injects the [[HivemallConf]] into [[SQLConf]].
*/
implicit def SQLConfToHivemallConf(conf: SQLConf): HivemallConf = new HivemallConf(conf)

private val sqlConfEntries = java.util.Collections.synchronizedMap(
new java.util.HashMap[String, ConfigEntry[_]]())

private def register(entry: ConfigEntry[_]): Unit = sqlConfEntries.synchronized {
require(!sqlConfEntries.containsKey(entry.key),
s"Duplicate SQLConfigEntry. ${entry.key} has been registered")
sqlConfEntries.put(entry.key, entry)
}

// For testing only
// TODO: Need to add tests for the configurations
private[sql] def unregister(entry: ConfigEntry[_]): Unit = sqlConfEntries.synchronized {
sqlConfEntries.remove(entry.key)
}

def buildConf(key: String): ConfigBuilder = ConfigBuilder(key).onCreate(register)

val FEATURE_SELECTION_ENABLED =
buildConf("spark.sql.optimizer.featureSelection.enabled")
.doc("Whether feature selections are applied in the optimizer")
.booleanConf
.createWithDefault(false)

val FEATURE_SELECTION_VARIANCE_THRESHOLD =
buildConf("spark.sql.optimizer.featureSelection.varianceThreshold")
.doc("Specifies the threshold of variances to filter out features")
.doubleConf
.createWithDefault(0.05)
}

class HivemallConf(conf: SQLConf) {
import HivemallConf._

private val reader = new ConfigReader(conf.settings)

def featureSelectionEnabled: Boolean = getConf(FEATURE_SELECTION_ENABLED)

def featureSelectionVarianceThreshold: Double = getConf(FEATURE_SELECTION_VARIANCE_THRESHOLD)

/** ********************** SQLConf functionality methods ************ */

/**
* Return the value of Hivemall configuration property for the given key. If the key is not set
* yet, return `defaultValue` in [[ConfigEntry]].
*/
private def getConf[T](entry: ConfigEntry[T]): T = {
require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered")
entry.readFrom(reader)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, Generate, JoinTopK, LogicalPlan}
import org.apache.spark.sql.execution.UserProvidedPlanner
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.UserProvidedLogicalPlans
import org.apache.spark.sql.execution.datasources.csv.{CsvToStruct, StructToCsv}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.optimizer.VarianceThreshold
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -97,7 +99,11 @@ final class HivemallOps(df: DataFrame) extends Logging {
import internal.HivemallOpsImpl._

private lazy val _sparkSession = df.sparkSession
private lazy val _strategy = new UserProvidedPlanner(_sparkSession.sqlContext.conf)

private lazy val _userDefinedStrategies: Seq[Strategy] = Seq(
new UserProvidedLogicalPlans(_sparkSession.sqlContext.conf))
private lazy val _userDefinedOptimizations: Seq[Rule[LogicalPlan]] = Seq(
new VarianceThreshold(_sparkSession.sqlContext.conf))

/**
* @see [[hivemall.regression.GeneralRegressorUDTF]]
Expand Down Expand Up @@ -1151,8 +1157,8 @@ final class HivemallOps(df: DataFrame) extends Logging {
@inline private def withTypedPlanInCustomStrategy(logicalPlan: => LogicalPlan)
: DataFrame = {
// Inject custom strategies
if (!_sparkSession.experimental.extraStrategies.contains(_strategy)) {
_sparkSession.experimental.extraStrategies = Seq(_strategy)
if (_sparkSession.experimental.extraStrategies == Nil) {
_sparkSession.experimental.extraStrategies = _userDefinedStrategies
}
withTypedPlan(logicalPlan)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.optimizer

import org.apache.spark.sql.catalyst.plans.logical.{Histogram, LogicalPlan, Project, Statistics}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf


/**
* This optimizer rule removes features with low variance; it removes all features whose
* variance doesn't meet some threshold. You can control this threshold by
* `spark.sql.optimizer.featureSelection.varianceThreshold` (0.05 by default).
*/
class VarianceThreshold(conf: SQLConf) extends Rule[LogicalPlan] {
import org.apache.spark.sql.hive.HivemallConf._

private def featureSelectionEnabled: Boolean = conf.featureSelectionEnabled
private def varianceThreshold: Double = conf.featureSelectionVarianceThreshold

private def hasColumnHistogram(s: Statistics): Boolean = {
s.attributeStats.exists { case (_, stat) =>
stat.histogram.isDefined
}
}

private def satisfyVarianceThreshold(histgramOption: Option[Histogram]): Boolean = {
// TODO: Since binary types are not supported in histograms but they could frequently appear
// in user schemas, we would be better to handle the case here.
histgramOption.forall { hist =>
// TODO: Make the value more precise by using `HistogramBin.ndv`
val dataSeq = hist.bins.map { bin => (bin.hi + bin.lo) / 2 }
val avg = dataSeq.sum / dataSeq.length
val variance = dataSeq.map { d => Math.pow(avg - d, 2.0) }.sum / dataSeq.length
varianceThreshold < variance
}
}

override def apply(plan: LogicalPlan): LogicalPlan = plan match {
case p if featureSelectionEnabled && hasColumnHistogram(p.stats) =>
val attributeStats = p.stats.attributeStats
val outputAttrs = p.output
val projectList = outputAttrs.zip(outputAttrs.map { a => attributeStats.get(a)}).flatMap {
case (_, Some(stat)) if !satisfyVarianceThreshold(stat.histogram) => None
case (attr, _) => Some(attr)
}
if (projectList != outputAttrs) {
Project(projectList, p)
} else {
p
}
case p => p
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.hive

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.optimizer.VarianceThreshold
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._


class FeatureSelectionRuleSuite extends SQLTestUtils with TestHiveSingleton {

import hiveContext.implicits._

protected override def beforeAll(): Unit = {
super.beforeAll()
// Sets user-defined optimization rules for feature selection
hiveContext.experimental.extraOptimizations = Seq(
new VarianceThreshold(hiveContext.conf))
}

test("filter out features with low variances") {
withSQLConf(
HivemallConf.FEATURE_SELECTION_ENABLED.key -> "true",
HivemallConf.FEATURE_SELECTION_VARIANCE_THRESHOLD.key -> "0.1",
SQLConf.CBO_ENABLED.key -> "true",
SQLConf.HISTOGRAM_ENABLED.key -> "true") {
withTable("t") {
withTempDir { dir =>
Seq((1, "one", 1.0, 1.0),
(1, "two", 1.1, 2.3),
(1, "three", 0.9, 3.5),
(1, "one", 0.9, 10.3))
.toDF("c0", "c1", "c2", "c3")
.write
.parquet(s"${dir.getAbsolutePath}/t")

spark.read.parquet(s"${dir.getAbsolutePath}/t").write.saveAsTable("t")

sql("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS c0, c1, c2, c3")

// Filters out `c0` and `c2` because of low variances
val optimizedPlan = sql("SELECT c0, * FROM t").queryExecution.optimizedPlan
assert(optimizedPlan.output.map(_.name) === Seq("c1", "c3"))
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@
package org.apache.spark.sql.hive.test

import scala.collection.mutable.Seq
import scala.reflect.runtime.universe.TypeTag

import hivemall.tools.RegressionDatagen

import org.apache.spark.sql.{Column, QueryTest}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.test.SQLTestUtils

/**
Expand Down

0 comments on commit e79981e

Please sign in to comment.