Skip to content

Commit

Permalink
closed #62
Browse files Browse the repository at this point in the history
  • Loading branch information
bbengfort committed Nov 5, 2016
1 parent 89469e4 commit 1af97f3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 14 deletions.
12 changes: 8 additions & 4 deletions yellowbrick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def __init__(self, model, ax=None, **kwargs):
These parameters can be influenced later on in the visualization
process, but can and should be set as early as possible.
"""
self.estimator = model
super(ScoreVisualizer, self).__init__(ax=ax, **kwargs)

self.estimator = model
self.name = get_model_name(self.estimator)

def fit(self, X, y=None, **kwargs):
Expand Down Expand Up @@ -222,8 +222,9 @@ def __init__(self, ax=None, **kwargs):
These parameters can be influenced later on in the visualization
process, but can and should be set as early as possible.
"""
self.estimator = model
super(ScoreVisualizer, self).__init__(ax=ax, **kwargs)
self.estimator = model


def fit(self, X, y=None, **kwargs):
"""
Expand Down Expand Up @@ -252,7 +253,7 @@ class MultiModelMixin(object):
Does predict for each of the models and generates subplots.
"""

def __init__(self, models, **kwargs):
def __init__(self, models, ax=None, **kwargs):
# Ensure models is a collection, if it's a single estimator then we
# wrap it in a list so that the API doesn't break during render.
"""
Expand All @@ -265,7 +266,10 @@ def __init__(self, models, **kwargs):
These parameters can be influenced later on in the visualization
process, but can and should be set as early as possible.
"""
if isestimator(models):
# TODO: How to handle the axes in this mixin?
self.ax = ax

if all(isestimator, models):
models = [models]

# Keep track of the models
Expand Down
10 changes: 0 additions & 10 deletions yellowbrick/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ def __init__(self, model, ax=None, classes=None, **kwargs):
"""
super(ClassificationReport, self).__init__(model, ax=ax, **kwargs)

## hoisted to Visualizer base class
# self.ax = ax

## hoisted to ScoreVisualizer base class
self.estimator = model
self.name = get_model_name(self.estimator)
Expand Down Expand Up @@ -228,9 +225,6 @@ def __init__(self, model, ax=None, **kwargs):
"""
super(ROCAUC, self).__init__(model, ax=ax, **kwargs)

## hoisted to Visualizer base class
# self.ax = ax

## hoisted to ScoreVisualizer base class
self.name = get_model_name(self.estimator)

Expand Down Expand Up @@ -350,10 +344,6 @@ def __init__(self, model, ax=None, classes=None, **kwargs):
"""
super(ClassBalance, self).__init__(model, ax=ax, **kwargs)

## hoisted to ScoreVisualizer base class
self.estimator = model
self.name = get_model_name(self.estimator)

self.colors = kwargs.pop('colors', YELLOWBRICK_PALETTES['paired'])
self.classes_ = classes

Expand Down

0 comments on commit 1af97f3

Please sign in to comment.