diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearSVCExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearSVCExample.java index a18ed1d0b48fa..ddca5318b04c7 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearSVCExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearSVCExample.java @@ -20,6 +20,7 @@ // $example on$ import org.apache.spark.ml.classification.LinearSVC; import org.apache.spark.ml.classification.LinearSVCModel; +import org.apache.spark.ml.classification.LinearSVCTrainingSummary; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; @@ -47,6 +48,15 @@ public static void main(String[] args) { // Print the coefficients and intercept for LinearSVC System.out.println("Coefficients: " + lsvcModel.coefficients() + " Intercept: " + lsvcModel.intercept()); + + LinearSVCTrainingSummary trainingSummary = lsvcModel.summary(); + System.out.println("Total Iteration: " + trainingSummary.totalIterations()); + // Obtain the objective per iteration. + double[] objectiveHistory = trainingSummary.objectiveHistory(); + System.out.println("objectiveHistory:"); + for (double lossPerIteration : objectiveHistory) { + System.out.println(lossPerIteration); + } // $example off$ spark.stop(); diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 38bd048bec559..4561fe3d9dca7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -252,7 +252,7 @@ class LinearSVC @Since("2.2.0") ( val model = copyValues(new LinearSVCModel(uid, coefficientVector, interceptVector)) val trainingSummary = LinearSVCTrainingSummary($(labelCol), $(featuresCol), objectiveHistory, objectiveHistory.length) - model.setSummary(trainingSummary) + model.setSummary(Some(trainingSummary)) instr.logSuccess(model) model } @@ -290,15 +290,26 @@ class LinearSVCModel private[classification] ( @Since("2.2.0") def setWeightCol(value: Double): this.type = set(threshold, value) - private var trainingSummary: LinearSVCTrainingSummary = _ + private var trainingSummary: Option[LinearSVCTrainingSummary] = None private[classification] - def setSummary(summary: LinearSVCTrainingSummary): this.type = { + def setSummary(summary: Option[LinearSVCTrainingSummary]): this.type = { this.trainingSummary = summary this } - def summary: LinearSVCTrainingSummary = trainingSummary + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.2.0") + def summary: LinearSVCTrainingSummary = trainingSummary.getOrElse( + throw new SparkException("No training summary available for this LinearSVCModel") + ) + + /** Indicates whether a training summary exists for this model instance. */ + @Since("2.2.0") + def hasSummary: Boolean = trainingSummary.isDefined private val margin: Vector => Double = (features) => { BLAS.dot(features, coefficients) + intercept @@ -368,7 +379,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] { } /** - * Abstraction for Linear SVC Training results. + * Linear SVC Training results. * Currently, the training summary ignores the training weights except * for the objective trace. */ diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 77fe162ca11e3..6ab7a6c9b828a 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -172,6 +172,15 @@ def intercept(self): """ return self._call_java("intercept") + @property + @since("2.2.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + @property @since("2.2.0") def summary(self): @@ -179,9 +188,12 @@ def summary(self): Gets summary (e.g. objective history, total iterations) of model trained on the training set. """ - - java_blrt_summary = self._call_java("summary") - return LinearSVCTrainingSummary(java_blrt_summary) + if self.hasSummary: + java_blrt_summary = self._call_java("summary") + return LinearSVCTrainingSummary(java_blrt_summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) @inherit_doc @@ -189,7 +201,7 @@ class LinearSVCTrainingSummary(JavaWrapper): """ .. note:: Experimental - Abstraction for LinearSVC Training results. + Linear SVC Training results. Currently, the training summary ignores the training weights except for the objective trace.