Skip to content

Commit

Permalink
Fixed bug in PipelineModel.transform* with usage of params. Updated C…
Browse files Browse the repository at this point in the history
…rossValidatorExample to use more training examples so it is less likely to get a 0-size fold.
  • Loading branch information
jkbradley committed Dec 3, 2014
1 parent ea34dc6 commit 99f88c2
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.examples.ml;

import java.util.ArrayList;
import java.util.List;

import com.google.common.collect.Lists;
Expand All @@ -28,7 +27,6 @@
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
Expand Down Expand Up @@ -65,7 +63,15 @@ public static void main(String[] args) {
new LabeledDocument(0L, "a b c d e spark", 1.0),
new LabeledDocument(1L, "b d", 0.0),
new LabeledDocument(2L, "spark f g h", 1.0),
new LabeledDocument(3L, "hadoop mapreduce", 0.0));
new LabeledDocument(3L, "hadoop mapreduce", 0.0),
new LabeledDocument(4L, "b spark who", 1.0),
new LabeledDocument(5L, "g d a y", 0.0),
new LabeledDocument(6L, "spark fly", 1.0),
new LabeledDocument(7L, "was mapreduce", 0.0),
new LabeledDocument(8L, "e spark program", 1.0),
new LabeledDocument(9L, "a e c l", 0.0),
new LabeledDocument(10L, "spark compile", 1.0),
new LabeledDocument(11L, "hadoop software", 0.0));
JavaSchemaRDD training =
jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);

Expand Down Expand Up @@ -112,8 +118,8 @@ public static void main(String[] args) {
new Document(7L, "apache hadoop"));
JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class);

// Make predictions on test documents.
lrModel.transform(test).registerAsTable("prediction");
// Make predictions on test documents. cvModel uses the best model found (lrModel).
cvModel.transform(test).registerAsTable("prediction");
JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
for (Row r: predictions.collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.examples.ml

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
Expand Down Expand Up @@ -50,7 +51,15 @@ object CrossValidatorExample {
LabeledDocument(0L, "a b c d e spark", 1.0),
LabeledDocument(1L, "b d", 0.0),
LabeledDocument(2L, "spark f g h", 1.0),
LabeledDocument(3L, "hadoop mapreduce", 0.0)))
LabeledDocument(3L, "hadoop mapreduce", 0.0),
LabeledDocument(4L, "b spark who", 1.0),
LabeledDocument(5L, "g d a y", 0.0),
LabeledDocument(6L, "spark fly", 1.0),
LabeledDocument(7L, "was mapreduce", 0.0),
LabeledDocument(8L, "e spark program", 1.0),
LabeledDocument(9L, "a e c l", 0.0),
LabeledDocument(10L, "spark compile", 1.0),
LabeledDocument(11L, "hadoop software", 0.0)))

// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
val tokenizer = new Tokenizer()
Expand Down Expand Up @@ -81,16 +90,7 @@ object CrossValidatorExample {
crossval.setNumFolds(2)

// Run cross-validation, and choose the best set of parameters.
val cvModel = try {
crossval.fit(training)
} catch {
case e: Exception =>
println("\nSTACK TRACE\n")
println(e.getStackTraceString)
println("\nSTACK TRACE OF CAUSE\n")
println(e.getCause.getStackTraceString)
throw e
}
val cvModel = crossval.fit(training)
// Get the best LogisticRegression model (with the best set of parameters from paramGrid).
val lrModel = cvModel.bestModel

Expand All @@ -101,8 +101,8 @@ object CrossValidatorExample {
Document(6L, "mapreduce spark"),
Document(7L, "apache hadoop")))

// Make predictions on test documents using the best LogisticRegression model.
lrModel.transform(test)
// Make predictions on test documents. cvModel uses the best model found (lrModel).
cvModel.transform(test)
.select('id, 'text, 'score, 'prediction)
.collect()
.foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
Expand Down
10 changes: 7 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,15 @@ class PipelineModel private[ml] (
}

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
transformSchema(dataset.schema, paramMap, logging = true)
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap))
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ fittingParamMap
transformSchema(dataset.schema, map, logging = true)
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
}

private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, paramMap))
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ fittingParamMap
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ private[spark] object BLAS extends Serializable with Logging {
* dot(x, y)
*/
def dot(x: Vector, y: Vector): Double = {
require(x.size == y.size)
require(x.size == y.size,
"BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" +
" x.size = " + x.size + ", y.size = " + y.size)
(x, y) match {
case (dx: DenseVector, dy: DenseVector) =>
dot(dx, dy)
Expand Down

0 comments on commit 99f88c2

Please sign in to comment.