Skip to content

Commit

Permalink
[SPARK-1636][MLLIB] Move main methods to examples
Browse files Browse the repository at this point in the history
* `NaiveBayes` -> `SparseNaiveBayes`
* `KMeans` -> `DenseKMeans`
* `SVMWithSGD` and `LogisticRegerssionWithSGD` -> `BinaryClassification`
* `ALS` -> `MovieLensALS`
* `LinearRegressionWithSGD`, `LassoWithSGD`, and `RidgeRegressionWithSGD` -> `LinearRegression`
* `DecisionTree` -> `DecisionTreeRunner`

`scopt` is used for parsing command-line parameters. `scopt` has MIT license and it only depends on `scala-library`.

Example help message:

~~~
BinaryClassification: an example app for binary classification.
Usage: BinaryClassification [options] <input>

  --numIterations <value>
        number of iterations
  --stepSize <value>
        initial step size, default: 1.0
  --algorithm <value>
        algorithm (SVM,LR), default: LR
  --regType <value>
        regularization type (L1,L2), default: L2
  --regParam <value>
        regularization parameter, default: 0.1
  <input>
        input paths to labeled examples in LIBSVM format
~~~

Author: Xiangrui Meng <meng@databricks.com>

Closes #584 from mengxr/mllib-main and squashes the following commits:

7b58c60 [Xiangrui Meng] minor
6e35d7e [Xiangrui Meng] make imports explicit and fix code style
c6178c9 [Xiangrui Meng] update TS PCA/SVD to use new spark-submit
6acff75 [Xiangrui Meng] use scopt for DecisionTreeRunner
be86069 [Xiangrui Meng] use main instead of extending App
b3edf68 [Xiangrui Meng] move DecisionTree's main method to examples
8bfaa5a [Xiangrui Meng] change NaiveBayesParams to Params
fe23dcb [Xiangrui Meng] remove main from KMeans and add DenseKMeans as an example
67f4448 [Xiangrui Meng] remove main methods from linear regression algorithms and add LinearRegression example
b066bbc [Xiangrui Meng] remove main from ALS and add MovieLensALS example
b040f3b [Xiangrui Meng] change BinaryClassificationParams to Params
577945b [Xiangrui Meng] remove unused imports from NB
3d299bc [Xiangrui Meng] remove main from LR/SVM and add an example app for binary classification
f70878e [Xiangrui Meng] remove main from NaiveBayes and add an example NaiveBayes app
01ec2cd [Xiangrui Meng] Merge branch 'master' into mllib-main
9420692 [Xiangrui Meng] add scopt to examples dependencies
  • Loading branch information
mengxr authored and rxin committed Apr 29, 2014
1 parent 497be3c commit 3f38334
Show file tree
Hide file tree
Showing 19 changed files with 795 additions and 321 deletions.
5 changes: 5 additions & 0 deletions examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.github.scopt</groupId>
<artifactId>scopt_${scala.binary.version}</artifactId>
<version>3.2.0</version>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.mllib

import org.apache.log4j.{Level, Logger}
import scopt.OptionParser

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.classification.{LogisticRegressionWithSGD, SVMWithSGD}
import org.apache.spark.mllib.evaluation.binary.BinaryClassificationMetrics
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater}

/**
* An example app for binary classification. Run with
* {{{
* ./bin/run-example org.apache.spark.examples.mllib.BinaryClassification
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
object BinaryClassification {

object Algorithm extends Enumeration {
type Algorithm = Value
val SVM, LR = Value
}

object RegType extends Enumeration {
type RegType = Value
val L1, L2 = Value
}

import Algorithm._
import RegType._

case class Params(
input: String = null,
numIterations: Int = 100,
stepSize: Double = 1.0,
algorithm: Algorithm = LR,
regType: RegType = L2,
regParam: Double = 0.1)

def main(args: Array[String]) {
val defaultParams = Params()

val parser = new OptionParser[Params]("BinaryClassification") {
head("BinaryClassification: an example app for binary classification.")
opt[Int]("numIterations")
.text("number of iterations")
.action((x, c) => c.copy(numIterations = x))
opt[Double]("stepSize")
.text(s"initial step size, default: ${defaultParams.stepSize}")
.action((x, c) => c.copy(stepSize = x))
opt[String]("algorithm")
.text(s"algorithm (${Algorithm.values.mkString(",")}), " +
s"default: ${defaultParams.algorithm}")
.action((x, c) => c.copy(algorithm = Algorithm.withName(x)))
opt[String]("regType")
.text(s"regularization type (${RegType.values.mkString(",")}), " +
s"default: ${defaultParams.regType}")
.action((x, c) => c.copy(regType = RegType.withName(x)))
opt[Double]("regParam")
.text(s"regularization parameter, default: ${defaultParams.regParam}")
arg[String]("<input>")
.required()
.text("input paths to labeled examples in LIBSVM format")
.action((x, c) => c.copy(input = x))
}

parser.parse(args, defaultParams).map { params =>
run(params)
} getOrElse {
sys.exit(1)
}
}

def run(params: Params) {
val conf = new SparkConf().setAppName(s"BinaryClassification with $params")
val sc = new SparkContext(conf)

Logger.getRootLogger.setLevel(Level.WARN)

val examples = MLUtils.loadLibSVMData(sc, params.input).cache()

val splits = examples.randomSplit(Array(0.8, 0.2))
val training = splits(0).cache()
val test = splits(1).cache()

val numTraining = training.count()
val numTest = test.count()
println(s"Training: $numTraining, test: $numTest.")

examples.unpersist(blocking = false)

val updater = params.regType match {
case L1 => new L1Updater()
case L2 => new SquaredL2Updater()
}

val model = params.algorithm match {
case LR =>
val algorithm = new LogisticRegressionWithSGD()
algorithm.optimizer
.setNumIterations(params.numIterations)
.setStepSize(params.stepSize)
.setUpdater(updater)
.setRegParam(params.regParam)
algorithm.run(training).clearThreshold()
case SVM =>
val algorithm = new SVMWithSGD()
algorithm.optimizer
.setNumIterations(params.numIterations)
.setStepSize(params.stepSize)
.setUpdater(updater)
.setRegParam(params.regParam)
algorithm.run(training).clearThreshold()
}

val prediction = model.predict(test.map(_.features))
val predictionAndLabel = prediction.zip(test.map(_.label))

val metrics = new BinaryClassificationMetrics(predictionAndLabel)

println(s"Test areaUnderPR = ${metrics.areaUnderPR()}.")
println(s"Test areaUnderROC = ${metrics.areaUnderROC()}.")

sc.stop()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.mllib

import scopt.OptionParser

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD

/**
* An example runner for decision tree. Run with
* {{{
* ./bin/spark-example org.apache.spark.examples.mllib.DecisionTreeRunner [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
object DecisionTreeRunner {

object ImpurityType extends Enumeration {
type ImpurityType = Value
val Gini, Entropy, Variance = Value
}

import ImpurityType._

case class Params(
input: String = null,
algo: Algo = Classification,
maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 20)

def main(args: Array[String]) {
val defaultParams = Params()

val parser = new OptionParser[Params]("DecisionTreeRunner") {
head("DecisionTreeRunner: an example decision tree app.")
opt[String]("algo")
.text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}")
.action((x, c) => c.copy(algo = Algo.withName(x)))
opt[String]("impurity")
.text(s"impurity type (${ImpurityType.values.mkString(",")}), " +
s"default: ${defaultParams.impurity}")
.action((x, c) => c.copy(impurity = ImpurityType.withName(x)))
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
.action((x, c) => c.copy(maxDepth = x))
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
arg[String]("<input>")
.text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)")
.required()
.action((x, c) => c.copy(input = x))
checkConfig { params =>
if (params.algo == Classification &&
(params.impurity == Gini || params.impurity == Entropy)) {
success
} else if (params.algo == Regression && params.impurity == Variance) {
success
} else {
failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.")
}
}
}

parser.parse(args, defaultParams).map { params =>
run(params)
}.getOrElse {
sys.exit(1)
}
}

def run(params: Params) {
val conf = new SparkConf().setAppName("DecisionTreeRunner")
val sc = new SparkContext(conf)

// Load training data and cache it.
val examples = MLUtils.loadLabeledData(sc, params.input).cache()

val splits = examples.randomSplit(Array(0.8, 0.2))
val training = splits(0).cache()
val test = splits(1).cache()

val numTraining = training.count()
val numTest = test.count()

println(s"numTraining = $numTraining, numTest = $numTest.")

examples.unpersist(blocking = false)

val impurityCalculator = params.impurity match {
case Gini => impurity.Gini
case Entropy => impurity.Entropy
case Variance => impurity.Variance
}

val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins)
val model = DecisionTree.train(training, strategy)

if (params.algo == Classification) {
val accuracy = accuracyScore(model, test)
println(s"Test accuracy = $accuracy.")
}

if (params.algo == Regression) {
val mse = meanSquaredError(model, test)
println(s"Test mean squared error = $mse.")
}

sc.stop()
}

/**
* Calculates the classifier accuracy.
*/
private def accuracyScore(
model: DecisionTreeModel,
data: RDD[LabeledPoint],
threshold: Double = 0.5): Double = {
def predictedValue(features: Vector): Double = {
if (model.predict(features) < threshold) 0.0 else 1.0
}
val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
val count = data.count()
correctCount.toDouble / count
}

/**
* Calculates the mean squared error for regression.
*/
private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = tree.predict(y.features) - y.label
err * err
}.mean()
}
}
Loading

0 comments on commit 3f38334

Please sign in to comment.