From 7282becddfa23e5316610b07f4543792d0edf11a Mon Sep 17 00:00:00 2001 From: Sina Kordestanchi Date: Thu, 8 Jun 2017 17:12:50 +0200 Subject: [PATCH] Modify JavaModelSelectionViaCrossValidationExample I assume this is a better way of using Cross Validation! I assume CrossValidator gets the whole "data", and breaks it into "training" and "test" data. --- ...delSelectionViaCrossValidationExample.java | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java index 975c65edc0ca6..83dab51543b5b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java @@ -48,8 +48,8 @@ public static void main(String[] args) { .getOrCreate(); // $example on$ - // Prepare training documents, which are labeled. - Dataset training = spark.createDataFrame(Arrays.asList( + // Prepare data documents, which are labeled. + Dataset data = spark.createDataFrame(Arrays.asList( new JavaLabeledDocument(0L, "a b c d e spark", 1.0), new JavaLabeledDocument(1L, "b d", 0.0), new JavaLabeledDocument(2L,"spark f g h", 1.0), @@ -97,18 +97,14 @@ public static void main(String[] args) { .setEstimatorParamMaps(paramGrid).setNumFolds(2); // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. - CrossValidatorModel cvModel = cv.fit(training); - - // Prepare test documents, which are unlabeled. - Dataset test = spark.createDataFrame(Arrays.asList( - new JavaDocument(4L, "spark i j k"), - new JavaDocument(5L, "l m n"), - new JavaDocument(6L, "mapreduce spark"), - new JavaDocument(7L, "apache hadoop") - ), JavaDocument.class); + CrossValidatorModel cvModel = cv.fit(data); // Make predictions on test documents. cvModel uses the best model found (lrModel). - Dataset predictions = cvModel.transform(test); + cvModel.avgMetrics(); + Model bestModel = cvModel.bestModel(); + + Dataset predictions = bestModel.transform(data); + for (Row r : predictions.select("id", "text", "probability", "prediction").collectAsList()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3));