-
-
Notifications
You must be signed in to change notification settings - Fork 554
/
pcolor.py
177 lines (137 loc) · 5.47 KB
/
pcolor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# yellowbrick.gridsearch.pcolor
# Colorplot visualizer for gridsearch results.
#
# Author: Phillip Schafer
# Created: Sat Feb 3 10:18:33 2018 -0500
#
# Copyright (C) 2018 The scikit-yb developers
# For license information, see LICENSE.txt
#
# ID: pcolor.py [03724ed] pbs929@users.noreply.github.com $
"""
Colorplot visualizer for gridsearch results.
"""
import numpy as np
from .base import GridSearchVisualizer
## Packages for export
__all__ = ["GridSearchColorPlot", "gridsearch_color_plot"]
##########################################################################
## Quick method
##########################################################################
def gridsearch_color_plot(model, x_param, y_param, X=None, y=None, ax=None, **kwargs):
"""Quick method:
Create a color plot showing the best grid search scores across two
parameters.
This helper function is a quick wrapper to utilize GridSearchColorPlot
for one-off analysis.
If no `X` data is passed, the model is assumed to be fit already. This
allows quick exploration without waiting for the grid search to re-run.
Parameters
----------
model : Scikit-Learn grid search object
Should be an instance of GridSearchCV. If not, an exception is raised.
The model may be fit or unfit.
x_param : string
The name of the parameter to be visualized on the horizontal axis.
y_param : string
The name of the parameter to be visualized on the vertical axis.
metric : string (default 'mean_test_score')
The field from the grid search's `cv_results` that we want to display.
X : ndarray or DataFrame of shape n x m or None (default None)
A matrix of n instances with m features. If not None, forces the
GridSearchCV object to be fit.
y : ndarray or Series of length n or None (default None)
An array or series of target or class values.
ax : matplotlib axes
The axes to plot the figure on.
classes : list of strings
The names of the classes in the target
Returns
-------
ax : matplotlib axes
Returns the axes that the classification report was drawn on.
"""
# Instantiate the visualizer
visualizer = GridSearchColorPlot(model, x_param, y_param, ax=ax, **kwargs)
# Fit if necessary
if X is not None:
visualizer.fit(X, y)
else:
visualizer.draw()
# Return the axes object on the visualizer
return visualizer.ax
class GridSearchColorPlot(GridSearchVisualizer):
"""
Create a color plot showing the best grid search scores across two
parameters.
Parameters
----------
model : Scikit-Learn grid search object
Should be an instance of GridSearchCV. If not, an exception is raised.
x_param : string
The name of the parameter to be visualized on the horizontal axis.
y_param : string
The name of the parameter to be visualized on the vertical axis.
metric : string (default 'mean_test_score')
The field from the grid search's `cv_results` that we want to display.
ax : matplotlib Axes, default: None
The axes to plot the figure on. If None is passed in the current axes
will be used (or generated if required).
colormap : string or cmap, default: 'RdBu_r'
optional string or matplotlib cmap to colorize lines
Use either color to colorize the lines on a per class basis or
colormap to color them on a continuous scale.
kwargs : dict
Keyword arguments that are passed to the base class and may influence
the visualization as defined in other Visualizers.
Examples
--------
>>> from yellowbrick.gridsearch import GridSearchColorPlot
>>> from sklearn.model_selection import GridSearchCV
>>> from sklearn.svm import SVC
>>> gridsearch = GridSearchCV(SVC(),
{'kernel': ['rbf', 'linear'], 'C': [1, 10]})
>>> model = GridSearchColorPlot(gridsearch, x_param='kernel', y_param='C')
>>> model.fit(X)
>>> model.show()
"""
def __init__(
self,
model,
x_param,
y_param,
metric="mean_test_score",
colormap="RdBu_r",
ax=None,
**kwargs
):
super(GridSearchColorPlot, self).__init__(model, ax=ax, **kwargs)
self.x_param = x_param
self.y_param = y_param
self.metric = metric
self.colormap = colormap
def draw(self):
# Project the grid search results to 2 dimensions
x_vals, y_vals, best_scores = self.param_projection(
self.x_param, self.y_param, metric=self.metric
)
# Mask nans so that they can be filled with a hatch
data = np.ma.masked_invalid(best_scores)
# Plot and fill in hatch for nans
mesh = self.ax.pcolor(
data, cmap=self.colormap, vmin=np.nanmin(data), vmax=np.nanmax(data)
)
self.ax.patch.set(hatch="x", edgecolor="black")
# Ticks and tick labels
self.ax.set_xticks(np.arange(len(x_vals)) + 0.5)
self.ax.set_yticks(np.arange(len(y_vals)) + 0.5)
self.ax.set_xticklabels(x_vals, rotation=45)
self.ax.set_yticklabels(y_vals, rotation=45)
# Add the colorbar
cb = self.ax.figure.colorbar(mesh, None, self.ax)
cb.outline.set_linewidth(0)
self.ax.set_aspect("equal")
def finalize(self):
self.set_title("Grid Search Scores")
self.ax.set_xlabel(self.x_param)
self.ax.set_ylabel(self.y_param)