-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-18194][ML] Log instrumentation in OneVsRest, CrossValidator, TrainValidationSplit #16480
Conversation
@@ -344,6 +344,10 @@ final class OneVsRest @Since("1.4.0") ( | |||
multiclassLabeled.unpersist() | |||
} | |||
|
|||
val instrLog = Instrumentation.create(this, multiclassLabeled) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think multiclassLabeled
is the dataset we want, but i'm not sure what it's supposed to be used for in instrumentation.scala so i could be wrong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd actually use the input "dataset" since it has more information (columns), though either should work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, can you please rename this to "instr" to match other classes? I see ALS is also named instrLog, but it's the only one. Could you change ALS to "instr" as well in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I added some comments.
Also, for all of these, I'd move the logging earlier in the train/fit methods so that you log info as soon as it's available.
@@ -344,6 +344,10 @@ final class OneVsRest @Since("1.4.0") ( | |||
multiclassLabeled.unpersist() | |||
} | |||
|
|||
val instrLog = Instrumentation.create(this, multiclassLabeled) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd actually use the input "dataset" since it has more information (columns), though either should work.
@@ -344,6 +344,10 @@ final class OneVsRest @Since("1.4.0") ( | |||
multiclassLabeled.unpersist() | |||
} | |||
|
|||
val instrLog = Instrumentation.create(this, multiclassLabeled) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, can you please rename this to "instr" to match other classes? I see ALS is also named instrLog, but it's the only one. Could you change ALS to "instr" as well in this PR?
@@ -344,6 +344,10 @@ final class OneVsRest @Since("1.4.0") ( | |||
multiclassLabeled.unpersist() | |||
} | |||
|
|||
val instrLog = Instrumentation.create(this, multiclassLabeled) | |||
instrLog.logParams(labelCol, featuresCol, predictionCol) | |||
instrLog.logNumClasses(numClasses) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also log numFeatures, which you can get from models.head.numFeatures
val instrLog = Instrumentation.create(this, multiclassLabeled) | ||
instrLog.logParams(labelCol, featuresCol, predictionCol) | ||
instrLog.logNumClasses(numClasses) | ||
instrLog.logNamedValue("classifier", $(classifier).getClass.getSimpleName) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use getCanonicalName instead
@@ -116,13 +116,17 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St | |||
} | |||
validationDataset.unpersist() | |||
|
|||
val instrLog = Instrumentation.create(this, dataset) | |||
instrLog.logParams(trainRatio) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log seed
* Logs the value with customized name field. | ||
*/ | ||
def logNamedValue(name: String, num: Long): Unit = { | ||
def logNamedValue(name: String, num: JValue): Unit = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this class doesn't expose json4s APIs, let's stick with basic types in the public API. Just do String for now (since the Long version is not yet used).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"num" -> "value"
* @param estimatorParamMaps different params tried by the tuning estimator | ||
* @param evaluator evaluator used to compute the metric for each estimator param value | ||
*/ | ||
def logTuningParams( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like that this is separated out for all tuning algorithms, but it belongs in ml.tuning. How about as a static (object) method in ValidatorParams?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved into ValidatorParams as a class method -- let me know if you feel strongly about it being static. thanks!
Unit test runs & logs:
|
Now the logs show the full class path for the estimator/evaluator/classifier:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just a few comments
/** | ||
* Instrumentation logging for tuning params including the inner estimator and evaluator info. | ||
* | ||
* @param instrumentation instrumentation logger |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd remove this comment since it doesn't add info.
* @param instrumentation instrumentation logger | ||
*/ | ||
protected def logTuningParams(instrumentation: Instrumentation[_]): Unit = { | ||
instrumentation.log(compact(render(map2jvalue(Map[String, JValue]( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say just use instrumentation.logNamedValue for each of these, rather than handling JSON here.
|
||
if (handlePersistence) { | ||
multiclassLabeled.unpersist() | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove extra newline
ok to test |
add to whitelist |
LGTM pending Jenkins tests |
Test build #70970 has started for PR 16480 at commit |
Test build #70972 has started for PR 16480 at commit |
Test build #3522 has finished for PR 16480 at commit
|
Test build #70996 has finished for PR 16480 at commit
|
Merging with master. |
…rainValidationSplit ## What changes were proposed in this pull request? Added instrumentation logging for OneVsRest classifier, CrossValidator, TrainValidationSplit fit() functions. ## How was this patch tested? Ran unit tests and checked the log file (see output in comments). Author: sueann <sueann@databricks.com> Closes apache#16480 from sueann/SPARK-18194.
…rainValidationSplit ## What changes were proposed in this pull request? Added instrumentation logging for OneVsRest classifier, CrossValidator, TrainValidationSplit fit() functions. ## How was this patch tested? Ran unit tests and checked the log file (see output in comments). Author: sueann <sueann@databricks.com> Closes apache#16480 from sueann/SPARK-18194.
What changes were proposed in this pull request?
Added instrumentation logging for OneVsRest classifier, CrossValidator, TrainValidationSplit fit() functions.
How was this patch tested?
Ran unit tests and checked the log file (see output in comments).