Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addressing issue #590:Adding train alpha and test alpha to residual #806

Merged
merged 7 commits into from
Apr 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 7 additions & 4 deletions tests/test_regressor/test_residuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,14 @@ def test_alpha_param(self, mock_sca):
"""
# Instantiate a prediction error plot, provide custom alpha
visualizer = ResidualsPlot(
Ridge(random_state=8893), alpha=0.3, hist=False
Ridge(random_state=8893), train_alpha=0.3,test_alpha=0.75, hist=False
)

alphas = {
'train_point': 0.3,
'test_point': 0.75
}
# Test param gets set correctly
assert visualizer.alpha == 0.3
assert visualizer.alphas == alphas

visualizer.ax = mock.MagicMock()
visualizer.fit(self.data.X.train, self.data.y.train)
Expand All @@ -369,4 +372,4 @@ def test_alpha_param(self, mock_sca):
# Test that alpha was passed to internal matplotlib scatterplot
_, scatter_kwargs = visualizer.ax.scatter.call_args
assert "alpha" in scatter_kwargs
assert scatter_kwargs["alpha"] == 0.3
assert scatter_kwargs["alpha"] == 0.75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, this shows a good understanding that alpha is 0.3 after visualizer.fit then changes to 0.75 after visualizer.score

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

49 changes: 35 additions & 14 deletions yellowbrick/regressor/residuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def score(self, X, y=None, **kwargs):
return self.score_

def draw(self, y, y_pred):

"""
Parameters
----------
Expand Down Expand Up @@ -370,9 +371,15 @@ class ResidualsPlot(RegressionScoreVisualizer):
line_color : color, default: dark grey
Defines the color of the zero error line, can be any matplotlib color.

alpha : float, default: 0.75
Specify a transparency where 1 is completely opaque and 0 is completely
transparent. This property makes densely clustered points more visible.
train_alpha : float, default: 0.75
Specify a transparency for traininig data, where 1 is completely opaque
and 0 is completely transparent. This property makes densely clustered
points more visible.

test_alpha : float, default: 0.75
Specify a transparency for test data, where 1 is completely opaque
and 0 is completely transparent. This property makes densely clustered
points more visible.

kwargs : dict
Keyword arguments that are passed to the base class and may influence
Expand All @@ -396,8 +403,8 @@ class ResidualsPlot(RegressionScoreVisualizer):
The residuals histogram feature requires matplotlib 2.0.2 or greater.
"""
def __init__(self, model, ax=None, hist=True, train_color='b',
test_color='g', line_color=LINE_COLOR, alpha=0.75,
**kwargs):
test_color='g', line_color=LINE_COLOR, train_alpha=0.75,
test_alpha=0.75,**kwargs):

super(ResidualsPlot, self).__init__(model, ax=ax, **kwargs)

Expand All @@ -422,7 +429,10 @@ def __init__(self, model, ax=None, hist=True, train_color='b',
# Store labels and colors for the legend ordered by call
self._labels, self._colors = [], []

self.alpha = alpha
self.alphas = {
'train_point': train_alpha,
'test_point': test_alpha
}

@memoized
def hax(self):
Expand Down Expand Up @@ -496,7 +506,7 @@ def score(self, X, y=None, train=False, **kwargs):
y_pred = self.predict(X)
scores = y_pred - y
self.draw(y_pred, scores, train=train)

return score

def draw(self, y_pred, residuals, train=False, **kwargs):
Expand Down Expand Up @@ -528,17 +538,19 @@ def draw(self, y_pred, residuals, train=False, **kwargs):
if train:
color = self.colors['train_point']
label = "Train $R^2 = {:0.3f}$".format(self.train_score_)
alpha = self.alphas['train_point']
else:
color = self.colors['test_point']
label = "Test $R^2 = {:0.3f}$".format(self.test_score_)

alpha = self.alphas['test_point']

# Update the legend information
self._labels.append(label)
self._colors.append(color)

# Draw the residuals scatter plot
self.ax.scatter(
y_pred, residuals, c=color, alpha=self.alpha, label=label
y_pred, residuals, c=color, alpha=alpha, label=label
)

# Add residuals histogram
Expand Down Expand Up @@ -593,7 +605,8 @@ def residuals_plot(model,
test_color='g',
line_color=LINE_COLOR,
random_state=None,
alpha=0.75,
train_alpha=0.75,
test_alpha=0.75,
**kwargs):
"""Quick method:

Expand Down Expand Up @@ -648,9 +661,15 @@ def residuals_plot(model,
random_state : int, RandomState instance or None, optional
Passed to the train_test_split function.

alpha : float, default: 0.75
Specify a transparency where 1 is completely opaque and 0 is completely
transparent. This property makes densely clustered points more visible.
train_alpha : float, default: 0.75
Specify a transparency for traininig data, where 1 is completely opaque
and 0 is completely transparent. This property makes densely clustered
points more visible.

test_alpha : float, default: 0.75
Specify a transparency for test data, where 1 is completely opaque and
0 is completely transparent. This property makes densely clustered
points more visible.

kwargs : dict
Keyword arguments that are passed to the base class and may influence
Expand All @@ -662,9 +681,11 @@ def residuals_plot(model,
Returns the axes that the residuals plot was drawn on.
"""
# Instantiate the visualizer

visualizer = ResidualsPlot(
model=model, ax=ax, hist=hist, train_color=train_color,
test_color=test_color, line_color=line_color, alpha=alpha,
test_color=test_color, line_color=line_color,
train_alpha=train_alpha,test_alpha=test_alpha,
**kwargs
)

Expand Down