Skip to content

Commit

Permalink
Merge branch 'master' into SPARK-13579
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcelo Vanzin committed Mar 31, 2016
2 parents ef19b3c + 8b207f3 commit 9275ea6
Show file tree
Hide file tree
Showing 80 changed files with 1,158 additions and 8,466 deletions.
4 changes: 3 additions & 1 deletion dev/deps/spark-deps-hadoop-2.2
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ JavaEWAH-0.3.2.jar
RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.jar
antlr-runtime-3.5.2.jar
antlr-2.7.7.jar
antlr-runtime-3.4.jar
antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
Expand Down Expand Up @@ -172,6 +173,7 @@ spire_2.11-0.7.4.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
univocity-parsers-1.5.6.jar
xbean-asm5-shaded-4.4.jar
Expand Down
4 changes: 3 additions & 1 deletion dev/deps/spark-deps-hadoop-2.3
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ JavaEWAH-0.3.2.jar
RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
antlr-2.7.7.jar
antlr-runtime-3.4.jar
antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
Expand Down Expand Up @@ -163,6 +164,7 @@ spire_2.11-0.7.4.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
univocity-parsers-1.5.6.jar
xbean-asm5-shaded-4.4.jar
Expand Down
4 changes: 3 additions & 1 deletion dev/deps/spark-deps-hadoop-2.4
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ JavaEWAH-0.3.2.jar
RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
antlr-2.7.7.jar
antlr-runtime-3.4.jar
antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
Expand Down Expand Up @@ -164,6 +165,7 @@ spire_2.11-0.7.4.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
univocity-parsers-1.5.6.jar
xbean-asm5-shaded-4.4.jar
Expand Down
4 changes: 3 additions & 1 deletion dev/deps/spark-deps-hadoop-2.6
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ JavaEWAH-0.3.2.jar
RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
antlr-2.7.7.jar
antlr-runtime-3.4.jar
antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
Expand Down Expand Up @@ -170,6 +171,7 @@ spire_2.11-0.7.4.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
univocity-parsers-1.5.6.jar
xbean-asm5-shaded-4.4.jar
Expand Down
4 changes: 3 additions & 1 deletion dev/deps/spark-deps-hadoop-2.7
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ JavaEWAH-0.3.2.jar
RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
antlr-2.7.7.jar
antlr-runtime-3.4.jar
antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
Expand Down Expand Up @@ -171,6 +172,7 @@ spire_2.11-0.7.4.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
univocity-parsers-1.5.6.jar
xbean-asm5-shaded-4.4.jar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams
"Sizes of layers from input layer to output layer" +
" E.g., Array(780, 100, 10) means 780 inputs, " +
"one hidden layer with 100 neurons and output layer of 10 neurons.",
// TODO: how to check ALSO that all elements are greater than 0?
ParamValidators.arrayLengthGt(1)
(t: Array[Int]) => t.forall(ParamValidators.gt(0)) && t.length > 1
)

/** @group getParam */
Expand Down
165 changes: 154 additions & 11 deletions mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,37 @@ import java.util.UUID

import scala.language.existentials

import org.apache.hadoop.fs.Path
import org.json4s.{DefaultFormats, JObject, _}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel

/**
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams {

private[ml] trait ClassifierTypeTrait {
// scalastyle:off structural.type
type ClassifierType = Classifier[F, E, M] forSome {
type F
type M <: ClassificationModel[F, M]
type E <: Classifier[F, E, M]
}
// scalastyle:on structural.type
}

/**
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {

/**
* param for the base binary classifier that we reduce multiclass classification into.
Expand All @@ -57,6 +65,55 @@ private[ml] trait OneVsRestParams extends PredictorParams {
def getClassifier: ClassifierType = $(classifier)
}

private[ml] object OneVsRestParams extends ClassifierTypeTrait {

def validateParams(instance: OneVsRestParams): Unit = {
def checkElement(elem: Params, name: String): Unit = elem match {
case stage: MLWritable => // good
case other =>
throw new UnsupportedOperationException("OneVsRest write will fail " +
s" because it contains $name which does not implement MLWritable." +
s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
}

instance match {
case ovrModel: OneVsRestModel => ovrModel.models.foreach(checkElement(_, "model"))
case _ => // no need to check OneVsRest here
}

checkElement(instance.getClassifier, "classifier")
}

def saveImpl(
path: String,
instance: OneVsRestParams,
sc: SparkContext,
extraMetadata: Option[JObject] = None): Unit = {

val params = instance.extractParamMap().toSeq
val jsonParams = render(params
.filter { case ParamPair(p, v) => p.name != "classifier" }
.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }
.toList)

DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))

val classifierPath = new Path(path, "classifier").toString
instance.getClassifier.asInstanceOf[MLWritable].save(classifierPath)
}

def loadImpl(
path: String,
sc: SparkContext,
expectedClassName: String): (DefaultParamsReader.Metadata, ClassifierType) = {

val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
val classifierPath = new Path(path, "classifier").toString
val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, sc)
(metadata, estimator)
}
}

/**
* :: Experimental ::
* Model produced by [[OneVsRest]].
Expand All @@ -73,10 +130,10 @@ private[ml] trait OneVsRestParams extends PredictorParams {
@Since("1.4.0")
@Experimental
final class OneVsRestModel private[ml] (
@Since("1.4.0") override val uid: String,
@Since("1.4.0") labelMetadata: Metadata,
@Since("1.4.0") override val uid: String,
private[ml] val labelMetadata: Metadata,
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
extends Model[OneVsRestModel] with OneVsRestParams {
extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
Expand Down Expand Up @@ -143,6 +200,56 @@ final class OneVsRestModel private[ml] (
uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
copyValues(copied, extra).setParent(parent)
}

@Since("2.0.0")
override def write: MLWriter = new OneVsRestModel.OneVsRestModelWriter(this)
}

@Since("2.0.0")
object OneVsRestModel extends MLReadable[OneVsRestModel] {

@Since("2.0.0")
override def read: MLReader[OneVsRestModel] = new OneVsRestModelReader

@Since("2.0.0")
override def load(path: String): OneVsRestModel = super.load(path)

/** [[MLWriter]] instance for [[OneVsRestModel]] */
private[OneVsRestModel] class OneVsRestModelWriter(instance: OneVsRestModel) extends MLWriter {

OneVsRestParams.validateParams(instance)

override protected def saveImpl(path: String): Unit = {
val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~
("numClasses" -> instance.models.length)
OneVsRestParams.saveImpl(path, instance, sc, Some(extraJson))
instance.models.zipWithIndex.foreach { case (model: MLWritable, idx) =>
val modelPath = new Path(path, s"model_$idx").toString
model.save(modelPath)
}
}
}

private class OneVsRestModelReader extends MLReader[OneVsRestModel] {

/** Checked against metadata when loading model */
private val className = classOf[OneVsRestModel].getName

override def load(path: String): OneVsRestModel = {
implicit val format = DefaultFormats
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String])
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val models = Range(0, numClasses).toArray.map { idx =>
val modelPath = new Path(path, s"model_$idx").toString
DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc)
}
val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models)
DefaultParamsReader.getAndSetParams(ovrModel, metadata)
ovrModel.set("classifier", classifier)
ovrModel
}
}
}

/**
Expand All @@ -158,7 +265,7 @@ final class OneVsRestModel private[ml] (
@Experimental
final class OneVsRest @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends Estimator[OneVsRestModel] with OneVsRestParams {
extends Estimator[OneVsRestModel] with OneVsRestParams with MLWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("oneVsRest"))
Expand Down Expand Up @@ -243,4 +350,40 @@ final class OneVsRest @Since("1.4.0") (
}
copied
}

@Since("2.0.0")
override def write: MLWriter = new OneVsRest.OneVsRestWriter(this)
}

@Since("2.0.0")
object OneVsRest extends MLReadable[OneVsRest] {

@Since("2.0.0")
override def read: MLReader[OneVsRest] = new OneVsRestReader

@Since("2.0.0")
override def load(path: String): OneVsRest = super.load(path)

/** [[MLWriter]] instance for [[OneVsRest]] */
private[OneVsRest] class OneVsRestWriter(instance: OneVsRest) extends MLWriter {

OneVsRestParams.validateParams(instance)

override protected def saveImpl(path: String): Unit = {
OneVsRestParams.saveImpl(path, instance, sc)
}
}

private class OneVsRestReader extends MLReader[OneVsRest] {

/** Checked against metadata when loading model */
private val className = classOf[OneVsRest].getName

override def load(path: String): OneVsRest = {
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
val ovr = new OneVsRest(metadata.uid)
DefaultParamsReader.getAndSetParams(ovr, metadata)
ovr.setClassifier(classifier)
}
}
}
Loading

0 comments on commit 9275ea6

Please sign in to comment.