Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tengpeng committed Apr 23, 2018
1 parent 3c6a4da commit da53b1a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine

private[regression] val epsilon: Double = 1E-16

private[regression] def ylogy(y: Double, mu: Double): Double = {
if (y == 0) 0.0 else y * math.log(y / mu)
}

/**
* Wrapper of family and link combination used in the model.
*/
Expand Down Expand Up @@ -725,10 +729,6 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine

override def variance(mu: Double): Double = mu * (1.0 - mu)

private def ylogy(y: Double, mu: Double): Double = {
if (y == 0) 0.0 else y * math.log(y / mu)
}

override def deviance(y: Double, mu: Double, weight: Double): Double = {
2.0 * weight * (ylogy(y, mu) + ylogy(1.0 - y, 1.0 - mu))
}
Expand Down Expand Up @@ -782,10 +782,6 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine

override def variance(mu: Double): Double = mu

private def ylogy(y: Double, mu: Double): Double = {
if (y == 0) 0.0 else y * math.log(y / mu)
}

override def deviance(y: Double, mu: Double, weight: Double): Double = {
2.0 * weight * (ylogy(y, mu) - (y - mu))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,10 +493,19 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest
}
[1] -0.0457441 -0.6833928
[1] 1.8121235 -0.1747493 -0.5815417
R code for deivance calculation:
data = cbind(y=c(0,1,0,0,0,1), x1=c(18, 12, 15, 13, 15, 16), x2=c(1,0,0,2,1,1))
summary(glm(y~x1+x2, family=poisson, data=data.frame(data)))$deviance
[1] 3.70055
summary(glm(y~x1+x2-1, family=poisson, data=data.frame(data)))$deviance
[1] 3.809296
*/
val expected = Seq(
Vectors.dense(0.0, -0.0457441, -0.6833928, 3.8093),
Vectors.dense(1.8121235, -0.1747493, -0.5815417, 3.7006))
Vectors.dense(0.0, -0.0457441, -0.6833928),
Vectors.dense(1.8121235, -0.1747493, -0.5815417))

val residualDeviancesR = Array(3.809296, 3.70055)

import GeneralizedLinearRegression._

Expand All @@ -507,10 +516,10 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest
val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link)
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1),
model.summary.deviance)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " +
s"$link link and fitIntercept = $fitIntercept (with zero values).")
assert(model.summary.deviance ~== residualDeviancesR(idx) absTol 1E-3)
idx += 1
}
}
Expand Down

0 comments on commit da53b1a

Please sign in to comment.