Permalink
Cannot retrieve contributors at this time
Fetching contributors…
| /* | |
| * Licensed to the Apache Software Foundation (ASF) under one or more | |
| * contributor license agreements. See the NOTICE file distributed with | |
| * this work for additional information regarding copyright ownership. | |
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |
| * (the "License"); you may not use this file except in compliance with | |
| * the License. You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| package org.apache.spark.ml.regression | |
| import scala.collection.mutable | |
| import breeze.linalg.{DenseVector => BDV} | |
| import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} | |
| import org.apache.hadoop.fs.Path | |
| import org.apache.spark.SparkException | |
| import org.apache.spark.annotation.{Experimental, Since} | |
| import org.apache.spark.broadcast.Broadcast | |
| import org.apache.spark.internal.Logging | |
| import org.apache.spark.ml.{Estimator, Model} | |
| import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} | |
| import org.apache.spark.ml.param._ | |
| import org.apache.spark.ml.param.shared._ | |
| import org.apache.spark.ml.util._ | |
| import org.apache.spark.mllib.linalg.VectorImplicits._ | |
| import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer | |
| import org.apache.spark.mllib.util.MLUtils | |
| import org.apache.spark.rdd.RDD | |
| import org.apache.spark.sql.{DataFrame, Dataset, Row} | |
| import org.apache.spark.sql.functions._ | |
| import org.apache.spark.sql.types.{DoubleType, StructType} | |
| import org.apache.spark.storage.StorageLevel | |
| /** | |
| * Params for accelerated failure time (AFT) regression. | |
| */ | |
| private[regression] trait AFTSurvivalRegressionParams extends Params | |
| with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter | |
| with HasTol with HasFitIntercept with HasAggregationDepth with Logging { | |
| /** | |
| * Param for censor column name. | |
| * The value of this column could be 0 or 1. | |
| * If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored. | |
| * @group param | |
| */ | |
| @Since("1.6.0") | |
| final val censorCol: Param[String] = new Param(this, "censorCol", "censor column name") | |
| /** @group getParam */ | |
| @Since("1.6.0") | |
| def getCensorCol: String = $(censorCol) | |
| setDefault(censorCol -> "censor") | |
| /** | |
| * Param for quantile probabilities array. | |
| * Values of the quantile probabilities array should be in the range (0, 1) | |
| * and the array should be non-empty. | |
| * @group param | |
| */ | |
| @Since("1.6.0") | |
| final val quantileProbabilities: DoubleArrayParam = new DoubleArrayParam(this, | |
| "quantileProbabilities", "quantile probabilities array", | |
| (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1, false, false)) && t.length > 0) | |
| /** @group getParam */ | |
| @Since("1.6.0") | |
| def getQuantileProbabilities: Array[Double] = $(quantileProbabilities) | |
| setDefault(quantileProbabilities -> Array(0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99)) | |
| /** | |
| * Param for quantiles column name. | |
| * This column will output quantiles of corresponding quantileProbabilities if it is set. | |
| * @group param | |
| */ | |
| @Since("1.6.0") | |
| final val quantilesCol: Param[String] = new Param(this, "quantilesCol", "quantiles column name") | |
| /** @group getParam */ | |
| @Since("1.6.0") | |
| def getQuantilesCol: String = $(quantilesCol) | |
| /** Checks whether the input has quantiles column name. */ | |
| private[regression] def hasQuantilesCol: Boolean = { | |
| isDefined(quantilesCol) && $(quantilesCol).nonEmpty | |
| } | |
| /** | |
| * Validates and transforms the input schema with the provided param map. | |
| * @param schema input schema | |
| * @param fitting whether this is in fitting or prediction | |
| * @return output schema | |
| */ | |
| protected def validateAndTransformSchema( | |
| schema: StructType, | |
| fitting: Boolean): StructType = { | |
| SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) | |
| if (fitting) { | |
| SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) | |
| SchemaUtils.checkNumericType(schema, $(labelCol)) | |
| } | |
| if (hasQuantilesCol) { | |
| SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT) | |
| } | |
| SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) | |
| } | |
| } | |
| /** | |
| * :: Experimental :: | |
| * Fit a parametric survival regression model named accelerated failure time (AFT) model | |
| * (see <a href="https://en.wikipedia.org/wiki/Accelerated_failure_time_model"> | |
| * Accelerated failure time model (Wikipedia)</a>) | |
| * based on the Weibull distribution of the survival time. | |
| */ | |
| @Experimental | |
| @Since("1.6.0") | |
| class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String) | |
| extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams | |
| with DefaultParamsWritable with Logging { | |
| @Since("1.6.0") | |
| def this() = this(Identifiable.randomUID("aftSurvReg")) | |
| /** @group setParam */ | |
| @Since("1.6.0") | |
| def setFeaturesCol(value: String): this.type = set(featuresCol, value) | |
| /** @group setParam */ | |
| @Since("1.6.0") | |
| def setLabelCol(value: String): this.type = set(labelCol, value) | |
| /** @group setParam */ | |
| @Since("1.6.0") | |
| def setCensorCol(value: String): this.type = set(censorCol, value) | |
| /** @group setParam */ | |
| @Since("1.6.0") | |
| def setPredictionCol(value: String): this.type = set(predictionCol, value) | |
| /** @group setParam */ | |
| @Since("1.6.0") | |
| def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value) | |
| /** @group setParam */ | |
| @Since("1.6.0") | |
| def setQuantilesCol(value: String): this.type = set(quantilesCol, value) | |
| /** | |
| * Set if we should fit the intercept | |
| * Default is true. | |
| * @group setParam | |
| */ | |
| @Since("1.6.0") | |
| def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) | |
| setDefault(fitIntercept -> true) | |
| /** | |
| * Set the maximum number of iterations. | |
| * Default is 100. | |
| * @group setParam | |
| */ | |
| @Since("1.6.0") | |
| def setMaxIter(value: Int): this.type = set(maxIter, value) | |
| setDefault(maxIter -> 100) | |
| /** | |
| * Set the convergence tolerance of iterations. | |
| * Smaller value will lead to higher accuracy with the cost of more iterations. | |
| * Default is 1E-6. | |
| * @group setParam | |
| */ | |
| @Since("1.6.0") | |
| def setTol(value: Double): this.type = set(tol, value) | |
| setDefault(tol -> 1E-6) | |
| /** | |
| * Suggested depth for treeAggregate (greater than or equal to 2). | |
| * If the dimensions of features or the number of partitions are large, | |
| * this param could be adjusted to a larger size. | |
| * Default is 2. | |
| * @group expertSetParam | |
| */ | |
| @Since("2.1.0") | |
| def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) | |
| setDefault(aggregationDepth -> 2) | |
| /** | |
| * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, | |
| * and put it in an RDD with strong types. | |
| */ | |
| protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = { | |
| dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol))) | |
| .rdd.map { | |
| case Row(features: Vector, label: Double, censor: Double) => | |
| AFTPoint(features, label, censor) | |
| } | |
| } | |
| @Since("2.0.0") | |
| override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = { | |
| transformSchema(dataset.schema, logging = true) | |
| val instances = extractAFTPoints(dataset) | |
| val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE | |
| if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) | |
| val featuresSummarizer = { | |
| val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features) | |
| val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => { | |
| c1.merge(c2) | |
| } | |
| instances.treeAggregate( | |
| new MultivariateOnlineSummarizer | |
| )(seqOp, combOp, $(aggregationDepth)) | |
| } | |
| val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) | |
| val numFeatures = featuresStd.size | |
| if (!$(fitIntercept) && (0 until numFeatures).exists { i => | |
| featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) { | |
| logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + | |
| "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + | |
| "columns. This behavior is different from R survival::survreg.") | |
| } | |
| val bcFeaturesStd = instances.context.broadcast(featuresStd) | |
| val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd, $(aggregationDepth)) | |
| val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) | |
| /* | |
| The parameters vector has three parts: | |
| the first element: Double, log(sigma), the log of scale parameter | |
| the second element: Double, intercept of the beta parameter | |
| the third to the end elements: Doubles, regression coefficients vector of the beta parameter | |
| */ | |
| val initialParameters = Vectors.zeros(numFeatures + 2) | |
| val states = optimizer.iterations(new CachedDiffFunction(costFun), | |
| initialParameters.asBreeze.toDenseVector) | |
| val parameters = { | |
| val arrayBuilder = mutable.ArrayBuilder.make[Double] | |
| var state: optimizer.State = null | |
| while (states.hasNext) { | |
| state = states.next() | |
| arrayBuilder += state.adjustedValue | |
| } | |
| if (state == null) { | |
| val msg = s"${optimizer.getClass.getName} failed." | |
| throw new SparkException(msg) | |
| } | |
| state.x.toArray.clone() | |
| } | |
| bcFeaturesStd.destroy(blocking = false) | |
| if (handlePersistence) instances.unpersist() | |
| val rawCoefficients = parameters.slice(2, parameters.length) | |
| var i = 0 | |
| while (i < numFeatures) { | |
| rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } | |
| i += 1 | |
| } | |
| val coefficients = Vectors.dense(rawCoefficients) | |
| val intercept = parameters(1) | |
| val scale = math.exp(parameters(0)) | |
| val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) | |
| copyValues(model.setParent(this)) | |
| } | |
| @Since("1.6.0") | |
| override def transformSchema(schema: StructType): StructType = { | |
| validateAndTransformSchema(schema, fitting = true) | |
| } | |
| @Since("1.6.0") | |
| override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra) | |
| } | |
| @Since("1.6.0") | |
| object AFTSurvivalRegression extends DefaultParamsReadable[AFTSurvivalRegression] { | |
| @Since("1.6.0") | |
| override def load(path: String): AFTSurvivalRegression = super.load(path) | |
| } | |
| /** | |
| * :: Experimental :: | |
| * Model produced by [[AFTSurvivalRegression]]. | |
| */ | |
| @Experimental | |
| @Since("1.6.0") | |
| class AFTSurvivalRegressionModel private[ml] ( | |
| @Since("1.6.0") override val uid: String, | |
| @Since("2.0.0") val coefficients: Vector, | |
| @Since("1.6.0") val intercept: Double, | |
| @Since("1.6.0") val scale: Double) | |
| extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with MLWritable { | |
| /** @group setParam */ | |
| @Since("1.6.0") | |
| def setFeaturesCol(value: String): this.type = set(featuresCol, value) | |
| /** @group setParam */ | |
| @Since("1.6.0") | |
| def setPredictionCol(value: String): this.type = set(predictionCol, value) | |
| /** @group setParam */ | |
| @Since("1.6.0") | |
| def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value) | |
| /** @group setParam */ | |
| @Since("1.6.0") | |
| def setQuantilesCol(value: String): this.type = set(quantilesCol, value) | |
| @Since("2.0.0") | |
| def predictQuantiles(features: Vector): Vector = { | |
| // scale parameter for the Weibull distribution of lifetime | |
| val lambda = math.exp(BLAS.dot(coefficients, features) + intercept) | |
| // shape parameter for the Weibull distribution of lifetime | |
| val k = 1 / scale | |
| val quantiles = $(quantileProbabilities).map { | |
| q => lambda * math.exp(math.log(-math.log(1 - q)) / k) | |
| } | |
| Vectors.dense(quantiles) | |
| } | |
| @Since("2.0.0") | |
| def predict(features: Vector): Double = { | |
| math.exp(BLAS.dot(coefficients, features) + intercept) | |
| } | |
| @Since("2.0.0") | |
| override def transform(dataset: Dataset[_]): DataFrame = { | |
| transformSchema(dataset.schema, logging = true) | |
| val predictUDF = udf { features: Vector => predict(features) } | |
| val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} | |
| if (hasQuantilesCol) { | |
| dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) | |
| .withColumn($(quantilesCol), predictQuantilesUDF(col($(featuresCol)))) | |
| } else { | |
| dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) | |
| } | |
| } | |
| @Since("1.6.0") | |
| override def transformSchema(schema: StructType): StructType = { | |
| validateAndTransformSchema(schema, fitting = false) | |
| } | |
| @Since("1.6.0") | |
| override def copy(extra: ParamMap): AFTSurvivalRegressionModel = { | |
| copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra) | |
| .setParent(parent) | |
| } | |
| @Since("1.6.0") | |
| override def write: MLWriter = | |
| new AFTSurvivalRegressionModel.AFTSurvivalRegressionModelWriter(this) | |
| } | |
| @Since("1.6.0") | |
| object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] { | |
| @Since("1.6.0") | |
| override def read: MLReader[AFTSurvivalRegressionModel] = new AFTSurvivalRegressionModelReader | |
| @Since("1.6.0") | |
| override def load(path: String): AFTSurvivalRegressionModel = super.load(path) | |
| /** [[MLWriter]] instance for [[AFTSurvivalRegressionModel]] */ | |
| private[AFTSurvivalRegressionModel] class AFTSurvivalRegressionModelWriter ( | |
| instance: AFTSurvivalRegressionModel | |
| ) extends MLWriter with Logging { | |
| private case class Data(coefficients: Vector, intercept: Double, scale: Double) | |
| override protected def saveImpl(path: String): Unit = { | |
| // Save metadata and Params | |
| DefaultParamsWriter.saveMetadata(instance, path, sc) | |
| // Save model data: coefficients, intercept, scale | |
| val data = Data(instance.coefficients, instance.intercept, instance.scale) | |
| val dataPath = new Path(path, "data").toString | |
| sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) | |
| } | |
| } | |
| private class AFTSurvivalRegressionModelReader extends MLReader[AFTSurvivalRegressionModel] { | |
| /** Checked against metadata when loading model */ | |
| private val className = classOf[AFTSurvivalRegressionModel].getName | |
| override def load(path: String): AFTSurvivalRegressionModel = { | |
| val metadata = DefaultParamsReader.loadMetadata(path, sc, className) | |
| val dataPath = new Path(path, "data").toString | |
| val data = sparkSession.read.parquet(dataPath) | |
| val Row(coefficients: Vector, intercept: Double, scale: Double) = | |
| MLUtils.convertVectorColumnsToML(data, "coefficients") | |
| .select("coefficients", "intercept", "scale") | |
| .head() | |
| val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) | |
| DefaultParamsReader.getAndSetParams(model, metadata) | |
| model | |
| } | |
| } | |
| } | |
| /** | |
| * AFTAggregator computes the gradient and loss for a AFT loss function, | |
| * as used in AFT survival regression for samples in sparse or dense vector in an online fashion. | |
| * | |
| * The loss function and likelihood function under the AFT model based on: | |
| * Lawless, J. F., Statistical Models and Methods for Lifetime Data, | |
| * New York: John Wiley & Sons, Inc. 2003. | |
| * | |
| * Two AFTAggregator can be merged together to have a summary of loss and gradient of | |
| * the corresponding joint dataset. | |
| * | |
| * Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of subjects i = 1,..,n, | |
| * with possible right-censoring, the likelihood function under the AFT model is given as | |
| * | |
| * <blockquote> | |
| * $$ | |
| * L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0} | |
| * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0} | |
| * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} | |
| * $$ | |
| * </blockquote> | |
| * | |
| * Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not. | |
| * Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}$, the log-likelihood function | |
| * assumes the form | |
| * | |
| * <blockquote> | |
| * $$ | |
| * \iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+ | |
| * \delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] | |
| * $$ | |
| * </blockquote> | |
| * Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, | |
| * and $f_{0}(\epsilon_{i})$ is corresponding density function. | |
| * | |
| * The most commonly used log-linear survival regression method is based on the Weibull | |
| * distribution of the survival time. The Weibull distribution for lifetime corresponding | |
| * to extreme value distribution for log of the lifetime, | |
| * and the $S_{0}(\epsilon)$ function is | |
| * | |
| * <blockquote> | |
| * $$ | |
| * S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) | |
| * $$ | |
| * </blockquote> | |
| * | |
| * and the $f_{0}(\epsilon_{i})$ function is | |
| * | |
| * <blockquote> | |
| * $$ | |
| * f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) | |
| * $$ | |
| * </blockquote> | |
| * | |
| * The log-likelihood function for Weibull distribution of lifetime is | |
| * | |
| * <blockquote> | |
| * $$ | |
| * \iota(\beta,\sigma)= | |
| * -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] | |
| * $$ | |
| * </blockquote> | |
| * | |
| * Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, | |
| * the loss function we use to optimize is $-\iota(\beta,\sigma)$. | |
| * The gradient functions for $\beta$ and $\log\sigma$ respectively are | |
| * | |
| * <blockquote> | |
| * $$ | |
| * \frac{\partial (-\iota)}{\partial \beta}= | |
| * \sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} \\ | |
| * | |
| * \frac{\partial (-\iota)}{\partial (\log\sigma)}= | |
| * \sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] | |
| * $$ | |
| * </blockquote> | |
| * | |
| * @param bcParameters The broadcasted value includes three part: The log of scale parameter, | |
| * the intercept and regression coefficients corresponding to the features. | |
| * @param fitIntercept Whether to fit an intercept term. | |
| * @param bcFeaturesStd The broadcast standard deviation values of the features. | |
| */ | |
| private class AFTAggregator( | |
| bcParameters: Broadcast[BDV[Double]], | |
| fitIntercept: Boolean, | |
| bcFeaturesStd: Broadcast[Array[Double]]) extends Serializable { | |
| private val length = bcParameters.value.length | |
| // make transient so we do not serialize between aggregation stages | |
| @transient private lazy val parameters = bcParameters.value | |
| // the regression coefficients to the covariates | |
| @transient private lazy val coefficients = parameters.slice(2, length) | |
| @transient private lazy val intercept = parameters(1) | |
| // sigma is the scale parameter of the AFT model | |
| @transient private lazy val sigma = math.exp(parameters(0)) | |
| private var totalCnt: Long = 0L | |
| private var lossSum = 0.0 | |
| // Here we optimize loss function over log(sigma), intercept and coefficients | |
| private val gradientSumArray = Array.ofDim[Double](length) | |
| def count: Long = totalCnt | |
| def loss: Double = { | |
| require(totalCnt > 0.0, s"The number of instances should be " + | |
| s"greater than 0.0, but got $totalCnt.") | |
| lossSum / totalCnt | |
| } | |
| def gradient: BDV[Double] = { | |
| require(totalCnt > 0.0, s"The number of instances should be " + | |
| s"greater than 0.0, but got $totalCnt.") | |
| new BDV(gradientSumArray.map(_ / totalCnt.toDouble)) | |
| } | |
| /** | |
| * Add a new training data to this AFTAggregator, and update the loss and gradient | |
| * of the objective function. | |
| * | |
| * @param data The AFTPoint representation for one data point to be added into this aggregator. | |
| * @return This AFTAggregator object. | |
| */ | |
| def add(data: AFTPoint): this.type = { | |
| val xi = data.features | |
| val ti = data.label | |
| val delta = data.censor | |
| val localFeaturesStd = bcFeaturesStd.value | |
| val margin = { | |
| var sum = 0.0 | |
| xi.foreachActive { (index, value) => | |
| if (localFeaturesStd(index) != 0.0 && value != 0.0) { | |
| sum += coefficients(index) * (value / localFeaturesStd(index)) | |
| } | |
| } | |
| sum + intercept | |
| } | |
| val epsilon = (math.log(ti) - margin) / sigma | |
| lossSum += delta * math.log(sigma) - delta * epsilon + math.exp(epsilon) | |
| val multiplier = (delta - math.exp(epsilon)) / sigma | |
| gradientSumArray(0) += delta + multiplier * sigma * epsilon | |
| gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 } | |
| xi.foreachActive { (index, value) => | |
| if (localFeaturesStd(index) != 0.0 && value != 0.0) { | |
| gradientSumArray(index + 2) += multiplier * (value / localFeaturesStd(index)) | |
| } | |
| } | |
| totalCnt += 1 | |
| this | |
| } | |
| /** | |
| * Merge another AFTAggregator, and update the loss and gradient | |
| * of the objective function. | |
| * (Note that it's in place merging; as a result, `this` object will be modified.) | |
| * | |
| * @param other The other AFTAggregator to be merged. | |
| * @return This AFTAggregator object. | |
| */ | |
| def merge(other: AFTAggregator): this.type = { | |
| if (other.count != 0) { | |
| totalCnt += other.totalCnt | |
| lossSum += other.lossSum | |
| var i = 0 | |
| while (i < length) { | |
| this.gradientSumArray(i) += other.gradientSumArray(i) | |
| i += 1 | |
| } | |
| } | |
| this | |
| } | |
| } | |
| /** | |
| * AFTCostFun implements Breeze's DiffFunction[T] for AFT cost. | |
| * It returns the loss and gradient at a particular point (parameters). | |
| * It's used in Breeze's convex optimization routines. | |
| */ | |
| private class AFTCostFun( | |
| data: RDD[AFTPoint], | |
| fitIntercept: Boolean, | |
| bcFeaturesStd: Broadcast[Array[Double]], | |
| aggregationDepth: Int) extends DiffFunction[BDV[Double]] { | |
| override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = { | |
| val bcParameters = data.context.broadcast(parameters) | |
| val aftAggregator = data.treeAggregate( | |
| new AFTAggregator(bcParameters, fitIntercept, bcFeaturesStd))( | |
| seqOp = (c, v) => (c, v) match { | |
| case (aggregator, instance) => aggregator.add(instance) | |
| }, | |
| combOp = (c1, c2) => (c1, c2) match { | |
| case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) | |
| }, depth = aggregationDepth) | |
| bcParameters.destroy(blocking = false) | |
| (aftAggregator.loss, aftAggregator.gradient) | |
| } | |
| } | |
| /** | |
| * Class that represents the (features, label, censor) of a data point. | |
| * | |
| * @param features List of features for this data point. | |
| * @param label Label for this data point. | |
| * @param censor Indicator of the event has occurred or not. If the value is 1, it means | |
| * the event has occurred i.e. uncensored; otherwise censored. | |
| */ | |
| private[regression] case class AFTPoint(features: Vector, label: Double, censor: Double) { | |
| require(censor == 1.0 || censor == 0.0, "censor of class AFTPoint must be 1.0 or 0.0") | |
| } |