Skip to content

Commit

Permalink
[SPARK-21306][ML] OneVsRest should support setWeightCol
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

add `setWeightCol` method for OneVsRest.

`weightCol` is ignored if classifier doesn't inherit HasWeightCol trait.

## How was this patch tested?

+ [x] add an unit test.

Author: Yan Facai (颜发才) <facai.yan@gmail.com>

Closes #18554 from facaiy/BUG/oneVsRest_missing_weightCol.
  • Loading branch information
facaiy authored and yanboliang committed Jul 28, 2017
1 parent f44ead8 commit a5a3189
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
Expand All @@ -53,7 +54,8 @@ private[ml] trait ClassifierTypeTrait {
/**
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {
private[ml] trait OneVsRestParams extends PredictorParams
with ClassifierTypeTrait with HasWeightCol {

/**
* param for the base binary classifier that we reduce multiclass classification into.
Expand Down Expand Up @@ -294,6 +296,18 @@ final class OneVsRest @Since("1.4.0") (
@Since("1.5.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)

/**
* Sets the value of param [[weightCol]].
*
* This is ignored if weight is not supported by [[classifier]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("2.3.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
Expand All @@ -317,7 +331,20 @@ final class OneVsRest @Since("1.4.0") (
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
instr.logNumClasses(numClasses)

val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
val weightColIsUsed = isDefined(weightCol) && $(weightCol).nonEmpty && {
getClassifier match {
case _: HasWeightCol => true
case c =>
logWarning(s"weightCol is ignored, as it is not supported by $c now.")
false
}
}

val multiclassLabeled = if (weightColIsUsed) {
dataset.select($(labelCol), $(featuresCol), $(weightCol))
} else {
dataset.select($(labelCol), $(featuresCol))
}

// persist if underlying dataset is not persistent.
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
Expand All @@ -337,7 +364,13 @@ final class OneVsRest @Since("1.4.0") (
paramMap.put(classifier.labelCol -> labelColName)
paramMap.put(classifier.featuresCol -> getFeaturesCol)
paramMap.put(classifier.predictionCol -> getPredictionCol)
classifier.fit(trainingDataset, paramMap)
if (weightColIsUsed) {
val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol]
paramMap.put(classifier_.weightCol -> getWeightCol)
classifier_.fit(trainingDataset, paramMap)
} else {
classifier.fit(trainingDataset, paramMap)
}
}.toArray[ClassificationModel[_, _]]
instr.logNumFeatures(models.head.numFeatures)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,16 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
}

test("SPARK-21306: OneVsRest should support setWeightCol") {
val dataset2 = dataset.withColumn("weight", lit(1))
// classifier inherits hasWeightCol
val ova = new OneVsRest().setWeightCol("weight").setClassifier(new LogisticRegression())
assert(ova.fit(dataset2) !== null)
// classifier doesn't inherit hasWeightCol
val ova2 = new OneVsRest().setWeightCol("weight").setClassifier(new DecisionTreeClassifier())
assert(ova2.fit(dataset2) !== null)
}

test("OneVsRest.copy and OneVsRestModel.copy") {
val lr = new LogisticRegression()
.setMaxIter(1)
Expand Down
27 changes: 21 additions & 6 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,7 @@ def weights(self):
return self._call_java("weights")


class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol):
"""
Parameters for OneVsRest and OneVsRestModel.
"""
Expand Down Expand Up @@ -1517,20 +1517,22 @@ class OneVsRest(Estimator, OneVsRestParams, JavaMLReadable, JavaMLWritable):

@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
classifier=None):
classifier=None, weightCol=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
classifier=None)
classifier=None, weightCol=None)
"""
super(OneVsRest, self).__init__()
kwargs = self._input_kwargs
self._set(**kwargs)

@keyword_only
@since("2.0.0")
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None,
classifier=None, weightCol=None):
"""
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \
classifier=None, weightCol=None):
Sets params for OneVsRest.
"""
kwargs = self._input_kwargs
Expand All @@ -1546,7 +1548,18 @@ def _fit(self, dataset):

numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1

multiclassLabeled = dataset.select(labelCol, featuresCol)
weightCol = None
if (self.isDefined(self.weightCol) and self.getWeightCol()):
if isinstance(classifier, HasWeightCol):
weightCol = self.getWeightCol()
else:
warnings.warn("weightCol is ignored, "
"as it is not supported by {} now.".format(classifier))

if weightCol:
multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol)
else:
multiclassLabeled = dataset.select(labelCol, featuresCol)

# persist if underlying dataset is not persistent.
handlePersistence = \
Expand All @@ -1562,6 +1575,8 @@ def trainSingleClass(index):
paramMap = dict([(classifier.labelCol, binaryLabelCol),
(classifier.featuresCol, featuresCol),
(classifier.predictionCol, predictionCol)])
if weightCol:
paramMap[classifier.weightCol] = weightCol
return classifier.fit(trainingDataset, paramMap)

# TODO: Parallel training for all classes.
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,20 @@ def test_output_columns(self):
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "prediction"])

def test_support_for_weightCol(self):
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
(1.0, Vectors.sparse(2, [], []), 1.0),
(2.0, Vectors.dense(0.5, 0.5), 1.0)],
["label", "features", "weight"])
# classifier inherits hasWeightCol
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr, weightCol="weight")
self.assertIsNotNone(ovr.fit(df))
# classifier doesn't inherit hasWeightCol
dt = DecisionTreeClassifier()
ovr2 = OneVsRest(classifier=dt, weightCol="weight")
self.assertIsNotNone(ovr2.fit(df))


class HashingTFTest(SparkSessionTestCase):

Expand Down

0 comments on commit a5a3189

Please sign in to comment.