-
-
Notifications
You must be signed in to change notification settings - Fork 380
/
jointplot.py
104 lines (92 loc) · 2.84 KB
/
jointplot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""Bokeh jointplot."""
import bokeh.plotting as bkp
import numpy as np
from bokeh.layouts import gridplot
from ...distplot import plot_dist
from ...kdeplot import plot_kde
from ...plot_utils import make_label
from ....rcparams import rcParams
def plot_joint(
ax,
figsize,
plotters,
xt_labelsize,
kind,
contour,
fill_last,
joint_kwargs,
gridsize,
marginal_kwargs,
show,
):
"""Bokeh joint plot."""
if ax is None:
tools = rcParams["plot.bokeh.tools"]
output_backend = rcParams["plot.bokeh.output_backend"]
dpi = rcParams["plot.bokeh.figure.dpi"]
axjoin = bkp.figure(
width=int(figsize[0] * dpi * 0.8),
height=int(figsize[1] * dpi * 0.8),
output_backend=output_backend,
tools=tools,
)
ax_hist_x = bkp.figure(
width=int(figsize[0] * dpi * 0.8),
height=int(figsize[1] * dpi * 0.2),
output_backend=output_backend,
tools=tools,
x_range=axjoin.x_range,
)
ax_hist_y = bkp.figure(
width=int(figsize[0] * dpi * 0.2),
height=int(figsize[1] * dpi * 0.8),
output_backend=output_backend,
tools=tools,
y_range=axjoin.y_range,
)
elif len(ax) == 2 and len(ax[0]) == 2 and len(ax[1]) == 2:
ax_hist_x, _ = ax[0]
axjoin, ax_hist_y = ax[1]
else:
raise ValueError("ax must be of lenght 3 but found {}".format(len(ax)))
# Set labels for axes
x_var_name = make_label(plotters[0][0], plotters[0][1])
y_var_name = make_label(plotters[1][0], plotters[1][1])
axjoin.xaxis.axis_label = x_var_name
axjoin.yaxis.axis_label = y_var_name
# Flatten data
x = plotters[0][2].flatten()
y = plotters[1][2].flatten()
if kind == "scatter":
axjoin.circle(x, y, **joint_kwargs)
elif kind == "kde":
plot_kde(
x,
y,
contour=contour,
fill_last=fill_last,
ax=axjoin,
backend="bokeh",
show=False,
**joint_kwargs
)
else:
if gridsize == "auto":
gridsize = int(len(x) ** 0.35)
gridsize = gridsize / 10
axjoin.hexbin(x, y, size=gridsize, **joint_kwargs)
marginal_kwargs["plot_kwargs"].setdefault("line_color", "black")
for val, ax_, rotate in ((x, ax_hist_x, False), (y, ax_hist_y, True)):
plot_dist(
val,
textsize=xt_labelsize,
rotated=rotate,
ax=ax_,
backend="bokeh",
backend_kwargs={"show": False},
**marginal_kwargs
)
if show:
grid = gridplot([[ax_hist_x, None], [axjoin, ax_hist_y]], toolbar_location="above")
bkp.show(grid)
return np.array([[ax_hist_x, None], [axjoin, ax_hist_y]])