-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
circle.py
421 lines (350 loc) · 15.1 KB
/
circle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
"""Functions to plot on circle as for connectivity."""
# Authors: Alexandre Gramfort <alexandre.gramfort@inria.fr>
# Denis Engemann <denis.engemann@gmail.com>
# Martin Luessi <mluessi@nmr.mgh.harvard.edu>
#
# License: Simplified BSD
from itertools import cycle
from functools import partial
import numpy as np
from .utils import plt_show, _get_cmap
from ..utils import _validate_type
def circular_layout(node_names, node_order, start_pos=90, start_between=True,
group_boundaries=None, group_sep=10):
"""Create layout arranging nodes on a circle.
Parameters
----------
node_names : list of str
Node names.
node_order : list of str
List with node names defining the order in which the nodes are
arranged. Must have the elements as node_names but the order can be
different. The nodes are arranged clockwise starting at "start_pos"
degrees.
start_pos : float
Angle in degrees that defines where the first node is plotted.
start_between : bool
If True, the layout starts with the position between the nodes. This is
the same as adding "180. / len(node_names)" to start_pos.
group_boundaries : None | array-like
List of of boundaries between groups at which point a "group_sep" will
be inserted. E.g. "[0, len(node_names) / 2]" will create two groups.
group_sep : float
Group separation angle in degrees. See "group_boundaries".
Returns
-------
node_angles : array, shape=(n_node_names,)
Node angles in degrees.
"""
n_nodes = len(node_names)
if len(node_order) != n_nodes:
raise ValueError('node_order has to be the same length as node_names')
if group_boundaries is not None:
boundaries = np.array(group_boundaries, dtype=np.int64)
if np.any(boundaries >= n_nodes) or np.any(boundaries < 0):
raise ValueError('"group_boundaries" has to be between 0 and '
'n_nodes - 1.')
if len(boundaries) > 1 and np.any(np.diff(boundaries) <= 0):
raise ValueError('"group_boundaries" must have non-decreasing '
'values.')
n_group_sep = len(group_boundaries)
else:
n_group_sep = 0
boundaries = None
# convert it to a list with indices
node_order = [node_order.index(name) for name in node_names]
node_order = np.array(node_order)
if len(np.unique(node_order)) != n_nodes:
raise ValueError('node_order has repeated entries')
node_sep = (360. - n_group_sep * group_sep) / n_nodes
if start_between:
start_pos += node_sep / 2
if boundaries is not None and boundaries[0] == 0:
# special case when a group separator is at the start
start_pos += group_sep / 2
boundaries = boundaries[1:] if n_group_sep > 1 else None
node_angles = np.ones(n_nodes, dtype=np.float64) * node_sep
node_angles[0] = start_pos
if boundaries is not None:
node_angles[boundaries] += group_sep
node_angles = np.cumsum(node_angles)[node_order]
return node_angles
def _plot_connectivity_circle_onpick(event, fig=None, ax=None, indices=None,
n_nodes=0, node_angles=None,
ylim=[9, 10]):
"""Isolate connections around a single node when user left clicks a node.
On right click, resets all connections.
"""
if event.inaxes != ax:
return
if event.button == 1: # left click
# click must be near node radius
if not ylim[0] <= event.ydata <= ylim[1]:
return
# all angles in range [0, 2*pi]
node_angles = node_angles % (np.pi * 2)
node = np.argmin(np.abs(event.xdata - node_angles))
patches = event.inaxes.patches
for ii, (x, y) in enumerate(zip(indices[0], indices[1])):
patches[ii].set_visible(node in [x, y])
fig.canvas.draw()
elif event.button == 3: # right click
patches = event.inaxes.patches
for ii in range(np.size(indices, axis=1)):
patches[ii].set_visible(True)
fig.canvas.draw()
def _plot_connectivity_circle(con, node_names, indices=None, n_lines=None,
node_angles=None, node_width=None,
node_height=None, node_colors=None,
facecolor='black', textcolor='white',
node_edgecolor='black', linewidth=1.5,
colormap='hot', vmin=None, vmax=None,
colorbar=True, title=None,
colorbar_size=None, colorbar_pos=None,
fontsize_title=12, fontsize_names=8,
fontsize_colorbar=8, padding=6.,
ax=None, interactive=True,
node_linewidth=2., show=True):
import matplotlib.pyplot as plt
import matplotlib.path as m_path
import matplotlib.patches as m_patches
from matplotlib.projections.polar import PolarAxes
_validate_type(ax, (None, PolarAxes))
n_nodes = len(node_names)
if node_angles is not None:
if len(node_angles) != n_nodes:
raise ValueError('node_angles has to be the same length '
'as node_names')
# convert it to radians
node_angles = node_angles * np.pi / 180
else:
# uniform layout on unit circle
node_angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False)
if node_width is None:
# widths correspond to the minimum angle between two nodes
dist_mat = node_angles[None, :] - node_angles[:, None]
dist_mat[np.diag_indices(n_nodes)] = 1e9
node_width = np.min(np.abs(dist_mat))
else:
node_width = node_width * np.pi / 180
if node_height is None:
node_height = 1.0
if node_colors is not None:
if len(node_colors) < n_nodes:
node_colors = cycle(node_colors)
else:
# assign colors using colormap
try:
spectral = plt.cm.spectral
except AttributeError:
spectral = plt.cm.Spectral
node_colors = [spectral(i / float(n_nodes))
for i in range(n_nodes)]
# handle 1D and 2D connectivity information
if con.ndim == 1:
if indices is None:
raise ValueError('indices has to be provided if con.ndim == 1')
elif con.ndim == 2:
if con.shape[0] != n_nodes or con.shape[1] != n_nodes:
raise ValueError('con has to be 1D or a square matrix')
# we use the lower-triangular part
indices = np.tril_indices(n_nodes, -1)
con = con[indices]
else:
raise ValueError('con has to be 1D or a square matrix')
# get the colormap
colormap = _get_cmap(colormap)
# Use a polar axes
if ax is None:
fig = plt.figure(figsize=(8, 8), facecolor=facecolor)
ax = fig.add_subplot(polar=True)
else:
fig = ax.figure
ax.set_facecolor(facecolor)
# No ticks, we'll put our own
ax.set_xticks([])
ax.set_yticks([])
# Set y axes limit, add additional space if requested
ax.set_ylim(0, 10 + padding)
# Remove the black axes border which may obscure the labels
ax.spines['polar'].set_visible(False)
# Draw lines between connected nodes, only draw the strongest connections
if n_lines is not None and len(con) > n_lines:
con_thresh = np.sort(np.abs(con).ravel())[-n_lines]
else:
con_thresh = 0.
# get the connections which we are drawing and sort by connection strength
# this will allow us to draw the strongest connections first
con_abs = np.abs(con)
con_draw_idx = np.where(con_abs >= con_thresh)[0]
con = con[con_draw_idx]
con_abs = con_abs[con_draw_idx]
indices = [ind[con_draw_idx] for ind in indices]
# now sort them
sort_idx = np.argsort(con_abs)
del con_abs
con = con[sort_idx]
indices = [ind[sort_idx] for ind in indices]
# Get vmin vmax for color scaling
if vmin is None:
vmin = np.min(con[np.abs(con) >= con_thresh])
if vmax is None:
vmax = np.max(con)
vrange = vmax - vmin
# We want to add some "noise" to the start and end position of the
# edges: We modulate the noise with the number of connections of the
# node and the connection strength, such that the strongest connections
# are closer to the node center
nodes_n_con = np.zeros((n_nodes), dtype=np.int64)
for i, j in zip(indices[0], indices[1]):
nodes_n_con[i] += 1
nodes_n_con[j] += 1
# initialize random number generator so plot is reproducible
rng = np.random.mtrand.RandomState(0)
n_con = len(indices[0])
noise_max = 0.25 * node_width
start_noise = rng.uniform(-noise_max, noise_max, n_con)
end_noise = rng.uniform(-noise_max, noise_max, n_con)
nodes_n_con_seen = np.zeros_like(nodes_n_con)
for i, (start, end) in enumerate(zip(indices[0], indices[1])):
nodes_n_con_seen[start] += 1
nodes_n_con_seen[end] += 1
start_noise[i] *= ((nodes_n_con[start] - nodes_n_con_seen[start]) /
float(nodes_n_con[start]))
end_noise[i] *= ((nodes_n_con[end] - nodes_n_con_seen[end]) /
float(nodes_n_con[end]))
# scale connectivity for colormap (vmin<=>0, vmax<=>1)
con_val_scaled = (con - vmin) / vrange
# Finally, we draw the connections
for pos, (i, j) in enumerate(zip(indices[0], indices[1])):
# Start point
t0, r0 = node_angles[i], 10
# End point
t1, r1 = node_angles[j], 10
# Some noise in start and end point
t0 += start_noise[pos]
t1 += end_noise[pos]
verts = [(t0, r0), (t0, 5), (t1, 5), (t1, r1)]
codes = [m_path.Path.MOVETO, m_path.Path.CURVE4, m_path.Path.CURVE4,
m_path.Path.LINETO]
path = m_path.Path(verts, codes)
color = colormap(con_val_scaled[pos])
# Actual line
patch = m_patches.PathPatch(path, fill=False, edgecolor=color,
linewidth=linewidth, alpha=1.)
ax.add_patch(patch)
# Draw ring with colored nodes
height = np.ones(n_nodes) * node_height
bars = ax.bar(node_angles, height, width=node_width, bottom=9,
edgecolor=node_edgecolor, lw=node_linewidth,
facecolor='.9', align='center')
for bar, color in zip(bars, node_colors):
bar.set_facecolor(color)
# Draw node labels
angles_deg = 180 * node_angles / np.pi
for name, angle_rad, angle_deg in zip(node_names, node_angles, angles_deg):
if angle_deg >= 270:
ha = 'left'
else:
# Flip the label, so text is always upright
angle_deg += 180
ha = 'right'
ax.text(angle_rad, 9.4 + node_height, name, size=fontsize_names,
rotation=angle_deg, rotation_mode='anchor',
horizontalalignment=ha, verticalalignment='center',
color=textcolor)
if title is not None:
ax.set_title(title, color=textcolor, fontsize=fontsize_title)
if colorbar:
sm = plt.cm.ScalarMappable(cmap=colormap,
norm=plt.Normalize(vmin, vmax))
sm.set_array(np.linspace(vmin, vmax))
colorbar_kwargs = dict()
if colorbar_size is not None:
colorbar_kwargs.update(shrink=colorbar_size)
if colorbar_pos is not None:
colorbar_kwargs.update(anchor=colorbar_pos)
cb = fig.colorbar(sm, ax=ax, **colorbar_kwargs)
cb_yticks = plt.getp(cb.ax.axes, 'yticklabels')
cb.ax.tick_params(labelsize=fontsize_colorbar)
plt.setp(cb_yticks, color=textcolor)
# Add callback for interaction
if interactive:
callback = partial(_plot_connectivity_circle_onpick, fig=fig,
ax=ax, indices=indices, n_nodes=n_nodes,
node_angles=node_angles)
fig.canvas.mpl_connect('button_press_event', callback)
plt_show(show)
return fig, ax
def plot_channel_labels_circle(labels, colors=None, picks=None, **kwargs):
"""Plot labels for each channel in a circle plot.
.. note:: This primarily makes sense for sEEG channels where each
channel can be assigned an anatomical label as the electrode
passes through various brain areas.
Parameters
----------
labels : dict
Lists of labels (values) associated with each channel (keys).
colors : dict
The color (value) for each label (key).
picks : list | tuple
The channels to consider.
**kwargs : kwargs
Keyword arguments for
:func:`mne_connectivity.viz.plot_connectivity_circle`.
Returns
-------
fig : instance of matplotlib.figure.Figure
The figure handle.
axes : instance of matplotlib.projections.polar.PolarAxes
The subplot handle.
"""
from matplotlib.colors import LinearSegmentedColormap
_validate_type(labels, dict, 'labels')
_validate_type(colors, (dict, None), 'colors')
_validate_type(picks, (list, tuple, None), 'picks')
if picks is not None:
labels = {k: v for k, v in labels.items() if k in picks}
ch_names = list(labels.keys())
all_labels = list(set([label for val in labels.values()
for label in val]))
n_labels = len(all_labels)
if colors is not None:
for label in all_labels:
if label not in colors:
raise ValueError(f'No color provided for {label} in `colors`')
# update all_labels, there may be unconnected labels in colors
all_labels = list(colors.keys())
n_labels = len(all_labels)
# make colormap
label_colors = [colors[label] for label in all_labels]
node_colors = ['black'] * len(ch_names) + label_colors
label_cmap = LinearSegmentedColormap.from_list(
'label_cmap', label_colors, N=len(label_colors))
else:
node_colors = None
node_names = ch_names + all_labels
con = np.zeros((len(node_names), len(node_names))) * np.nan
for idx, ch_name in enumerate(ch_names):
for label in labels[ch_name]:
node_idx = node_names.index(label)
label_color = all_labels.index(label) / n_labels
con[idx, node_idx] = con[node_idx, idx] = label_color # symmetric
# plot
node_order = ch_names + all_labels[::-1]
node_angles = circular_layout(node_names, node_order, start_pos=90,
group_boundaries=[0, len(ch_names)])
# provide defaults but don't overwrite
if 'node_angles' not in kwargs:
kwargs.update(node_angles=node_angles)
if 'colorbar' not in kwargs:
kwargs.update(colorbar=False)
if 'node_colors' not in kwargs:
kwargs.update(node_colors=node_colors)
if 'vmin' not in kwargs:
kwargs.update(vmin=0)
if 'vmax' not in kwargs:
kwargs.update(vmax=1)
if 'colormap' not in kwargs:
kwargs.update(colormap=label_cmap)
return _plot_connectivity_circle(con, node_names, **kwargs)