Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zjffdu committed Apr 14, 2017
1 parent 0daa041 commit 7f20fc8
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
// $example on$
import org.apache.spark.ml.classification.LinearSVC;
import org.apache.spark.ml.classification.LinearSVCModel;
import org.apache.spark.ml.classification.LinearSVCTrainingSummary;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
Expand Down Expand Up @@ -47,6 +48,15 @@ public static void main(String[] args) {
// Print the coefficients and intercept for LinearSVC
System.out.println("Coefficients: "
+ lsvcModel.coefficients() + " Intercept: " + lsvcModel.intercept());

LinearSVCTrainingSummary trainingSummary = lsvcModel.summary();
System.out.println("Total Iteration: " + trainingSummary.totalIterations());
// Obtain the objective per iteration.
double[] objectiveHistory = trainingSummary.objectiveHistory();
System.out.println("objectiveHistory:");
for (double lossPerIteration : objectiveHistory) {
System.out.println(lossPerIteration);
}
// $example off$

spark.stop();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class LinearSVC @Since("2.2.0") (
val model = copyValues(new LinearSVCModel(uid, coefficientVector, interceptVector))
val trainingSummary = LinearSVCTrainingSummary($(labelCol), $(featuresCol),
objectiveHistory, objectiveHistory.length)
model.setSummary(trainingSummary)
model.setSummary(Some(trainingSummary))
instr.logSuccess(model)
model
}
Expand Down Expand Up @@ -290,15 +290,26 @@ class LinearSVCModel private[classification] (
@Since("2.2.0")
def setWeightCol(value: Double): this.type = set(threshold, value)

private var trainingSummary: LinearSVCTrainingSummary = _
private var trainingSummary: Option[LinearSVCTrainingSummary] = None

private[classification]
def setSummary(summary: LinearSVCTrainingSummary): this.type = {
def setSummary(summary: Option[LinearSVCTrainingSummary]): this.type = {
this.trainingSummary = summary
this
}

def summary: LinearSVCTrainingSummary = trainingSummary
/**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
@Since("2.2.0")
def summary: LinearSVCTrainingSummary = trainingSummary.getOrElse(
throw new SparkException("No training summary available for this LinearSVCModel")
)

/** Indicates whether a training summary exists for this model instance. */
@Since("2.2.0")
def hasSummary: Boolean = trainingSummary.isDefined

private val margin: Vector => Double = (features) => {
BLAS.dot(features, coefficients) + intercept
Expand Down Expand Up @@ -368,7 +379,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
}

/**
* Abstraction for Linear SVC Training results.
* Linear SVC Training results.
* Currently, the training summary ignores the training weights except
* for the objective trace.
*/
Expand Down
20 changes: 16 additions & 4 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,24 +172,36 @@ def intercept(self):
"""
return self._call_java("intercept")

@property
@since("2.2.0")
def hasSummary(self):
"""
Indicates whether a training summary exists for this model
instance.
"""
return self._call_java("hasSummary")

@property
@since("2.2.0")
def summary(self):
"""
Gets summary (e.g. objective history, total iterations) of model
trained on the training set.
"""

java_blrt_summary = self._call_java("summary")
return LinearSVCTrainingSummary(java_blrt_summary)
if self.hasSummary:
java_blrt_summary = self._call_java("summary")
return LinearSVCTrainingSummary(java_blrt_summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)


@inherit_doc
class LinearSVCTrainingSummary(JavaWrapper):
"""
.. note:: Experimental
Abstraction for LinearSVC Training results.
Linear SVC Training results.
Currently, the training summary ignores the training weights except
for the objective trace.
Expand Down

0 comments on commit 7f20fc8

Please sign in to comment.