From fa611d9f0634c1029e2dce93a16a1faed0651333 Mon Sep 17 00:00:00 2001 From: vectorijk Date: Thu, 24 Sep 2015 01:59:09 -0700 Subject: [PATCH 1/5] [SPARK-10688][ML][PySpark] Python API for AFTSurvivalRegression --- python/pyspark/ml/regression.py | 138 +++++++++++++++++++++++++++++++- 1 file changed, 137 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 21d454f9003bb..e7079f744f77d 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -24,7 +24,8 @@ __all__ = ['DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor', 'GBTRegressionModel', 'LinearRegression', 'LinearRegressionModel', - 'RandomForestRegressor', 'RandomForestRegressionModel'] + 'RandomForestRegressor', 'RandomForestRegressionModel', + 'AFTSurvivalRegression', 'AFTSurvivalRegressionModel'] @inherit_doc @@ -608,6 +609,141 @@ class GBTRegressionModel(TreeEnsembleModels): .. versionadded:: 1.4.0 """ +@inherit_doc +class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, + HasFitIntercept, HasMaxIter, HasTol): + """ + `https://en.wikipedia.org/wiki/Accelerated_failure_time_model` + Fit a parametric survival regression model named accelerated failure time (AFT) model + based on the Weibull distribution of the survival time. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0), 1.0), + ... (0.0, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"]) + >>> aftsr = AFTSurvivalRegression() + >>> model = aftsr.fit(df) + >>> model.transform(df).show() + +-----+---------+------+----------+ + |label| features|censor|prediction| + +-----+---------+------+----------+ + | 1.0| [1.0]| 1.0| 1.0| + | 0.0|(1,[],[])| 0.0| 1.0| + +-----+---------+------+----------+ + ... + + .. versionadded:: 1.6.0 + """ + + # a placeholder to make it appear in the generated doc + censorCol = Param(Params._dummy(), "censorCol", + "censor column name") + quantileProbabilities = \ + Param(Params._dummy(), "quantileProbabilities", + "quantile probabilities array" + + ", array is not empty and every probability is in range [0,1]") + quantilesCol = Param(Params._dummy(), "quantilesCol", + "quantiles column name") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], + quantilesCol=None): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], + quantilesCol=None): + """ + super(AFTSurvivalRegression, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid) + self.censorCol = \ + Param(self, "censorCol", + "censor column name") + self.quantileProbabilities = \ + Param(self, "quantileProbabilities", + "quantile probabilities array" + + ", array is not empty and every probability is in range [0,1]") + self.quantilesCol = Param(self, "quantilesCol", + "quantiles column name") + self._setDefault(censorCol="censor", + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("1.6.0") + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], + quantilesCol=None): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], + quantilesCol=None): + """ + kwargs = self.__init__._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return AFTSurvivalRegressionModel(java_model) + + @since("1.6.0") + def setCensorCol(self, value): + """ + Sets the value of :py:attr:`censorCol`. + """ + self._paramMap[self.censorCol] = value + return self + + @since("1.6.0") + def getCensorCol(self): + """ + Gets the value of censorCol or its default value. + """ + return self.getOrDefault(self.censorCol) + + @since("1.6.0") + def setQuantileProbabilities(self, value): + """ + Sets the value of :py:attr:`quantileProbabilities`. + """ + self._paramMap[self.quantileProbabilities] = value + return self + + @since("1.6.0") + def getQuantileProbabilities(self): + """ + Gets the value of quantileProbabilities or its default value. + """ + return self.getOrDefault(self.quantileProbabilities) + + @since("1.6.0") + def setQuantilesCol(self, value): + """ + Sets the value of :py:attr:`quantilesCol`. + """ + self._paramMap[self.quantilesCol] = value + return self + + @since("1.6.0") + def getQuantilesCol(self): + """ + Gets the value of quantilesCol or its default value. + """ + return self.getOrDefault(self.quantilesCol) + + + +class AFTSurvivalRegressionModel(JavaModel): + """ + Model fitted by AFTSurvivalRegression. + + .. versionadded:: 1.6.0 + """ if __name__ == "__main__": import doctest From b07efa33bf50ac8eb63d356c7e976e3a46dff40d Mon Sep 17 00:00:00 2001 From: vectorijk Date: Tue, 29 Sep 2015 00:36:35 -0700 Subject: [PATCH 2/5] clean up --- python/pyspark/ml/regression.py | 69 ++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index e7079f744f77d..60b75e9afec3c 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -22,10 +22,11 @@ from pyspark.mllib.common import inherit_doc -__all__ = ['DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor', - 'GBTRegressionModel', 'LinearRegression', 'LinearRegressionModel', - 'RandomForestRegressor', 'RandomForestRegressionModel', - 'AFTSurvivalRegression', 'AFTSurvivalRegressionModel'] +__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', + 'DecisionTreeRegressor', 'DecisionTreeRegressionModel', + 'GBTRegressor', 'GBTRegressionModel', + 'LinearRegression', 'LinearRegressionModel', + 'RandomForestRegressor', 'RandomForestRegressionModel'] @inherit_doc @@ -609,13 +610,16 @@ class GBTRegressionModel(TreeEnsembleModels): .. versionadded:: 1.4.0 """ + @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasFitIntercept, HasMaxIter, HasTol): """ - `https://en.wikipedia.org/wiki/Accelerated_failure_time_model` - Fit a parametric survival regression model named accelerated failure time (AFT) model - based on the Weibull distribution of the survival time. + Accelerated Failure Time(AFT) Model Survival Regression + + Fit a parametric AFT survival regression model based on the Weibull distribution + of the survival time. + see also, `https://en.wikipedia.org/wiki/Accelerated_failure_time_model` >>> from pyspark.mllib.linalg import Vectors >>> df = sqlContext.createDataFrame([ @@ -637,37 +641,43 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi # a placeholder to make it appear in the generated doc censorCol = Param(Params._dummy(), "censorCol", - "censor column name") + "censor column name. The value of this column could be 0 or 1. " + + "If the value is 1, it means the event has occurred i.e. " + + "uncensored; otherwise censored.") quantileProbabilities = \ Param(Params._dummy(), "quantileProbabilities", - "quantile probabilities array" + - ", array is not empty and every probability is in range [0,1]") + "quantile probabilities array. Values of the quantile probabilities array " + + "should be in the range [0, 1] and the array should be non-empty.") quantilesCol = Param(Params._dummy(), "quantilesCol", - "quantiles column name") + "quantiles column name. This column will output quantiles of " + + "corresponding quantileProbabilities if it is set.") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", - quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], - quantilesCol=None): + quantileProbabilities=None, quantilesCol=None): """ - __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", - quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], - quantilesCol=None): + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ + quantileProbabilities=None, quantilesCol=None): """ super(AFTSurvivalRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid) - self.censorCol = \ - Param(self, "censorCol", - "censor column name") + #: Param for censor column name + self.censorCol = Param(self, "censorCol", + "censor column name. The value of this column could be 0 or 1. " + + "If the value is 1, it means the event has occurred i.e. " + + "uncensored; otherwise censored.") + #: Param for quantile probabilities array self.quantileProbabilities = \ Param(self, "quantileProbabilities", - "quantile probabilities array" + - ", array is not empty and every probability is in range [0,1]") + "quantile probabilities array. Values of the quantile probabilities array " + + "should be in the range [0, 1] and the array should be non-empty.") + #: Param for quantiles column name self.quantilesCol = Param(self, "quantilesCol", - "quantiles column name") + "quantiles column name. This column will output quantiles of " + + "corresponding quantileProbabilities if it is set.") self._setDefault(censorCol="censor", quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]) kwargs = self.__init__._input_kwargs @@ -676,14 +686,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @keyword_only @since("1.6.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", - quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], - quantilesCol=None): + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", + quantileProbabilities=None, quantilesCol=None): """ - setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", - quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], - quantilesCol=None): + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ + quantileProbabilities=None, quantilesCol=None): """ kwargs = self.__init__._input_kwargs return self._set(**kwargs) @@ -737,7 +745,6 @@ def getQuantilesCol(self): return self.getOrDefault(self.quantilesCol) - class AFTSurvivalRegressionModel(JavaModel): """ Model fitted by AFTSurvivalRegression. From f76ec8ae3ede3efa7446b22b307a09087f4e4012 Mon Sep 17 00:00:00 2001 From: vectorijk Date: Wed, 30 Sep 2015 00:08:05 -0700 Subject: [PATCH 3/5] typos --- python/pyspark/ml/regression.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 60b75e9afec3c..9af62529d556d 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -615,11 +615,12 @@ class GBTRegressionModel(TreeEnsembleModels): class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasFitIntercept, HasMaxIter, HasTol): """ - Accelerated Failure Time(AFT) Model Survival Regression + Accelerated Failure Time (AFT) Model Survival Regression Fit a parametric AFT survival regression model based on the Weibull distribution of the survival time. - see also, `https://en.wikipedia.org/wiki/Accelerated_failure_time_model` + + .. seealso:: `AFT Model `_ >>> from pyspark.mllib.linalg import Vectors >>> df = sqlContext.createDataFrame([ @@ -659,7 +660,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ - quantileProbabilities=None, quantilesCol=None): + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ + quantilesCol=None): """ super(AFTSurvivalRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -691,7 +693,8 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ - quantileProbabilities=None, quantilesCol=None): + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ + quantilesCol=None): """ kwargs = self.__init__._input_kwargs return self._set(**kwargs) From 25ac136a80bbd37c0fbc01392831ffadb4803309 Mon Sep 17 00:00:00 2001 From: vectorijk Date: Wed, 30 Sep 2015 15:47:15 -0700 Subject: [PATCH 4/5] implement AFTSurvivalRegressionModel methods - set quantileProbabilites by array in setParams - two methods in AFTSurvivalRegressionModel --- python/pyspark/ml/regression.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 9af62529d556d..83ab72da17059 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -628,6 +628,10 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi ... (0.0, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"]) >>> aftsr = AFTSurvivalRegression() >>> model = aftsr.fit(df) + >>> model.predict(Vectors.dense(6.3)) + 1.0 + >>> model.predictQuantiles(Vectors.dense(6.3)) + DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.6052]) >>> model.transform(df).show() +-----+---------+------+----------+ |label| features|censor|prediction| @@ -696,8 +700,12 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ quantilesCol=None): """ - kwargs = self.__init__._input_kwargs - return self._set(**kwargs) + kwargs = self.setParams._input_kwargs + if quantileProbabilities is None: + return self._set(**kwargs).setQuantileProbabilities([0.01, 0.05, 0.1, 0.25, 0.5, + 0.75, 0.9, 0.95, 0.99]) + else: + return self._set(**kwargs) def _create_model(self, java_model): return AFTSurvivalRegressionModel(java_model) @@ -755,6 +763,19 @@ class AFTSurvivalRegressionModel(JavaModel): .. versionadded:: 1.6.0 """ + def predictQuantiles(self, features): + """ + Predicted Quantiles + """ + return self._call_java("predictQuantiles", features) + + def predict(self, features): + """ + Predicted value + """ + return self._call_java("predict", features) + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From 4f1e8a9e008e251f1da50fccd313db72f3b87c18 Mon Sep 17 00:00:00 2001 From: vectorijk Date: Wed, 30 Sep 2015 23:18:17 -0700 Subject: [PATCH 5/5] change docs in quantileProbabilities --- python/pyspark/ml/regression.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 83ab72da17059..a0f7f54e65213 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -652,7 +652,7 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi quantileProbabilities = \ Param(Params._dummy(), "quantileProbabilities", "quantile probabilities array. Values of the quantile probabilities array " + - "should be in the range [0, 1] and the array should be non-empty.") + "should be in the range (0, 1) and the array should be non-empty.") quantilesCol = Param(Params._dummy(), "quantilesCol", "quantiles column name. This column will output quantiles of " + "corresponding quantileProbabilities if it is set.") @@ -679,7 +679,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.quantileProbabilities = \ Param(self, "quantileProbabilities", "quantile probabilities array. Values of the quantile probabilities array " + - "should be in the range [0, 1] and the array should be non-empty.") + "should be in the range (0, 1) and the array should be non-empty.") #: Param for quantiles column name self.quantilesCol = Param(self, "quantilesCol", "quantiles column name. This column will output quantiles of " +