forked from statsmodels/statsmodels
-
Notifications
You must be signed in to change notification settings - Fork 0
/
factorplots.py
212 lines (180 loc) · 7.42 KB
/
factorplots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# -*- coding: utf-8 -*-
"""
Authors: Josef Perktold, Skipper Seabold, Denis A. Engemann
"""
from statsmodels.compat.python import get_function_name, iterkeys, lrange, zip, iteritems
import numpy as np
from statsmodels.graphics.plottools import rainbow
import statsmodels.graphics.utils as utils
def interaction_plot(x, trace, response, func=np.mean, ax=None, plottype='b',
xlabel=None, ylabel=None, colors=[], markers=[],
linestyles=[], legendloc='best', legendtitle=None,
**kwargs):
"""
Interaction plot for factor level statistics.
Note. If categorial factors are supplied levels will be internally
recoded to integers. This ensures matplotlib compatiblity.
uses pandas.DataFrame to calculate an `aggregate` statistic for each
level of the factor or group given by `trace`.
Parameters
----------
x : array-like
The `x` factor levels constitute the x-axis. If a `pandas.Series` is
given its name will be used in `xlabel` if `xlabel` is None.
trace : array-like
The `trace` factor levels will be drawn as lines in the plot.
If `trace` is a `pandas.Series` its name will be used as the
`legendtitle` if `legendtitle` is None.
response : array-like
The reponse or dependent variable. If a `pandas.Series` is given
its name will be used in `ylabel` if `ylabel` is None.
func : function
Anything accepted by `pandas.DataFrame.aggregate`. This is applied to
the response variable grouped by the trace levels.
plottype : str {'line', 'scatter', 'both'}, optional
The type of plot to return. Can be 'l', 's', or 'b'
ax : axes, optional
Matplotlib axes instance
xlabel : str, optional
Label to use for `x`. Default is 'X'. If `x` is a `pandas.Series` it
will use the series names.
ylabel : str, optional
Label to use for `response`. Default is 'func of response'. If
`response` is a `pandas.Series` it will use the series names.
colors : list, optional
If given, must have length == number of levels in trace.
linestyles : list, optional
If given, must have length == number of levels in trace.
markers : list, optional
If given, must have length == number of lovels in trace
kwargs
These will be passed to the plot command used either plot or scatter.
If you want to control the overall plotting options, use kwargs.
Returns
-------
fig : Figure
The figure given by `ax.figure` or a new instance.
Examples
--------
>>> import numpy as np
>>> np.random.seed(12345)
>>> weight = np.random.randint(1,4,size=60)
>>> duration = np.random.randint(1,3,size=60)
>>> days = np.log(np.random.randint(1,30, size=60))
>>> fig = interaction_plot(weight, duration, days,
... colors=['red','blue'], markers=['D','^'], ms=10)
>>> import matplotlib.pyplot as plt
>>> plt.show()
.. plot::
import numpy as np
from statsmodels.graphics.factorplots import interaction_plot
np.random.seed(12345)
weight = np.random.randint(1,4,size=60)
duration = np.random.randint(1,3,size=60)
days = np.log(np.random.randint(1,30, size=60))
fig = interaction_plot(weight, duration, days,
colors=['red','blue'], markers=['D','^'], ms=10)
import matplotlib.pyplot as plt
#plt.show()
"""
from pandas import DataFrame
fig, ax = utils.create_mpl_ax(ax)
response_name = ylabel or getattr(response, 'name', 'response')
ylabel = '%s of %s' % (get_function_name(func), response_name)
xlabel = xlabel or getattr(x, 'name', 'X')
legendtitle = legendtitle or getattr(trace, 'name', 'Trace')
ax.set_ylabel(ylabel)
ax.set_xlabel(xlabel)
x_values = x_levels = None
if isinstance(x[0], str):
x_levels = [l for l in np.unique(x)]
x_values = lrange(len(x_levels))
x = _recode(x, dict(zip(x_levels, x_values)))
data = DataFrame(dict(x=x, trace=trace, response=response))
plot_data = data.groupby(['trace', 'x']).aggregate(func).reset_index()
# return data
# check plot args
n_trace = len(plot_data['trace'].unique())
if linestyles:
try:
assert len(linestyles) == n_trace
except AssertionError as err:
raise ValueError("Must be a linestyle for each trace level")
else: # set a default
linestyles = ['-'] * n_trace
if markers:
try:
assert len(markers) == n_trace
except AssertionError as err:
raise ValueError("Must be a linestyle for each trace level")
else: # set a default
markers = ['.'] * n_trace
if colors:
try:
assert len(colors) == n_trace
except AssertionError as err:
raise ValueError("Must be a linestyle for each trace level")
else: # set a default
#TODO: how to get n_trace different colors?
colors = rainbow(n_trace)
if plottype == 'both' or plottype == 'b':
for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
# trace label
label = str(group['trace'].values[0])
ax.plot(group['x'], group['response'], color=colors[i],
marker=markers[i], label=label,
linestyle=linestyles[i], **kwargs)
elif plottype == 'line' or plottype == 'l':
for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
# trace label
label = str(group['trace'].values[0])
ax.plot(group['x'], group['response'], color=colors[i],
label=label, linestyle=linestyles[i], **kwargs)
elif plottype == 'scatter' or plottype == 's':
for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
# trace label
label = str(group['trace'].values[0])
ax.scatter(group['x'], group['response'], color=colors[i],
label=label, marker=markers[i], **kwargs)
else:
raise ValueError("Plot type %s not understood" % plottype)
ax.legend(loc=legendloc, title=legendtitle)
ax.margins(.1)
if all([x_levels, x_values]):
ax.set_xticks(x_values)
ax.set_xticklabels(x_levels)
return fig
def _recode(x, levels):
""" Recode categorial data to int factor.
Parameters
----------
x : array-like
array like object supporting with numpy array methods of categorially
coded data.
levels : dict
mapping of labels to integer-codings
Returns
-------
out : instance numpy.ndarray
"""
from pandas import Series
name = None
if isinstance(x, Series):
name = x.name
x = x.values
if x.dtype.type not in [np.str_, np.object_]:
raise ValueError('This is not a categorial factor.'
' Array of str type required.')
elif not isinstance(levels, dict):
raise ValueError('This is not a valid value for levels.'
' Dict required.')
elif not (np.unique(x) == np.unique(list(iterkeys(levels)))).all():
raise ValueError('The levels do not match the array values.')
else:
out = np.empty(x.shape[0], dtype=np.int)
for level, coding in iteritems(levels):
out[x == level] = coding
if name:
out = Series(out)
out.name = name
return out