-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
1239 visualize decision trees #1511
Conversation
… formats, model checking, etc.
… regular text or graphviz
… root node and immediate child nodes have the right properties
Codecov Report
@@ Coverage Diff @@
## main #1511 +/- ##
=========================================
+ Coverage 100.0% 100.0% +0.1%
=========================================
Files 232 232
Lines 16430 16639 +209
=========================================
+ Hits 16422 16631 +209
Misses 8 8
Continue to review full report at Codecov.
|
# Conflicts: # docs/source/release_notes.rst
@ParthivNaresh I'm excited to review this! Could you please include an example of the output in the PR description? Will make it quick for reviewers to understand what's up. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ParthivNaresh Looks great! And tests look solid. I left some comments but I think the only thing holding merge is the discussion about how to display the column names in the clean_format_tree
evalml/model_understanding/prediction_explanations/explainers.py
Outdated
Show resolved
Hide resolved
evalml/model_understanding/prediction_explanations/explainers.py
Outdated
Show resolved
Hide resolved
num_nodes = est.tree_.node_count | ||
children_left = est.tree_.children_left | ||
children_right = est.tree_.children_right | ||
features = est.tree_.feature |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should save the actual feature names in the output. These feature names range from 0 to n_col-1 because we convert from pandas to sklearn right before fitting. The problem is that the feature names are not saved in the tree object.
I think we have a couple of options:
- Add an option to this function for passing in the feature names
- Change the input type from a tree estimator to a pipeline with a tree estimator. This would allow us to use
input_feature_names[estimator.name]
- File an issue for saving the feature names to the tree estimator and leaving this function as-is for now.
I think I prefer 2 but what do you think? I believe our model understanding methods take in a pipeline instead of an estimator so it'd be more consistent with what we already have. I guess we can also do both 1 and 2, where we add the feature names as parameter to this function and then add another function that accepts a pipeline that calls this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! I added a few suggestions on additional tests, but nothing blocking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ParthivNaresh this is awesome!
I left a few suggestions and questions. Ones I'd like us to address before merge:
- Method naming
- Use data method in graph method
- Split graph unit tests into checking the returned graph content vs checking the filepath image saving
I also left a note about #1535 , that could be cool to look at next!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really good! Just added some tiny nit-picky comments about breaking down tests more and docstrings.
…, and release notes. Also fixed typo in Decision Tree Regressor name
# Conflicts: # docs/source/release_notes.rst
… test to cover list casting of passed feature names
Fixes #1239
The output of
visualize_decision_tree()
will be agraphviz.files.Source
object.For example
visualize_decision_tree(clf=regression_estimator, filled=True, max_depth=2).view()
The output of
clean_format_tree
will be an OrderedDict.