-
Notifications
You must be signed in to change notification settings - Fork 93
/
plotting.py
553 lines (443 loc) · 18.1 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
# coding: utf-8
"""
Utilities for generating matplotlib plots.
.. note::
Avoid importing matplotlib in the module namespace otherwise startup is very slow.
"""
from __future__ import print_function, division, unicode_literals, absolute_import
import os
import collections
import time
import numpy as np
from monty.string import list_strings
from monty.functools import lazy_property
from pymatgen.util.plotting import add_fig_kwargs, get_ax_fig_plt, get_ax3d_fig_plt, get_axarray_fig_plt
__all__ = [
"set_axlims",
"get_ax_fig_plt",
"get_ax3d_fig_plt",
"plot_array",
"ArrayPlotter",
"data_from_cplx_mode",
"Marker",
"plot_unit_cell",
]
def ax_append_title(ax, title, loc="center", fontsize=None):
"""Add title to previous ax.title. Return new title."""
prev_title = ax.get_title(loc=loc)
new_title = prev_title + title
ax.set_title(new_title, loc=loc, fontsize=fontsize)
return new_title
#def set_grid(fig, boolean):
# if hasattr(fig, "axes"):
# for ax in fig.axes:
# if ax.grid: ax.grid.set_visible(boolean)
# else:
# if ax.grid: ax.grid.set_visible(boolean)
def set_axlims(ax, lims, axname):
"""
Set the data limits for the axis ax.
Args:
lims: tuple(2) for (left, right), tuple(1) or scalar for left only.
axname: "x" for x-axis, "y" for y-axis.
Return: (left, right)
"""
left, right = None, None
if lims is None: return (left, right)
len_lims = None
try:
len_lims = len(lims)
except TypeError:
# Asumme Scalar
left = float(lims)
if len_lims is not None:
if len(lims) == 2:
left, right = lims[0], lims[1]
elif len(lims) == 1:
left = lims[0]
set_lim = getattr(ax, {"x": "set_xlim", "y": "set_ylim"}[axname])
set_lim(left, right)
return left, right
def set_visible(ax, boolean, *args):
"""
Hide/Show the artists of axis ax listed in args.
"""
if "legend" in args and ax.legend():
#handles, labels = ax.get_legend_handles_labels()
#if handles:
ax.legend().set_visible(boolean)
if "title" in args and ax.title:
ax.title.set_visible(boolean)
if "xlabel" in args and ax.xaxis.label:
ax.xaxis.label.set_visible(boolean)
if "ylabel" in args and ax.yaxis.label:
ax.yaxis.label.set_visible(boolean)
def rotate_ticklabels(ax, rotation, axname="x"):
"""Rotate the ticklables of axis ``ax``"""
if "x" in axname:
for tick in ax.get_xticklabels():
tick.set_rotation(rotation)
if "y" in axname:
for tick in ax.get_yticklabels():
tick.set_rotation(rotation)
def data_from_cplx_mode(cplx_mode, arr):
"""
Extract the data from the numpy array ``arr`` depending on the values of ``cplx_mode``.
Args:
cplx_mode: Possible values in ("re", "im", "abs", "angle")
"re" for the real part,
"im" for the imaginary part.
"abs" means that the absolute value of the complex number is shown.
"angle" will display the phase of the complex number in radians.
"""
if cplx_mode == "re": return arr.real
if cplx_mode == "im": return arr.imag
if cplx_mode == "abs": return np.abs(arr)
if cplx_mode == "angle": return np.angle(arr, deg=False)
raise ValueError("Unsupported mode `%s`" % str(cplx_mode))
@add_fig_kwargs
def plot_xy_with_hue(data, x, y, hue, decimals=None, ax=None,
xlims=None, ylims=None, fontsize=12, **kwargs):
"""
Plot y = f(x) relation for different values of `hue`.
Useful for convergence tests done wrt to two parameters.
Args:
data: |pandas-DataFrame| containing columns `x`, `y`, and `hue`.
x: Name of the column used as x-value
y: Name of the column used as y-value
hue: Variable that define subsets of the data, which will be drawn on separate lines
decimals: Number of decimal places to round `hue` columns. Ignore if None
ax: |matplotlib-Axes| or None if a new figure should be created.
xlims ylims: Set the data limits for the x(y)-axis. Accept tuple e.g. `(left, right)`
or scalar e.g. `left`. If left (right) is None, default values are used
fontsize: Legend fontsize.
kwargs: Keywork arguments are passed to ax.plot method.
Returns: |matplotlib-Figure|
"""
# Check here because pandas error messages are a bit criptic.
miss = [k for k in (x, y, hue) if k not in data]
if miss:
raise ValueError("Cannot find `%s` in dataframe.\nAvailable keys are: %s" % (str(miss), str(data.keys())))
# Truncate values in hue column so that we can group.
if decimals is not None:
data = data.round({hue: decimals})
ax, fig, plt = get_ax_fig_plt(ax=ax)
for key, grp in data.groupby(hue):
#xvals, yvals = grp[x], grp[y]
# Sort xs and rearrange ys
xy = np.array(sorted(zip(grp[x], grp[y]), key=lambda t: t[0]))
xvals, yvals = xy[:, 0], xy[:, 1]
label = "{} = {}".format(hue, key)
if not kwargs:
ax.plot(xvals, yvals, 'o-', label=label)
else:
ax.plot(xvals, yvals, label=label, **kwargs)
ax.grid(True)
ax.set_xlabel(x)
ax.set_ylabel(y)
set_axlims(ax, xlims, "x")
set_axlims(ax, ylims, "y")
ax.legend(loc="best", fontsize=fontsize, shadow=True)
return fig
@add_fig_kwargs
def plot_array(array, color_map=None, cplx_mode="abs", **kwargs):
"""
Use imshow for plotting 2D or 1D arrays.
Example::
plot_array(np.random.rand(10,10))
See <http://stackoverflow.com/questions/7229971/2d-grid-data-visualization-in-python>
Args:
array: Array-like object (1D or 2D).
color_map: color map.
cplx_mode:
Flag defining how to handle complex arrays. Possible values in ("re", "im", "abs", "angle")
"re" for the real part, "im" for the imaginary part.
"abs" means that the absolute value of the complex number is shown.
"angle" will display the phase of the complex number in radians.
Returns: |matplotlib-Figure|
"""
# Handle vectors
array = np.atleast_2d(array)
array = data_from_cplx_mode(cplx_mode, array)
import matplotlib as mpl
from matplotlib import pyplot as plt
if color_map is None:
# make a color map of fixed colors
color_map = mpl.colors.LinearSegmentedColormap.from_list('my_colormap',
['blue', 'black', 'red'], 256)
img = plt.imshow(array, interpolation='nearest', cmap=color_map, origin='lower')
# Make a color bar
plt.colorbar(img, cmap=color_map)
# Set grid
plt.grid(True, color='white')
fig = plt.gcf()
return fig
class ArrayPlotter(object):
def __init__(self, *labels_and_arrays):
"""
Args:
labels_and_arrays: List [("label1", arr1), ("label2", arr2")]
"""
self._arr_dict = collections.OrderedDict()
for label, array in labels_and_arrays:
self.add_array(label, array)
def __len__(self):
return len(self._arr_dict)
def __iter__(self):
return self._arr_dict.__iter__()
def keys(self):
return self._arr_dict.keys()
def items(self):
return self._arr_dict.items()
def add_array(self, label, array):
"""Add array with the given name."""
if label in self._arr_dict:
raise ValueError("%s is already in %s" % (label, list(self._arr_dict.keys())))
self._arr_dict[label] = array
def add_arrays(self, labels, arr_list):
"""
Add a list of arrays
Args:
labels: List of labels.
arr_list: List of arrays.
"""
assert len(labels) == len(arr_list)
for label, arr in zip(labels, arr_list):
self.add_array(label, arr)
@add_fig_kwargs
def plot(self, cplx_mode="abs", colormap="jet", fontsize=8, **kwargs):
"""
Args:
cplx_mode: "abs" for absolute value, "re", "im", "angle"
colormap: matplotlib colormap.
fontsize: legend and label fontsize.
Returns: |matplotlib-Figure|
"""
# Build grid of plots.
num_plots, ncols, nrows = len(self), 1, 1
if num_plots > 1:
ncols = 2
nrows = num_plots // ncols + (num_plots % ncols)
import matplotlib.pyplot as plt
fig, ax_mat = plt.subplots(nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False)
# Don't show the last ax if num_plots is odd.
if num_plots % ncols != 0: ax_mat[-1, -1].axis("off")
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import MultipleLocator
for ax, (label, arr) in zip(ax_mat.flat, self.items()):
data = data_from_cplx_mode(cplx_mode, arr)
# Use origin to place the [0, 0] index of the array in the lower left corner of the axes.
img = ax.matshow(data, interpolation='nearest', cmap=colormap, origin='lower', aspect="auto")
ax.set_title("(%s) %s" % (cplx_mode, label), fontsize=fontsize)
# Make a color bar for this ax
# Create divider for existing axes instance
# http://stackoverflow.com/questions/18266642/multiple-imshow-subplots-each-with-colorbar
divider3 = make_axes_locatable(ax)
# Append axes to the right of ax, with 10% width of ax
cax3 = divider3.append_axes("right", size="10%", pad=0.05)
# Create colorbar in the appended axes
# Tick locations can be set with the kwarg `ticks`
# and the format of the ticklabels with kwarg `format`
cbar3 = plt.colorbar(img, cax=cax3, ticks=MultipleLocator(0.2), format="%.2f")
# Remove xticks from ax
ax.xaxis.set_visible(False)
# Manually set ticklocations
#ax.set_yticks([0.0, 2.5, 3.14, 4.0, 5.2, 7.0])
# Set grid
ax.grid(True, color='white')
fig.tight_layout()
return fig
class Marker(collections.namedtuple("Marker", "x y s")):
"""
Stores the position and the size of the marker.
A marker is a list of tuple(x, y, s) where x, and y are the position
in the graph and s is the size of the marker.
Used for plotting purpose e.g. QP data, energy derivatives...
Example::
x, y, s = [1, 2, 3], [4, 5, 6], [0.1, 0.2, -0.3]
marker = Marker(x, y, s)
marker.extend((x, y, s))
"""
def __new__(cls, *xys):
"""Extends the base class adding consistency check."""
if not xys:
xys = ([], [], [])
return super(cls, Marker).__new__(cls, *xys)
if len(xys) != 3:
raise TypeError("Expecting 3 entries in xys got %d" % len(xys))
x = np.asarray(xys[0])
y = np.asarray(xys[1])
s = np.asarray(xys[2])
xys = (x, y, s)
for s in xys[-1]:
if np.iscomplex(s):
raise ValueError("Found ambiguous complex entry %s" % str(s))
return super(cls, Marker).__new__(cls, *xys)
def __bool__(self):
return bool(len(self.s))
__nonzero__ = __bool__
def extend(self, xys):
"""
Extend the marker values.
"""
if len(xys) != 3:
raise TypeError("Expecting 3 entries in xys got %d" % len(xys))
self.x.extend(xys[0])
self.y.extend(xys[1])
self.s.extend(xys[2])
lens = np.array((len(self.x), len(self.y), len(self.s)))
if np.any(lens != lens[0]):
raise TypeError("x, y, s vectors should have same lengths but got %s" % str(lens))
def posneg_marker(self):
"""
Split data into two sets: the first one contains all the points with positive size.
the first set contains all the points with negative size.
"""
pos_x, pos_y, pos_s = [], [], []
neg_x, neg_y, neg_s = [], [], []
for x, y, s in zip(self.x, self.y, self.s):
if s >= 0.0:
pos_x.append(x)
pos_y.append(y)
pos_s.append(s)
else:
neg_x.append(x)
neg_y.append(y)
neg_s.append(s)
return Marker(pos_x, pos_y, pos_s), Marker(neg_x, neg_y, neg_s)
class MplExpose(object): # pragma: no cover
"""
Example:
with MplExpose() as e:
e(obj.plot1(show=False))
e(obj.plot2(show=False))
"""
def __init__(self, slide_mode=False, slide_timeout=None, verbose=1):
"""
Args:
slide_mode: If true, iterate over figures. Default: Expose all figures at once.
slide_timeout: Close figure after slide-timeout seconds Block if None.
verbose: verbosity level
"""
self.figures = []
self.slide_mode = bool(slide_mode)
self.timeout_ms = slide_timeout
self.verbose = verbose
if self.timeout_ms is not None:
self.timeout_ms = int(self.timeout_ms * 1000)
assert self.timeout_ms >= 0
if self.verbose:
if self.slide_mode:
print("\nSliding matplotlib figures with slide timeout: %s [s]" % slide_timeout)
else:
print("\nLoading all mpl figures before showing them. Could take some time...")
self.start_time = time.time()
def __call__(self, obj):
"""
Add an object to MplExpose. Support mpl figure, list of figures or
generator yelding figures.
"""
import types
if isinstance(obj, (types.GeneratorType, list, tuple)):
for fig in obj:
self.add_fig(fig)
else:
self.add_fig(obj)
def add_fig(self, fig):
"""Add a matplotlib figure."""
if fig is None: return
if not self.slide_mode:
self.figures.append(fig)
else:
#print("Printing and closing", fig)
import matplotlib.pyplot as plt
if self.timeout_ms is not None:
# Creating a timer object
# timer calls plt.close after interval milliseconds to close the window.
timer = fig.canvas.new_timer(interval=self.timeout_ms)
timer.add_callback(plt.close, fig)
timer.start()
plt.show()
fig.clear()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Activated at the end of the with statement. """
self.expose()
def expose(self):
"""Show all figures. Clear figures if needed."""
if not self.slide_mode:
print("All figures in memory, elapsed time: %.3f s" % (time.time() - self.start_time))
import matplotlib.pyplot as plt
plt.show()
for fig in self.figures:
fig.clear()
def plot_unit_cell(lattice, ax=None, **kwargs):
"""
Adds the unit cell of the lattice to a matplotlib Axes3D
Args:
lattice: Lattice object
ax: matplotlib :class:`Axes3D` or None if a new figure should be created.
kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to black
and linewidth to 3.
Returns:
matplotlib figure and matplotlib ax
"""
ax, fig, plt = get_ax3d_fig_plt(ax)
if "color" not in kwargs:
kwargs["color"] = "k"
if "linewidth" not in kwargs:
kwargs["linewidth"] = 3
v = 8 * [None]
v[0] = lattice.get_cartesian_coords([0.0, 0.0, 0.0])
v[1] = lattice.get_cartesian_coords([1.0, 0.0, 0.0])
v[2] = lattice.get_cartesian_coords([1.0, 1.0, 0.0])
v[3] = lattice.get_cartesian_coords([0.0, 1.0, 0.0])
v[4] = lattice.get_cartesian_coords([0.0, 1.0, 1.0])
v[5] = lattice.get_cartesian_coords([1.0, 1.0, 1.0])
v[6] = lattice.get_cartesian_coords([1.0, 0.0, 1.0])
v[7] = lattice.get_cartesian_coords([0.0, 0.0, 1.0])
for i, j in ((0, 1), (1, 2), (2, 3), (0, 3), (3, 4), (4, 5), (5, 6),
(6, 7), (7, 4), (0, 7), (1, 6), (2, 5), (3, 4)):
ax.plot(*zip(v[i], v[j]), **kwargs)
return fig, ax
def plot_structure(structure, ax=None, to_unit_cell=False, alpha=0.7,
style="points+labels", color_scheme="VESTA", **kwargs):
"""
Plot structure with matplotlib (minimalistic version)
Args:
structure: Structure object
ax: matplotlib :class:`Axes3D` or None if a new figure should be created.
alpha: The alpha blending value, between 0 (transparent) and 1 (opaque)
to_unit_cell: True if sites should be wrapped into the first unit cell.
style: "points+labels" to show atoms sites with labels.
color_scheme: color scheme for atom types. Allowed values in ("Jmol", "VESTA")
Returns: |matplotlib-Figure|
"""
fig, ax = plot_unit_cell(structure.lattice, ax=ax, linewidth=1)
from pymatgen.analysis.molecule_structure_comparator import CovalentRadius
from pymatgen.vis.structure_vtk import EL_COLORS
xyzs, colors = np.empty((len(structure), 4)), []
for i, site in enumerate(structure):
symbol = site.specie.symbol
color = tuple(i / 255 for i in EL_COLORS[color_scheme][symbol])
radius = CovalentRadius.radius[symbol]
if to_unit_cell and hasattr(site, "to_unit_cell"): site = site.to_unit_cell
# Use cartesian coordinates.
x, y, z = site.coords
xyzs[i] = (x, y, z, radius)
colors.append(color)
if "labels" in style:
ax.text(x, y, z, symbol)
# The definition of sizes is not optimal because matplotlib uses points
# wherease we would like something that depends on the radius (5000 seems to give reasonable plots)
# For possibile approaches, see
# https://stackoverflow.com/questions/9081553/python-scatter-plot-size-and-style-of-the-marker/24567352#24567352
# https://gist.github.com/syrte/592a062c562cd2a98a83
if "points" in style:
x, y, z, s = xyzs.T.copy()
s = 5000 * s **2
ax.scatter(x, y, zs=z, s=s, c=colors, alpha=alpha) #facecolors="white", #edgecolors="blue"
ax.set_title(structure.composition.formula)
ax.set_axis_off()
return fig