Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into memory-tracking-fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 29, 2015
2 parents ed25d3b + e78ec1a commit 57c9b4e
Show file tree
Hide file tree
Showing 44 changed files with 927 additions and 388 deletions.
2 changes: 1 addition & 1 deletion R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj"))
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
#'
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~' and '+'.
#' operators are supported, including '~', '+', '-', and '.'.
#' @param data DataFrame for training
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
#' @param lambda Regularization parameter
Expand Down
8 changes: 8 additions & 0 deletions R/pkg/inst/tests/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,11 @@ test_that("predictions match with native glm", {
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})

test_that("dot minus and intercept vs native glm", {
training <- createDataFrame(sqlContext, iris)
model <- glm(Sepal_Width ~ . - Species + 0, data = training)
vals <- collect(select(predict(model, training), "prediction"))
rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})
52 changes: 18 additions & 34 deletions mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,20 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
/** @group getParam */
def getFormula: String = $(formula)

/** Whether the formula specifies fitting an intercept. */
private[ml] def hasIntercept: Boolean = {
require(parsedFormula.isDefined, "Must call setFormula() first.")
parsedFormula.get.hasIntercept
}

override def fit(dataset: DataFrame): RFormulaModel = {
require(parsedFormula.isDefined, "Must call setFormula() first.")
val resolvedFormula = parsedFormula.get.resolve(dataset.schema)
// StringType terms and terms representing interactions need to be encoded before assembly.
// TODO(ekl) add support for feature interactions
var encoderStages = ArrayBuffer[PipelineStage]()
var tempColumns = ArrayBuffer[String]()
val encodedTerms = parsedFormula.get.terms.map { term =>
val encoderStages = ArrayBuffer[PipelineStage]()
val tempColumns = ArrayBuffer[String]()
val encodedTerms = resolvedFormula.terms.map { term =>
dataset.schema(term) match {
case column if column.dataType == StringType =>
val indexCol = term + "_idx_" + uid
Expand All @@ -103,7 +110,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
.setOutputCol($(featuresCol))
encoderStages += new ColumnPruner(tempColumns.toSet)
val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
copyValues(new RFormulaModel(uid, parsedFormula.get, pipelineModel).setParent(this))
copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this))
}

// optimistic schema; does not contain any ML attributes
Expand All @@ -124,13 +131,13 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
/**
* :: Experimental ::
* A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
* @param parsedFormula a pre-parsed R formula.
* @param resolvedFormula the fitted R formula.
* @param pipelineModel the fitted feature model, including factor to index mappings.
*/
@Experimental
class RFormulaModel private[feature](
override val uid: String,
parsedFormula: ParsedRFormula,
resolvedFormula: ResolvedRFormula,
pipelineModel: PipelineModel)
extends Model[RFormulaModel] with RFormulaBase {

Expand All @@ -144,8 +151,8 @@ class RFormulaModel private[feature](
val withFeatures = pipelineModel.transformSchema(schema)
if (hasLabelCol(schema)) {
withFeatures
} else if (schema.exists(_.name == parsedFormula.label)) {
val nullable = schema(parsedFormula.label).dataType match {
} else if (schema.exists(_.name == resolvedFormula.label)) {
val nullable = schema(resolvedFormula.label).dataType match {
case _: NumericType | BooleanType => false
case _ => true
}
Expand All @@ -158,12 +165,12 @@ class RFormulaModel private[feature](
}

override def copy(extra: ParamMap): RFormulaModel = copyValues(
new RFormulaModel(uid, parsedFormula, pipelineModel))
new RFormulaModel(uid, resolvedFormula, pipelineModel))

override def toString: String = s"RFormulaModel(${parsedFormula})"
override def toString: String = s"RFormulaModel(${resolvedFormula})"

private def transformLabel(dataset: DataFrame): DataFrame = {
val labelName = parsedFormula.label
val labelName = resolvedFormula.label
if (hasLabelCol(dataset.schema)) {
dataset
} else if (dataset.schema.exists(_.name == labelName)) {
Expand Down Expand Up @@ -207,26 +214,3 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {

override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
}

/**
* Represents a parsed R formula.
*/
private[ml] case class ParsedRFormula(label: String, terms: Seq[String])

/**
* Limited implementation of R formula parsing. Currently supports: '~', '+'.
*/
private[ml] object RFormulaParser extends RegexParsers {
def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r

def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list }

def formula: Parser[ParsedRFormula] =
(term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.distinct) }

def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
case Success(result, _) => result
case failure: NoSuccess => throw new IllegalArgumentException(
"Could not parse formula: " + value)
}
}
129 changes: 129 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import scala.util.parsing.combinator.RegexParsers

import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.sql.types._

/**
* Represents a parsed R formula.
*/
private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
/**
* Resolves formula terms into column names. A schema is necessary for inferring the meaning
* of the special '.' term. Duplicate terms will be removed during resolution.
*/
def resolve(schema: StructType): ResolvedRFormula = {
var includedTerms = Seq[String]()
terms.foreach {
case Dot =>
includedTerms ++= simpleTypes(schema).filter(_ != label.value)
case ColumnRef(value) =>
includedTerms :+= value
case Deletion(term: Term) =>
term match {
case ColumnRef(value) =>
includedTerms = includedTerms.filter(_ != value)
case Dot =>
// e.g. "- .", which removes all first-order terms
val fromSchema = simpleTypes(schema)
includedTerms = includedTerms.filter(fromSchema.contains(_))
case _: Deletion =>
assert(false, "Deletion terms cannot be nested")
case _: Intercept =>
}
case _: Intercept =>
}
ResolvedRFormula(label.value, includedTerms.distinct)
}

/** Whether this formula specifies fitting with an intercept term. */
def hasIntercept: Boolean = {
var intercept = true
terms.foreach {
case Intercept(enabled) =>
intercept = enabled
case Deletion(Intercept(enabled)) =>
intercept = !enabled
case _ =>
}
intercept
}

// the dot operator excludes complex column types
private def simpleTypes(schema: StructType): Seq[String] = {
schema.fields.filter(_.dataType match {
case _: NumericType | StringType | BooleanType | _: VectorUDT => true
case _ => false
}).map(_.name)
}
}

/**
* Represents a fully evaluated and simplified R formula.
*/
private[ml] case class ResolvedRFormula(label: String, terms: Seq[String])

/**
* R formula terms. See the R formula docs here for more information:
* http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
*/
private[ml] sealed trait Term

/* R formula reference to all available columns, e.g. "." in a formula */
private[ml] case object Dot extends Term

/* R formula reference to a column, e.g. "+ Species" in a formula */
private[ml] case class ColumnRef(value: String) extends Term

/* R formula intercept toggle, e.g. "+ 0" in a formula */
private[ml] case class Intercept(enabled: Boolean) extends Term

/* R formula deletion of a variable, e.g. "- Species" in a formula */
private[ml] case class Deletion(term: Term) extends Term

/**
* Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.'.
*/
private[ml] object RFormulaParser extends RegexParsers {
def intercept: Parser[Intercept] =
"([01])".r ^^ { case a => Intercept(a == "1") }

def columnRef: Parser[ColumnRef] =
"([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) }

def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot }

def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ {
case op ~ list => list.foldLeft(List(op)) {
case (left, "+" ~ right) => left ++ Seq(right)
case (left, "-" ~ right) => left ++ Seq(Deletion(right))
}
}

def formula: Parser[ParsedRFormula] =
(columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }

def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
case Success(result, _) => result
case failure: NoSuccess => throw new IllegalArgumentException(
"Could not parse formula: " + value)
}
}
10 changes: 8 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,14 @@ private[r] object SparkRWrappers {
alpha: Double): PipelineModel = {
val formula = new RFormula().setFormula(value)
val estimator = family match {
case "gaussian" => new LinearRegression().setRegParam(lambda).setElasticNetParam(alpha)
case "binomial" => new LogisticRegression().setRegParam(lambda).setElasticNetParam(alpha)
case "gaussian" => new LinearRegression()
.setRegParam(lambda)
.setElasticNetParam(alpha)
.setFitIntercept(formula.hasIntercept)
case "binomial" => new LogisticRegression()
.setRegParam(lambda)
.setElasticNetParam(alpha)
.setFitIntercept(formula.hasIntercept)
}
val pipeline = new Pipeline().setStages(Array(formula, estimator))
pipeline.fit(df)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.api.python

import java.util.{List => JList}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix}
import org.apache.spark.mllib.clustering.GaussianMixtureModel

/**
* Wrapper around GaussianMixtureModel to provide helper methods in Python
*/
private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
val weights: Vector = Vectors.dense(model.weights)
val k: Int = weights.size

/**
* Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian
*/
val gaussians: JList[Object] = {
val modelGaussians = model.gaussians
var i = 0
var mu = ArrayBuffer.empty[Vector]
var sigma = ArrayBuffer.empty[Matrix]
while (i < k) {
mu += modelGaussians(i).mu
sigma += modelGaussians(i).sigma
i += 1
}
List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
}

def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
}
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ private[python] class PythonMLLibAPI extends Serializable {
seed: java.lang.Long,
initialModelWeights: java.util.ArrayList[Double],
initialModelMu: java.util.ArrayList[Vector],
initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = {
initialModelSigma: java.util.ArrayList[Matrix]): GaussianMixtureModelWrapper = {
val gmmAlg = new GaussianMixture()
.setK(k)
.setConvergenceTol(convergenceTol)
Expand All @@ -382,16 +382,7 @@ private[python] class PythonMLLibAPI extends Serializable {
if (seed != null) gmmAlg.setSeed(seed)

try {
val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
var wt = ArrayBuffer.empty[Double]
var mu = ArrayBuffer.empty[Vector]
var sigma = ArrayBuffer.empty[Matrix]
for (i <- 0 until model.k) {
wt += model.weights(i)
mu += model.gaussians(i).mu
sigma += model.gaussians(i).sigma
}
List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
new GaussianMixtureModelWrapper(gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)))
} finally {
data.rdd.unpersist(blocking = false)
}
Expand Down
Loading

0 comments on commit 57c9b4e

Please sign in to comment.