Skip to content

Commit

Permalink
fix style in test
Browse files Browse the repository at this point in the history
  • Loading branch information
actuaryzhang committed Jan 26, 2017
1 parent 9c320ee commit e183c08
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -798,13 +798,13 @@ class GeneralizedLinearRegressionModel private[ml] (
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) }
val predictLinkUDF = udf { (features: Vector, offset: Double) => predictLink(features, offset) }
val off = if (!isSet(offsetCol) || $(offsetCol).isEmpty) lit(0.0) else col($(offsetCol))
val offset = if (!isSet(offsetCol) || $(offsetCol).isEmpty) lit(0.0) else col($(offsetCol))
var output = dataset
if ($(predictionCol).nonEmpty) {
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)), off))
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)), offset))
}
if (hasLinkPredictionCol) {
output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)), off))
output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)), offset))
}
output.toDF()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,28 +623,28 @@ class GeneralizedLinearRegressionSuite
var idx = 0
for (fitIntercept <- Seq(false, true)) {
for (family <- Seq("gaussian", "poisson", "gamma")) {
val trainer = new GeneralizedLinearRegression().setFamily(family)
.setFitIntercept(fitIntercept).setOffsetCol("offset")
.setWeightCol("weight").setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~= expected(idx) absTol 1e-4, s"Model mismatch: GLM with family = $family," +
s" and fitIntercept = $fitIntercept.")

val familyObj = Family.fromName(family)
val familyLink = new FamilyAndLink(familyObj, familyObj.defaultLink)
model.transform(dataset).select("features", "offset", "prediction", "linkPrediction")
.collect().foreach {
case Row(features: DenseVector, offset: Double, prediction1: Double,
linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept + offset
val prediction2 = familyLink.fitted(eta)
val linkPrediction2 = eta
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"family = $family, and fitIntercept = $fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with family = $family, and fitIntercept = $fitIntercept.")
}
val trainer = new GeneralizedLinearRegression().setFamily(family)
.setFitIntercept(fitIntercept).setOffsetCol("offset")
.setWeightCol("weight").setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~= expected(idx) absTol 1e-4, s"Model mismatch: GLM with family = $family," +
s" and fitIntercept = $fitIntercept.")

val familyObj = Family.fromName(family)
val familyLink = new FamilyAndLink(familyObj, familyObj.defaultLink)
model.transform(dataset).select("features", "offset", "prediction", "linkPrediction")
.collect().foreach {
case Row(features: DenseVector, offset: Double, prediction1: Double,
linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept + offset
val prediction2 = familyLink.fitted(eta)
val linkPrediction2 = eta
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"family = $family, and fitIntercept = $fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with family = $family, and fitIntercept = $fitIntercept.")
}

idx += 1
}
Expand Down

0 comments on commit e183c08

Please sign in to comment.