Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

[HIVEMALL-182][SPARK][WIP] Add an optimizer rule to filter out columns with low variances #139

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions spark/spark-2.3/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@
<groupId>org.apache.hivemall</groupId>
<artifactId>hivemall-core</artifactId>
<scope>compile</scope>
<exclusions>
<exclusion>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.hivemall</groupId>
Expand Down Expand Up @@ -105,6 +111,7 @@
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<version>3.0.3</version>
<scope>test</scope>
</dependency>
<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ private object ExtractJoinTopKKeys extends Logging with PredicateHelper {
}
}

private[sql] class UserProvidedPlanner(val conf: SQLConf) extends Strategy {
private[sql] class UserProvidedLogicalPlans(val conf: SQLConf) extends Strategy {

override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractJoinTopKKeys(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.hive

import scala.language.implicitConversions

import org.apache.spark.internal.config.{ConfigBuilder, ConfigEntry, ConfigReader}
import org.apache.spark.sql.internal.SQLConf

object HivemallConf {

/**
* Implicitly injects the [[HivemallConf]] into [[SQLConf]].
*/
implicit def SQLConfToHivemallConf(conf: SQLConf): HivemallConf = new HivemallConf(conf)

private val sqlConfEntries = java.util.Collections.synchronizedMap(
new java.util.HashMap[String, ConfigEntry[_]]())

private def register(entry: ConfigEntry[_]): Unit = sqlConfEntries.synchronized {
require(!sqlConfEntries.containsKey(entry.key),
s"Duplicate SQLConfigEntry. ${entry.key} has been registered")
sqlConfEntries.put(entry.key, entry)
}

// For testing only
// TODO: Need to add tests for the configurations
private[sql] def unregister(entry: ConfigEntry[_]): Unit = sqlConfEntries.synchronized {
sqlConfEntries.remove(entry.key)
}

def buildConf(key: String): ConfigBuilder = ConfigBuilder(key).onCreate(register)

val FEATURE_SELECTION_ENABLED =
buildConf("spark.sql.optimizer.featureSelection.enabled")
.doc("Whether feature selections are applied in the optimizer")
.booleanConf
.createWithDefault(false)

val FEATURE_SELECTION_VARIANCE_THRESHOLD =
buildConf("spark.sql.optimizer.featureSelection.varianceThreshold")
.doc("Specifies the threshold of variances to filter out features")
.doubleConf
.createWithDefault(0.05)
}

class HivemallConf(conf: SQLConf) {
import HivemallConf._

private val reader = new ConfigReader(conf.settings)

def featureSelectionEnabled: Boolean = getConf(FEATURE_SELECTION_ENABLED)

def featureSelectionVarianceThreshold: Double = getConf(FEATURE_SELECTION_VARIANCE_THRESHOLD)

/** ********************** SQLConf functionality methods ************ */

/**
* Return the value of Hivemall configuration property for the given key. If the key is not set
* yet, return `defaultValue` in [[ConfigEntry]].
*/
private def getConf[T](entry: ConfigEntry[T]): T = {
require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered")
entry.readFrom(reader)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, Generate, JoinTopK, LogicalPlan}
import org.apache.spark.sql.execution.UserProvidedPlanner
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.UserProvidedLogicalPlans
import org.apache.spark.sql.execution.datasources.csv.{CsvToStruct, StructToCsv}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.optimizer.VarianceThreshold
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -97,7 +99,11 @@ final class HivemallOps(df: DataFrame) extends Logging {
import internal.HivemallOpsImpl._

private lazy val _sparkSession = df.sparkSession
private lazy val _strategy = new UserProvidedPlanner(_sparkSession.sqlContext.conf)

private lazy val _userDefinedStrategies: Seq[Strategy] = Seq(
new UserProvidedLogicalPlans(_sparkSession.sqlContext.conf))
private lazy val _userDefinedOptimizations: Seq[Rule[LogicalPlan]] = Seq(
new VarianceThreshold(_sparkSession.sqlContext.conf))

/**
* @see [[hivemall.regression.GeneralRegressorUDTF]]
Expand Down Expand Up @@ -1151,8 +1157,8 @@ final class HivemallOps(df: DataFrame) extends Logging {
@inline private def withTypedPlanInCustomStrategy(logicalPlan: => LogicalPlan)
: DataFrame = {
// Inject custom strategies
if (!_sparkSession.experimental.extraStrategies.contains(_strategy)) {
_sparkSession.experimental.extraStrategies = Seq(_strategy)
if (_sparkSession.experimental.extraStrategies == Nil) {
_sparkSession.experimental.extraStrategies = _userDefinedStrategies
}
withTypedPlan(logicalPlan)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.optimizer

import org.apache.spark.sql.catalyst.plans.logical.{Histogram, LogicalPlan, Project, Statistics}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf


/**
* This optimizer rule removes features with low variance; it removes all features whose
* variance doesn't meet some threshold. You can control this threshold by
* `spark.sql.optimizer.featureSelection.varianceThreshold` (0.05 by default).
*/
class VarianceThreshold(conf: SQLConf) extends Rule[LogicalPlan] {
import org.apache.spark.sql.hive.HivemallConf._

private def featureSelectionEnabled: Boolean = conf.featureSelectionEnabled
private def varianceThreshold: Double = conf.featureSelectionVarianceThreshold

private def hasColumnHistogram(s: Statistics): Boolean = {
s.attributeStats.exists { case (_, stat) =>
stat.histogram.isDefined
}
}

private def satisfyVarianceThreshold(histgramOption: Option[Histogram]): Boolean = {
// TODO: Since binary types are not supported in histograms but they could frequently appear
// in user schemas, we would be better to handle the case here.
histgramOption.forall { hist =>
// TODO: Make the value more precise by using `HistogramBin.ndv`
val dataSeq = hist.bins.map { bin => (bin.hi + bin.lo) / 2 }
val avg = dataSeq.sum / dataSeq.length
val variance = dataSeq.map { d => Math.pow(avg - d, 2.0) }.sum / dataSeq.length
varianceThreshold < variance
}
}

override def apply(plan: LogicalPlan): LogicalPlan = plan match {
case p if featureSelectionEnabled && hasColumnHistogram(p.stats) =>
val attributeStats = p.stats.attributeStats
val outputAttrs = p.output
val projectList = outputAttrs.zip(outputAttrs.map { a => attributeStats.get(a)}).flatMap {
case (_, Some(stat)) if !satisfyVarianceThreshold(stat.histogram) => None
case (attr, _) => Some(attr)
}
if (projectList != outputAttrs) {
Project(projectList, p)
} else {
p
}
case p => p
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.hive

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.optimizer.VarianceThreshold
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._


class FeatureSelectionRuleSuite extends SQLTestUtils with TestHiveSingleton {

import hiveContext.implicits._

protected override def beforeAll(): Unit = {
super.beforeAll()
// Sets user-defined optimization rules for feature selection
hiveContext.experimental.extraOptimizations = Seq(
new VarianceThreshold(hiveContext.conf))
}

test("filter out features with low variances") {
withSQLConf(
HivemallConf.FEATURE_SELECTION_ENABLED.key -> "true",
HivemallConf.FEATURE_SELECTION_VARIANCE_THRESHOLD.key -> "0.1",
SQLConf.CBO_ENABLED.key -> "true",
SQLConf.HISTOGRAM_ENABLED.key -> "true") {
withTable("t") {
withTempDir { dir =>
Seq((1, "one", 1.0, 1.0),
(1, "two", 1.1, 2.3),
(1, "three", 0.9, 3.5),
(1, "one", 0.9, 10.3))
.toDF("c0", "c1", "c2", "c3")
.write
.parquet(s"${dir.getAbsolutePath}/t")

spark.read.parquet(s"${dir.getAbsolutePath}/t").write.saveAsTable("t")

sql("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS c0, c1, c2, c3")

// Filters out `c0` and `c2` because of low variances
val optimizedPlan = sql("SELECT c0, * FROM t").queryExecution.optimizedPlan
assert(optimizedPlan.output.map(_.name) === Seq("c1", "c3"))
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@
package org.apache.spark.sql.hive.test

import scala.collection.mutable.Seq
import scala.reflect.runtime.universe.TypeTag

import hivemall.tools.RegressionDatagen

import org.apache.spark.sql.{Column, QueryTest}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.test.SQLTestUtils

/**
Expand Down