Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
actuaryzhang committed Feb 7, 2017
1 parent 95b7a10 commit 37c41aa
Showing 1 changed file with 17 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -746,37 +746,40 @@ class GeneralizedLinearRegressionSuite
test("generalized linear regression: intercept only") {
/*
R code:
y <- c(17, 19, 23, 29)
library(statmod)
y <- c(1.0, 0.5, 0.7, 0.3)
w <- c(1, 2, 3, 4)
model1 <- glm(y ~ 1, family = poisson)
model2 <- glm(y ~ 1, family = poisson, weights = w)
as.vector(c(coef(model1), coef(model2)))
for (fam in c("gaussian", "poisson", "Gamma")) {
for (fam in list(gaussian(), poisson(), binomial(), Gamma(), tweedie(1.6))) {
model1 <- glm(y ~ 1, family = fam)
model2 <- glm(y ~ 1, family = fam, weights = w)
print(as.vector(c(coef(model1), coef(model2))))
}
[1] 22 24
[1] 3.091042 3.178054
[1] 0.04545455 0.04166667
[1] 0.625 0.530
[1] -0.4700036 -0.6348783
[1] 0.5108256 0.1201443
[1] 1.600000 1.886792
[1] 1.325782 1.463641
*/

val dataset = Seq(
Instance(17.0, 1.0, Vectors.zeros(0)),
Instance(19.0, 2.0, Vectors.zeros(0)),
Instance(23.0, 3.0, Vectors.zeros(0)),
Instance(29.0, 4.0, Vectors.zeros(0))
Instance(1.0, 1.0, Vectors.zeros(0)),
Instance(0.5, 2.0, Vectors.zeros(0)),
Instance(0.7, 3.0, Vectors.zeros(0)),
Instance(0.3, 4.0, Vectors.zeros(0))
).toDF()

val expected = Seq(22.0, 24.0, 3.0910, 3.1781, 0.0455, 0.0417)
val expected = Seq(0.625, 0.530, -0.4700036, -0.6348783, 0.5108256, 0.1201443,
1.600000, 1.886792, 1.325782, 1.463641)

import GeneralizedLinearRegression._

var idx = 0
for (family <- Seq("gaussian", "poisson", "gamma")) {
for (family <- Seq("gaussian", "poisson", "binomial", "gamma", "tweedie")) {
for (useWeight <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily(family)
if (useWeight) trainer.setWeightCol("weight")
if (family == "tweedie") trainer.setVariancePower(1.6)
val model = trainer.fit(dataset)
val actual = model.intercept
assert(actual ~== expected(idx) absTol 1E-3, "Model mismatch: intercept only GLM with " +
Expand Down

0 comments on commit 37c41aa

Please sign in to comment.