Skip to content

Commit

Permalink
variable naming
Browse files Browse the repository at this point in the history
  • Loading branch information
sethah committed Dec 27, 2016
1 parent 9f2eaf3 commit f2dbdd9
Showing 1 changed file with 6 additions and 6 deletions.
Expand Up @@ -210,12 +210,12 @@ object MLTestingUtils extends SparkFunSuite {
* to assigning a sample weight proportional to the number of samples for each point.
*/
def testOversamplingVsWeighting[M <: Model[M], E <: Estimator[M]](
df: Dataset[LabeledPoint],
data: Dataset[LabeledPoint],
estimator: E with HasWeightCol,
modelEquals: (M, M) => Unit,
seed: Long): Unit = {
val (overSampledData, weightedData) = genEquivalentOversampledAndWeightedInstances(
df, seed)
data, seed)
val weightedModel = estimator.set(estimator.weightCol, "weight").fit(weightedData)
val overSampledModel = estimator.set(estimator.weightCol, "").fit(overSampledData)
modelEquals(weightedModel, overSampledModel)
Expand All @@ -227,17 +227,17 @@ object MLTestingUtils extends SparkFunSuite {
* model despite the outliers.
*/
def testOutliersWithSmallWeights[M <: Model[M], E <: Estimator[M]](
ds: Dataset[LabeledPoint],
data: Dataset[LabeledPoint],
estimator: E with HasWeightCol,
numClasses: Int,
modelEquals: (M, M) => Unit): Unit = {
import ds.sqlContext.implicits._
val outlierDS = ds.withColumn("weight", lit(1.0)).as[Instance].flatMap {
import data.sqlContext.implicits._
val outlierDS = data.withColumn("weight", lit(1.0)).as[Instance].flatMap {
case Instance(l, w, f) =>
val outlierLabel = if (numClasses == 0) -l else numClasses - l - 1
List.fill(3)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f))
}
val trueModel = estimator.set(estimator.weightCol, "").fit(ds)
val trueModel = estimator.set(estimator.weightCol, "").fit(data)
val outlierModel = estimator.set(estimator.weightCol, "weight").fit(outlierDS)
modelEquals(trueModel, outlierModel)
}
Expand Down

0 comments on commit f2dbdd9

Please sign in to comment.