Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-34356][ML] OVR transform fix potential column conflict #31472

Closed

Conversation

zhengruifeng
Copy link
Contributor

What changes were proposed in this pull request?

1, clear predictionCol & probabilityCol, use tmp rawPred col, to avoid potential column conflict;
2, use array instead of map, to keep in line with the python side;
3, simplify transform

Why are the changes needed?

if input dataset has a column whose name is predictionCol,probabilityCol,RawPredictionCol, transfrom will fail.

Does this PR introduce any user-facing change?

No

How was this patch tested?

added testsuite

@@ -223,6 +223,13 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
assert(oldCols === newCols)
}

test("SPARK-SPARK-34356: OneVsRestModel.transform should avoid potential column conflict") {
Copy link
Contributor Author

@zhengruifeng zhengruifeng Feb 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test will fail in master and (maybe) all versions of OVR.
but I think fix it in master maybe enough.

image

@zhengruifeng
Copy link
Contributor Author

zhengruifeng commented Feb 4, 2021

in 3.0.1 and master

scala> val df = spark.read.format("libsvm").load("/d0/Dev/Opensource/spark/data/mllib/sample_multiclass_classification_data.txt").withColumn("probability", lit(0.0))
21/02/04 18:06:36 WARN LibSVMFileFormat: 'numFeatures' option not specified, determining the number of features by going though the input. If you know the number in advance, please specify it via 'numFeatures' option to avoid the extra scan.
df: org.apache.spark.sql.DataFrame = [label: double, features: vector ... 1 more field]

scala> 

scala> val classifier = new LogisticRegression().setMaxIter(1).setTol(1E-6).setFitIntercept(true)
classifier: org.apache.spark.ml.classification.LogisticRegression = logreg_5900509aa825

scala> val ovr = new OneVsRest().setClassifier(classifier)
ovr: org.apache.spark.ml.classification.OneVsRest = oneVsRest_dd2b3e9da4e3

scala> val ovrm = ovr.fit(df)
ovrm: org.apache.spark.ml.classification.OneVsRestModel = OneVsRestModel: uid=oneVsRest_dd2b3e9da4e3, classifier=logreg_5900509aa825, numClasses=3, numFeatures=4

scala> ovrm.transform(df)
java.lang.IllegalArgumentException: requirement failed: Column probability already exists.
  at scala.Predef$.require(Predef.scala:281)
  at org.apache.spark.ml.util.SchemaUtils$.appendColumn(SchemaUtils.scala:106)
  at org.apache.spark.ml.util.SchemaUtils$.appendColumn(SchemaUtils.scala:96)
  at org.apache.spark.ml.classification.ProbabilisticClassifierParams.validateAndTransformSchema(ProbabilisticClassifier.scala:38)
  at org.apache.spark.ml.classification.ProbabilisticClassifierParams.validateAndTransformSchema$(ProbabilisticClassifier.scala:33)
  at org.apache.spark.ml.classification.LogisticRegressionModel.org$apache$spark$ml$classification$LogisticRegressionParams$$super$validateAndTransformSchema(LogisticRegression.scala:917)
  at org.apache.spark.ml.classification.LogisticRegressionParams.validateAndTransformSchema(LogisticRegression.scala:268)
  at org.apache.spark.ml.classification.LogisticRegressionParams.validateAndTransformSchema$(LogisticRegression.scala:255)
  at org.apache.spark.ml.classification.LogisticRegressionModel.validateAndTransformSchema(LogisticRegression.scala:917)
  at org.apache.spark.ml.PredictionModel.transformSchema(Predictor.scala:222)
  at org.apache.spark.ml.classification.ClassificationModel.transformSchema(Classifier.scala:182)
  at org.apache.spark.ml.classification.ProbabilisticClassificationModel.transformSchema(ProbabilisticClassifier.scala:88)
  at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:71)
  at org.apache.spark.ml.classification.ProbabilisticClassificationModel.transform(ProbabilisticClassifier.scala:107)
  at org.apache.spark.ml.classification.OneVsRestModel.$anonfun$transform$4(OneVsRest.scala:215)
  at scala.collection.IndexedSeqOptimized.foldLeft(IndexedSeqOptimized.scala:60)
  at scala.collection.IndexedSeqOptimized.foldLeft$(IndexedSeqOptimized.scala:68)
  at scala.collection.mutable.ArrayOps$ofRef.foldLeft(ArrayOps.scala:198)
  at org.apache.spark.ml.classification.OneVsRestModel.transform(OneVsRest.scala:203)
  ... 49 elided

scala> 

@SparkQA
Copy link

SparkQA commented Feb 4, 2021

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/39456/

@SparkQA
Copy link

SparkQA commented Feb 4, 2021

Kubernetes integration test status failure
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/39456/

@SparkQA
Copy link

SparkQA commented Feb 4, 2021

Test build #134869 has finished for PR 31472 at commit 8a47b6e.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@github-actions github-actions bot added the ML label Feb 4, 2021
tmpModel.setPredictionCol("")
tmpModel match {
case m: ProbabilisticClassificationModel[_, _] => m.setProbabilityCol("")
case _ =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this case be silently ignored? if it's always ProbabilisticClassificationModel then just cast?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@SparkQA
Copy link

SparkQA commented Feb 5, 2021

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/39501/

@SparkQA
Copy link

SparkQA commented Feb 5, 2021

Kubernetes integration test status success
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/39501/

@SparkQA
Copy link

SparkQA commented Feb 5, 2021

Test build #134918 has finished for PR 31472 at commit 02725b0.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@srowen srowen closed this in 178dc50 Feb 6, 2021
@zhengruifeng zhengruifeng deleted the ovr_submodel_skip_pred_prob branch February 7, 2021 01:59
@zhengruifeng
Copy link
Contributor Author

thanks @srowen for reviewing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants