forked from statsmodels/statsmodels
/
plot_grids.py
166 lines (136 loc) · 5.53 KB
/
plot_grids.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
'''create scatterplot with confidence ellipsis
Author: Josef Perktold
License: BSD-3
TODO: update script to use sharex, sharey, and visible=False
see http://www.scipy.org/Cookbook/Matplotlib/Multiple_Subplots_with_One_Axis_Label
for sharex I need to have the ax of the last_row when editing the earlier
rows. Or you axes_grid1, imagegrid
http://matplotlib.sourceforge.net/mpl_toolkits/axes_grid/users/overview.html
'''
import numpy as np
from scipy import stats
from . import utils
__all__ = ['scatter_ellipse']
def _make_ellipse(mean, cov, ax, level=0.95, color=None):
"""Support function for scatter_ellipse."""
from matplotlib.patches import Ellipse
v, w = np.linalg.eigh(cov)
u = w[0] / np.linalg.norm(w[0])
angle = np.arctan(u[1]/u[0])
angle = 180 * angle / np.pi # convert to degrees
v = 2 * np.sqrt(v * stats.chi2.ppf(level, 2)) #get size corresponding to level
ell = Ellipse(mean[:2], v[0], v[1], 180 + angle, facecolor='none',
edgecolor=color,
#ls='dashed', #for debugging
lw=1.5)
ell.set_clip_box(ax.bbox)
ell.set_alpha(0.5)
ax.add_artist(ell)
def scatter_ellipse(data, level=0.9, varnames=None, ell_kwds=None,
plot_kwds=None, add_titles=False, keep_ticks=False,
fig=None):
"""Create a grid of scatter plots with confidence ellipses.
ell_kwds, plot_kdes not used yet
looks ok with 5 or 6 variables, too crowded with 8, too empty with 1
Parameters
----------
data : array_like
Input data.
level : scalar, optional
Default is 0.9.
varnames : list of str, optional
Variable names. Used for y-axis labels, and if `add_titles` is True
also for titles. If not given, integers 1..data.shape[1] are used.
ell_kwds : dict, optional
UNUSED
plot_kwds : dict, optional
UNUSED
add_titles : bool, optional
Whether or not to add titles to each subplot. Default is False.
Titles are constructed from `varnames`.
keep_ticks : bool, optional
If False (default), remove all axis ticks.
fig : Matplotlib figure instance, optional
If given, this figure is simply returned. Otherwise a new figure is
created.
Returns
-------
fig : Matplotlib figure instance
If `fig` is None, the created figure. Otherwise `fig` itself.
"""
fig = utils.create_mpl_fig(fig)
import matplotlib.ticker as mticker
data = np.asanyarray(data) #needs mean and cov
nvars = data.shape[1]
if varnames is None:
#assuming single digit, nvars<=10 else use 'var%2d'
varnames = ['var%d' % i for i in range(nvars)]
plot_kwds_ = dict(ls='none', marker='.', color='k', alpha=0.5)
if plot_kwds:
plot_kwds_.update(plot_kwds)
ell_kwds_= dict(color='k')
if ell_kwds:
ell_kwds_.update(ell_kwds)
dmean = data.mean(0)
dcov = np.cov(data, rowvar=0)
for i in range(1, nvars):
#print '---'
ax_last=None
for j in range(i):
#print i,j, i*(nvars-1)+j+1
ax = fig.add_subplot(nvars-1, nvars-1, (i-1)*(nvars-1)+j+1)
## #sharey=ax_last) #sharey doesn't allow empty ticks?
## if j == 0:
## print 'new ax_last', j
## ax_last = ax
## ax.set_ylabel(varnames[i])
#TODO: make sure we have same xlim and ylim
formatter = mticker.FormatStrFormatter('% 3.1f')
ax.yaxis.set_major_formatter(formatter)
ax.xaxis.set_major_formatter(formatter)
idx = np.array([j,i])
ax.plot(*data[:,idx].T, **plot_kwds_)
if np.isscalar(level):
level = [level]
for alpha in level:
_make_ellipse(dmean[idx], dcov[idx[:,None], idx], ax, level=alpha,
**ell_kwds_)
if add_titles:
ax.set_title('%s-%s' % (varnames[i], varnames[j]))
if not ax.is_first_col():
if not keep_ticks:
ax.set_yticks([])
else:
ax.yaxis.set_major_locator(mticker.MaxNLocator(3))
else:
ax.set_ylabel(varnames[i])
if ax.is_last_row():
ax.set_xlabel(varnames[j])
else:
if not keep_ticks:
ax.set_xticks([])
else:
ax.xaxis.set_major_locator(mticker.MaxNLocator(3))
dcorr = np.corrcoef(data, rowvar=0)
dc = dcorr[idx[:,None], idx]
xlim = ax.get_xlim()
ylim = ax.get_ylim()
## xt = xlim[0] + 0.1 * (xlim[1] - xlim[0])
## yt = ylim[0] + 0.1 * (ylim[1] - ylim[0])
## if dc[1,0] < 0 :
## yt = ylim[0] + 0.1 * (ylim[1] - ylim[0])
## else:
## yt = ylim[1] - 0.2 * (ylim[1] - ylim[0])
yrangeq = ylim[0] + 0.4 * (ylim[1] - ylim[0])
if dc[1,0] < -0.25 or (dc[1,0] < 0.25 and dmean[idx][1] > yrangeq):
yt = ylim[0] + 0.1 * (ylim[1] - ylim[0])
else:
yt = ylim[1] - 0.2 * (ylim[1] - ylim[0])
xt = xlim[0] + 0.1 * (xlim[1] - xlim[0])
ax.text(xt, yt, '$\\rho=%0.2f$'% dc[1,0])
for ax in fig.axes:
if ax.is_last_row(): # or ax.is_first_col():
ax.xaxis.set_major_locator(mticker.MaxNLocator(3))
if ax.is_first_col():
ax.yaxis.set_major_locator(mticker.MaxNLocator(3))
return fig