Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10026] [ML] [PySpark] Implement some common Params for regression in PySpark #8508

Closed
wants to merge 4 commits into from

Conversation

yanboliang
Copy link
Contributor

LinearRegression and LogisticRegression lack of some Params for Python, and some Params are not shared classes which lead we need to write them for each class. These kinds of Params are list here:

HasElasticNetParam 
HasFitIntercept
HasStandardization
HasThresholds

Here we implement them in shared params at Python side and make LinearRegression/LogisticRegression parameters peer with Scala one.

@SparkQA
Copy link

SparkQA commented Aug 28, 2015

Test build #41744 has finished for PR 8508 at commit 730b0a7.

  • This patch fails Python style tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • ("thresholds", "Thresholds in multi-class classification to adjust the probability of " +
    • class HasElasticNetParam(Params):
    • class HasFitIntercept(Params):
    • class HasStandardization(Params):
    • class HasThresholds(Params):
    • thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.")
    • self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.")

@SparkQA
Copy link

SparkQA commented Aug 28, 2015

Test build #41745 has finished for PR 8508 at commit d44ac06.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • ("thresholds", "Thresholds in multi-class classification to adjust the probability of " +
    • class HasElasticNetParam(Params):
    • class HasFitIntercept(Params):
    • class HasStandardization(Params):
    • class HasThresholds(Params):
    • thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.")
    • self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.")

" to adjust the probability of predicting each class." +
" Array must have length equal to the number of classes, with values >= 0." +
" The class with largest value p/t is predicted, where p is the original" +
" probability of that class and t is the class' threshold.")
threshold = Param(Params._dummy(), "threshold",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should also extract a HasThreshold mixin for binary classifier thresholds

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

threshold is a deprecated parameter, it is replaced by thresholds. LogisticRegression still reserve threshold is just for binary compatibility. So I think we don't need to extract HasThreshold as shared Param. @jkbradley

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that the HasThresholds trait mixin in ml.LogisticRegression is actually an artifact resulting from transient dependency through ProbabilisticClassifier. We don't actually support multi-class classification in ml.LogisticRegression ATM and did quite a bit of work to make the API less confusing.

After mutli-class is supported I think it makes sense to use HasThresholds, but for the time being I would prefer we only use HasThreshold in the Python API.

@feynmanliang
Copy link
Contributor

LGTM 👍, some minor formatting comments and a suggestion.

"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
threshold=0.5, thresholds=None, \
probabilityCol="probability", rawPredictionCol="rawPrediction")
threshold=0.5, thresholds=None, probabilityCol="probability",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add \

@SparkQA
Copy link

SparkQA commented Sep 11, 2015

Test build #42319 has finished for PR 8508 at commit 962692b.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
    • case class ExecutorLostFailure(execId: String, isNormalExit: Boolean = false)
    • class CoGroupedRDD[K: ClassTag](
    • class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
    • class ExecutorLossReason(val message: String) extends Serializable
    • case class ExecutorExited(exitCode: Int, isNormalExit: Boolean, reason: String)
    • case class RemoveExecutor(executorId: String, reason: ExecutorLossReason)
    • case class GetExecutorLossReason(executorId: String) extends CoarseGrainedClusterMessage
    • class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid):
    • ("thresholds", "Thresholds in multi-class classification to adjust the probability of " +
    • class HasHandleInvalid(Params):
    • class HasElasticNetParam(Params):
    • class HasFitIntercept(Params):
    • class HasStandardization(Params):
    • class HasThresholds(Params):
    • thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.")
    • self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.")
    • case class ConvertToSafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf)
    • case class ConvertToUnsafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf)
    • case class FilterNode(conf: SQLConf, condition: Expression, child: LocalNode)
    • case class HashJoinNode(
    • case class LimitNode(conf: SQLConf, limit: Int, child: LocalNode) extends UnaryLocalNode(conf)
    • abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
    • abstract class LeafLocalNode(conf: SQLConf) extends LocalNode(conf)
    • abstract class UnaryLocalNode(conf: SQLConf) extends LocalNode(conf)
    • abstract class BinaryLocalNode(conf: SQLConf) extends LocalNode(conf)
    • case class ProjectNode(conf: SQLConf, projectList: Seq[NamedExpression], child: LocalNode)
    • case class SeqScanNode(conf: SQLConf, output: Seq[Attribute], data: Seq[InternalRow])
    • case class UnionNode(conf: SQLConf, children: Seq[LocalNode]) extends LocalNode(conf)

@mengxr
Copy link
Contributor

mengxr commented Sep 11, 2015

Merged into master. Thanks!

@asfgit asfgit closed this in b656e61 Sep 11, 2015
@yanboliang yanboliang deleted the spark-10026 branch May 5, 2016 07:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants