Skip to content

Commit

Permalink
removing everything except for simple class hierarchy for classification
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Feb 5, 2015
1 parent d35bb5d commit 52f4fde
Show file tree
Hide file tree
Showing 17 changed files with 8 additions and 717 deletions.
6 changes: 2 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/LabeledPoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@ import org.apache.spark.mllib.linalg.Vector
*/
case class LabeledPoint(label: Double, features: Vector, weight: Double) {

/** Default constructor which sets instance weight to 1.0 */
def this(label: Double, features: Vector) = this(label, features, 1.0)

override def toString: String = {
"(%s,%s,%s)".format(label, features, weight)
}
}

object LabeledPoint {
def apply(label: Double, features: Vector) = new LabeledPoint(label, features)
/** Constructor which sets instance weight to 1.0 */
def apply(label: Double, features: Vector) = new LabeledPoint(label, features, 1.0)
}
209 changes: 0 additions & 209 deletions mllib/src/main/scala/org/apache/spark/ml/classification/AdaBoost.scala

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@
package org.apache.spark.ml.classification

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.evaluation.ClassificationEvaluator
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.ml._
import org.apache.spark.ml.impl.estimator.{HasDefaultEvaluator, PredictionModel, Predictor,
PredictorParams}
import org.apache.spark.rdd.RDD

@AlphaComponent
private[classification] trait ClassifierParams extends PredictorParams
Expand All @@ -33,10 +29,9 @@ private[classification] trait ClassifierParams extends PredictorParams
*/
abstract class Classifier[Learner <: Classifier[Learner, M], M <: ClassificationModel[M]]
extends Predictor[Learner, M]
with ClassifierParams
with HasDefaultEvaluator {
with ClassifierParams {

override def defaultEvaluator: Evaluator = new ClassificationEvaluator
// TODO: defaultEvaluator (follow-up PR)
}


Expand All @@ -60,14 +55,6 @@ private[ml] abstract class ClassificationModel[M <: ClassificationModel[M]]
*/
def predictRaw(features: Vector): Vector

/**
* Compute this model's accuracy on the given dataset.
*/
def accuracy(dataset: RDD[LabeledPoint]): Double = {
// TODO: Handle instance weights.
val predictionsAndLabels = dataset.map(lp => predict(lp.features))
.zip(dataset.map(_.label))
ClassificationEvaluator.computeMetric(predictionsAndLabels, "accuracy")
}
// TODO: accuracy(dataset: RDD[LabeledPoint]): Double (follow-up PR)

}
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ private[classification] trait LogisticRegressionParams extends ClassifierParams
class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressionModel]
with LogisticRegressionParams {

// TODO: Extend IterativeEstimator

setRegParam(0.1)
setMaxIter(100)
setThreshold(0.5)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.DoubleType

Expand Down Expand Up @@ -57,16 +56,8 @@ class BinaryClassificationEvaluator extends Evaluator with Params
.map { case Row(score: Double, label: Double) =>
(score, label)
}
BinaryClassificationEvaluator.computeMetric(scoreAndLabels, map(metricName))
}

}

private[ml] object BinaryClassificationEvaluator {

def computeMetric(scoreAndLabels: RDD[(Double, Double)], metricName: String): Double = {
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val metric = metricName match {
val metric = map(metricName) match {
case "areaUnderROC" =>
metrics.areaUnderROC()
case "areaUnderPR" =>
Expand All @@ -77,5 +68,4 @@ private[ml] object BinaryClassificationEvaluator {
metrics.unpersist()
metric
}

}
Loading

0 comments on commit 52f4fde

Please sign in to comment.