Skip to content

Commit

Permalink
[SPARK-21306][ML] OneVsRest should support setWeightCol
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

add `setWeightCol` method for OneVsRest.

`weightCol` is ignored if classifier doesn't inherit HasWeightCol trait.

## How was this patch tested?

+ [x] add an unit test.

Author: Yan Facai (颜发才) <facai.yan@gmail.com>

Closes #18554 from facaiy/BUG/oneVsRest_missing_weightCol.

(cherry picked from commit a5a3189)
Signed-off-by: Yanbo Liang <ybliang8@gmail.com>
  • Loading branch information
facaiy authored and yanboliang committed Jul 28, 2017
1 parent 9498798 commit 8520d7c
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 9 deletions.
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
Expand All @@ -53,7 +54,8 @@ private[ml] trait ClassifierTypeTrait {
/**
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {
private[ml] trait OneVsRestParams extends PredictorParams
with ClassifierTypeTrait with HasWeightCol {

/**
* param for the base binary classifier that we reduce multiclass classification into.
Expand Down Expand Up @@ -299,6 +301,18 @@ final class OneVsRest @Since("1.4.0") (
@Since("1.5.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)

/**
* Sets the value of param [[weightCol]].
*
* This is ignored if weight is not supported by [[classifier]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("2.3.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
Expand All @@ -317,7 +331,20 @@ final class OneVsRest @Since("1.4.0") (
}
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)

val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
val weightColIsUsed = isDefined(weightCol) && $(weightCol).nonEmpty && {
getClassifier match {
case _: HasWeightCol => true
case c =>
logWarning(s"weightCol is ignored, as it is not supported by $c now.")
false
}
}

val multiclassLabeled = if (weightColIsUsed) {
dataset.select($(labelCol), $(featuresCol), $(weightCol))
} else {
dataset.select($(labelCol), $(featuresCol))
}

// persist if underlying dataset is not persistent.
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
Expand All @@ -337,7 +364,13 @@ final class OneVsRest @Since("1.4.0") (
paramMap.put(classifier.labelCol -> labelColName)
paramMap.put(classifier.featuresCol -> getFeaturesCol)
paramMap.put(classifier.predictionCol -> getPredictionCol)
classifier.fit(trainingDataset, paramMap)
if (weightColIsUsed) {
val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol]
paramMap.put(classifier_.weightCol -> getWeightCol)
classifier_.fit(trainingDataset, paramMap)
} else {
classifier.fit(trainingDataset, paramMap)
}
}.toArray[ClassificationModel[_, _]]

if (handlePersistence) {
Expand Down
Expand Up @@ -157,6 +157,16 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
}

test("SPARK-21306: OneVsRest should support setWeightCol") {
val dataset2 = dataset.withColumn("weight", lit(1))

This comment has been minimized.

Copy link
@facaiy

facaiy Jul 28, 2017

Author Contributor

@yanboliang @srowen Sigh. The exception is resolved by lit(1.0) on branch-2.1.

// classifier inherits hasWeightCol
val ova = new OneVsRest().setWeightCol("weight").setClassifier(new LogisticRegression())
assert(ova.fit(dataset2) !== null)
// classifier doesn't inherit hasWeightCol
val ova2 = new OneVsRest().setWeightCol("weight").setClassifier(new DecisionTreeClassifier())
assert(ova2.fit(dataset2) !== null)
}

test("OneVsRest.copy and OneVsRestModel.copy") {
val lr = new LogisticRegression()
.setMaxIter(1)
Expand Down
27 changes: 21 additions & 6 deletions python/pyspark/ml/classification.py
Expand Up @@ -1331,7 +1331,7 @@ def weights(self):
return self._call_java("weights")


class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol):
"""
Parameters for OneVsRest and OneVsRestModel.
"""
Expand Down Expand Up @@ -1394,20 +1394,22 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):

@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
classifier=None):
classifier=None, weightCol=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
classifier=None)
classifier=None, weightCol=None)
"""
super(OneVsRest, self).__init__()
kwargs = self._input_kwargs
self._set(**kwargs)

@keyword_only
@since("2.0.0")
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None,
classifier=None, weightCol=None):
"""
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \
classifier=None, weightCol=None):
Sets params for OneVsRest.
"""
kwargs = self._input_kwargs
Expand All @@ -1423,7 +1425,18 @@ def _fit(self, dataset):

numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1

multiclassLabeled = dataset.select(labelCol, featuresCol)
weightCol = None
if (self.isDefined(self.weightCol) and self.getWeightCol()):
if isinstance(classifier, HasWeightCol):
weightCol = self.getWeightCol()
else:
warnings.warn("weightCol is ignored, "
"as it is not supported by {} now.".format(classifier))

if weightCol:
multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol)
else:
multiclassLabeled = dataset.select(labelCol, featuresCol)

# persist if underlying dataset is not persistent.
handlePersistence = \
Expand All @@ -1439,6 +1452,8 @@ def trainSingleClass(index):
paramMap = dict([(classifier.labelCol, binaryLabelCol),
(classifier.featuresCol, featuresCol),
(classifier.predictionCol, predictionCol)])
if weightCol:
paramMap[classifier.weightCol] = weightCol
return classifier.fit(trainingDataset, paramMap)

# TODO: Parallel training for all classes.
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/ml/tests.py
Expand Up @@ -1218,6 +1218,20 @@ def test_output_columns(self):
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "prediction"])

def test_support_for_weightCol(self):
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
(1.0, Vectors.sparse(2, [], []), 1.0),
(2.0, Vectors.dense(0.5, 0.5), 1.0)],
["label", "features", "weight"])
# classifier inherits hasWeightCol
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr, weightCol="weight")
self.assertIsNotNone(ovr.fit(df))
# classifier doesn't inherit hasWeightCol
dt = DecisionTreeClassifier()
ovr2 = OneVsRest(classifier=dt, weightCol="weight")
self.assertIsNotNone(ovr2.fit(df))


class HashingTFTest(SparkSessionTestCase):

Expand Down

6 comments on commit 8520d7c

@srowen
Copy link
Member

@srowen srowen commented on 8520d7c Jul 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yanboliang @facaiy I think this makes branch 2.1 fail:

https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-branch-2.1-test-sbt-hadoop-2.7/578/

[info] - SPARK-21306: OneVsRest should support setWeightCol *** FAILED *** (180 milliseconds)
[info]   org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 382.0 failed 1 times, most recent failure: Lost task 0.0 in stage 382.0 (TID 759, localhost, executor driver): scala.MatchError: [0.0,1,[6.7885086340489185,3.4576551565453197,2.0812768587303507,0.3552148531053505]] (of class org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema)
[info] 	at org.apache.spark.ml.classification.LogisticRegression$$anonfun$12.apply(LogisticRegression.scala:330)
[info] 	at org.apache.spark.ml.classification.LogisticRegression$$anonfun$12.apply(LogisticRegression.scala:330)
[info] 	at scala.collection.Iterator$$anon$11.next(Iterator.scala:409)
[info] 	at org.apache.spark.storage.memory.MemoryStore.putIteratorAsValues(MemoryStore.scala:216)
...

Branch 2.0 seems to not compile?

[error] /home/jenkins/workspace/spark-branch-2.0-test-maven-hadoop-2.2/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala:147: not found: value lit
[error]     val dataset2 = dataset.withColumn("weight", lit(1))
[error]                                                 ^
[error] one error found
[error] Compile failed at Jul 27, 2017 7:34:14 PM [20.414s]

@facaiy
Copy link
Contributor Author

@facaiy facaiy commented on 8520d7c Jul 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check it this weekend, @srowen.

@yanboliang
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srowen Thanks for catching this. I will check it as well, maybe we can't directly merge it into branch before 2.1.

@yanboliang
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srowen We can't directly merge this PR into branch-2.1/2.0, I have reverted it from these branches.
@facaiy Would you like to prepare separate PRs for branch-2.1/2.0? Then we can merge them. Please let me know if you have any problem. Thanks for all.

@facaiy
Copy link
Contributor Author

@facaiy facaiy commented on 8520d7c Jul 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem. I'd like to work on it this weekend. However, the SparkException on 2.1 is quite amazing for me, so any help / suggestion will be appreciated.

@facaiy
Copy link
Contributor Author

@facaiy facaiy commented on 8520d7c Jul 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yanboliang @srowen

Two PRs are opened:
For branch-2.1, #18763: lit(1) -> lit(1.0)
For branch 2.0, #18764: lit(1.0) and import function._

Please sign in to comment.