Skip to content

Commit

Permalink
add HuberCostFun to LinearRegression.scala
Browse files Browse the repository at this point in the history
  • Loading branch information
fjiang6 committed Aug 4, 2015
1 parent c980a1f commit cff7ecb
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 609 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.StatCounter
import scala.math.pow

/**
* Params for linear regression.
Expand Down Expand Up @@ -591,3 +592,57 @@ private class LeastSquaresCostFun(
(loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
}
}

/**
* HuberCostFun implements Breeze's DiffFunction[T] for Huber cost as used in Robust regression.
* The Huber M-estimator corresponds to a probability distribution for the errors which is normal
* in the centre but like a double exponential distribution in the tails (Hogg 1979: 109).
* L = 1/2 ||A weights-y||^2 if |A weights-y| <= k
* L = k |A weights-y| - 1/2 K^2 if |A weights-y| > k
* where k = 1.345 which produce 95% efficiency when the errors are normal and
* substantial resistance to outliers otherwise.
* See also the documentation for the precise formulation.
* It's used in Breeze's convex optimization routines.
*/
private class HuberCostFun(
data: RDD[(Double, Vector)],
labelStd: Double,
labelMean: Double,
fitIntercept: Boolean,
featuresStd: Array[Double],
featuresMean: Array[Double],
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {

override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
val w = Vectors.fromBreeze(weights)

val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd,
labelMean, fitIntercept, featuresStd, featuresMean))(
seqOp = (c, v) => (c, v) match {
case (aggregator, (label, features)) => aggregator.add(label, features)
},
combOp = (c1, c2) => (c1, c2) match {
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
})

val k = 1.345
val bcW = data.context.broadcast(w)
val diff = dot(bcW.value, w) - labelMean
val norm = brzNorm(weights, 2.0)
var regVal = 0.0
if(diff < -k){
regVal = -k * diff - 0.5 * pow(k, 2)
} else if (diff >= -k && diff <= k){
regVal = 0.5 * norm * norm
} else {
regVal = k * diff - 0.5 * pow(k, 2)
}

val loss = leastSquaresAggregator.loss + regVal
val gradient = leastSquaresAggregator.gradient
axpy(effectiveL2regParam, w, gradient)

(loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
}
}

Loading

0 comments on commit cff7ecb

Please sign in to comment.