Skip to content

Commit

Permalink
CLI: Fix bug in verdi data core.trajectory show for various formats (
Browse files Browse the repository at this point in the history
…#5394)

These minor bugs went unnoticed because the methods are wholly untested.
This is partly because they rely on additional Python modules or external
executables. For the formats that rely on external executables, i.e.,
`jmol` and `xcrysden`, the `subprocess.check_output` function is
monkeypatched to prevent the actual executable from being called. This
tests all code except for the actual external executable, which at least
gives coverage of our code.

The test for `mpl_pos` needed to be monkeypatched as well. This is
because the `_show_mpl_pos` method calls `plot_positions_xyz` which
imports `matplotlib.pyplot` and for some completely unknown reason, this
causes `tests/storage/psql_dos/test_backend.py::test_unload_profile` to
fail. For some reason, merely importing `matplotlib` (even here directly
in the test) will cause that test to claim that there still is something
holding on to a reference of an sqlalchemy session that it keeps track
of in the `sqlalchemy.orm.session._sessions` weak ref dictionary. Since
it is impossible to figure out why the hell importing matplotlib would
interact with sqlalchemy sessions, the function that does the import is
simply mocked out for now.

Co-authored-by: Sebastiaan Huber <mail@sphuber.net>
  • Loading branch information
ltalirz and sphuber committed Jul 11, 2023
1 parent 03c86d5 commit fd4c126
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 68 deletions.
74 changes: 22 additions & 52 deletions aiida/cmdline/commands/cmd_data/cmd_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,96 +12,66 @@
"""
import pathlib

import click

from aiida.cmdline.params import options
from aiida.cmdline.params.options.multivalue import MultipleValueOption
from aiida.cmdline.utils import echo
from aiida.common.exceptions import MultipleObjectsError

SHOW_OPTIONS = [
options.TRAJECTORY_INDEX(),
options.WITH_ELEMENTS(),
click.option('-c', '--contour', type=click.FLOAT, cls=MultipleValueOption, default=None, help='Isovalues to plot'),
click.option(
'--sampling-stepsize',
type=click.INT,
default=None,
help='Sample positions in plot every sampling_stepsize timestep'
),
click.option(
'--stepsize',
type=click.INT,
default=None,
help='The stepsize for the trajectory, set it higher to reduce number of points'
),
click.option('--mintime', type=click.INT, default=None, help='The time to plot from'),
click.option('--maxtime', type=click.INT, default=None, help='The time to plot to'),
click.option('--indices', type=click.INT, cls=MultipleValueOption, default=None, help='Show only these indices'),
click.option(
'--dont-block', 'block', is_flag=True, default=True, help="Don't block interpreter when showing plot."
),
]


def show_options(func):
for option in reversed(SHOW_OPTIONS):
func = option(func)

return func


def _show_jmol(exec_name, trajectory_list, **kwargs):

def has_executable(exec_name):
"""
:return: True if executable can be found in PATH, False otherwise.
"""
import shutil
return shutil.which(exec_name) is not None


def _show_jmol(exec_name, trajectory_list, **_kwargs):
"""
Plugin for jmol
"""
import subprocess
import tempfile

if not has_executable(exec_name):
echo.echo_critical(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.")

# pylint: disable=protected-access
with tempfile.NamedTemporaryFile(mode='w+b') as handle:
for trajectory in trajectory_list:
handle.write(trajectory._exportcontent('cif', **kwargs)[0])
handle.write(trajectory._exportcontent('cif')[0])
handle.flush()

try:
subprocess.check_output([exec_name, handle.name])
except subprocess.CalledProcessError:
# The program died: just print a message
echo.echo_error(f'the call to {exec_name} ended with an error.')
except OSError as err:
if err.errno == 2:
echo.echo_critical(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.")
else:
raise


def _show_xcrysden(exec_name, object_list, **kwargs):
def _show_xcrysden(exec_name, trajectory_list, **_kwargs):
"""
Plugin for xcrysden
"""
import subprocess
import tempfile

if len(object_list) > 1:
if len(trajectory_list) > 1:
raise MultipleObjectsError('Visualization of multiple trajectories is not implemented')
obj = object_list[0]
obj = trajectory_list[0]

if not has_executable(exec_name):
echo.echo_critical(f"No executable '{exec_name}' found.")

# pylint: disable=protected-access
with tempfile.NamedTemporaryFile(mode='w+b', suffix='.xsf') as tmpf:
tmpf.write(obj._exportcontent('xsf', **kwargs)[0])

tmpf.write(obj._exportcontent('xsf')[0])
tmpf.flush()

try:
subprocess.check_output([exec_name, '--xsf', tmpf.name])
except subprocess.CalledProcessError:
# The program died: just print a message
echo.echo_error(f'the call to {exec_name} ended with an error.')
except OSError as err:
if err.errno == 2:
echo.echo_critical(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.")
else:
raise


# pylint: disable=unused-argument
Expand Down
29 changes: 25 additions & 4 deletions aiida/cmdline/commands/cmd_data/cmd_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from aiida.cmdline.commands.cmd_data import cmd_show, verdi_data
from aiida.cmdline.commands.cmd_data.cmd_export import data_export, export_options
from aiida.cmdline.commands.cmd_data.cmd_list import data_list, list_options
from aiida.cmdline.commands.cmd_data.cmd_show import show_options
from aiida.cmdline.params import arguments, options, types
from aiida.cmdline.utils import decorators, echo

Expand Down Expand Up @@ -66,16 +65,38 @@ def trajectory_list(raw, past_days, groups, all_users):
@trajectory.command('show')
@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:core.array.trajectory',)))
@options.VISUALIZATION_FORMAT(type=click.Choice(VISUALIZATION_FORMATS), default='jmol')
@show_options
@options.TRAJECTORY_INDEX()
@options.WITH_ELEMENTS()
@click.option(
'-c', '--contour', type=click.FLOAT, cls=options.MultipleValueOption, default=None, help='Isovalues to plot'
)
@click.option(
'--sampling-stepsize',
type=click.INT,
default=None,
help='Sample positions in plot every sampling_stepsize timestep'
)
@click.option(
'--stepsize',
type=click.INT,
default=None,
help='The stepsize for the trajectory, set it higher to reduce number of points'
)
@click.option('--mintime', type=click.INT, default=None, help='The time to plot from')
@click.option('--maxtime', type=click.INT, default=None, help='The time to plot to')
@click.option(
'--indices', type=click.INT, cls=options.MultipleValueOption, default=None, help='Show only these indices'
)
@click.option('--dont-block', 'block', is_flag=True, default=True, help="Don't block interpreter when showing plot.")
@decorators.with_dbenv()
def trajectory_show(data, fmt):
def trajectory_show(data, fmt, **kwargs):
"""Visualize a trajectory."""
try:
show_function = getattr(cmd_show, f'_show_{fmt}')
except AttributeError:
echo.echo_critical(f'visualization format {fmt} is not supported')

show_function(fmt, data)
show_function(exec_name=fmt, trajectory_list=data, **kwargs)


@trajectory.command('export')
Expand Down
19 changes: 7 additions & 12 deletions aiida/orm/nodes/data/array/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,6 @@ def show_mpl_pos(self, **kwargs): # pylint: disable=too-many-locals
from ase.data.colors import cpk_colors as colors
else:
raise ValueError(f'Unknown color spec {colors}')
if kwargs:
raise ValueError(f'Unrecognized keyword {kwargs.keys()}')

if element_list is None:
# If not all elements are allowed
Expand Down Expand Up @@ -703,12 +701,8 @@ def show_mpl_heatmap(self, **kwargs): # pylint: disable=invalid-name,too-many-a
from mayavi import mlab
except ImportError:
raise ImportError(
'Unable to import the mayavi package, that is required to'
'use the plotting feature you requested. '
'Please install it first and then call this command again '
'(note that the installation of mayavi is quite complicated '
'and requires that you already installed the python numpy '
'package, as well as the vtk package'
'The plotting feature you requested requires the mayavi package.'
'Try `pip install mayavi` or consult the documentation.'
)
from ase.data import atomic_numbers
from ase.data.colors import jmol_colors
Expand Down Expand Up @@ -847,7 +841,7 @@ def plot_positions_XYZ( # pylint: disable=too-many-arguments,too-many-locals,in
dont_block=False,
mintime=None,
maxtime=None,
label_sparsity=10):
n_labels=10):
"""
Plot with matplotlib the positions of the coordinates of the atoms
over time for a trajectory
Expand All @@ -862,14 +856,14 @@ def plot_positions_XYZ( # pylint: disable=too-many-arguments,too-many-locals,in
:param dont_block: passed to plt.show() as ``block=not dont_block``
:param mintime: if specified, cut the time axis at the specified min value
:param maxtime: if specified, cut the time axis at the specified max value
:param label_sparsity: how often to put a label with the pair (t, coord)
:param n_labels: how many labels (t, coord) to put
"""
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
import numpy as np

tlim = [times[0], times[-1]]
index_range = [0, len(times)]
index_range = [0, len(times) - 1]
if mintime is not None:
tlim[0] = mintime
index_range[0] = np.argmax(times > mintime)
Expand All @@ -896,7 +890,8 @@ def plot_positions_XYZ( # pylint: disable=too-many-arguments,too-many-locals,in
plt.ylabel(r'Z Position $\left[{}\right]$'.format(positions_unit))
plt.xlabel(f'Time [{times_unit}]')
plt.xlim(*tlim)
sparse_indices = np.linspace(*index_range, num=label_sparsity, dtype=int)
n_labels = np.minimum(n_labels, len(times)) # don't need more labels than times
sparse_indices = np.linspace(*index_range, num=n_labels, dtype=int)

for index, traj in enumerate(trajectories):
if index not in indices_to_show:
Expand Down
52 changes: 52 additions & 0 deletions tests/cmdline/commands/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
cmd_cif,
cmd_dict,
cmd_remote,
cmd_show,
cmd_singlefile,
cmd_structure,
cmd_trajectory,
Expand All @@ -37,6 +38,15 @@
from tests.static import STATIC_DIR


def has_mayavi() -> bool:
"""Return whether the ``mayavi`` module can be imported."""
try:
import mayavi # pylint: disable=unused-import
except ImportError:
return False
return True


class DummyVerdiDataExportable:
"""Test exportable data objects."""

Expand Down Expand Up @@ -517,6 +527,48 @@ def test_export(self, output_flag, tmp_path):
new_supported_formats = list(cmd_trajectory.EXPORT_FORMATS)
self.data_export_test(TrajectoryData, self.pks, new_supported_formats, output_flag, tmp_path)

@pytest.mark.parametrize(
'fmt', (
pytest.param(
'jmol', marks=pytest.mark.skipif(not cmd_show.has_executable('jmol'), reason='No jmol executable.')
),
pytest.param(
'xcrysden',
marks=pytest.mark.skipif(not cmd_show.has_executable('xcrysden'), reason='No xcrysden executable.')
),
pytest.param(
'mpl_heatmap', marks=pytest.mark.skipif(not has_mayavi(), reason='Package `mayavi` not installed.')
), pytest.param('mpl_pos')
)
)
def test_trajectoryshow(self, fmt, monkeypatch, run_cli_command):
"""Test showing the trajectory data in different formats"""
trajectory_pk = self.pks[DummyVerdiDataListable.NODE_ID_STR]
options = ['--format', fmt, str(trajectory_pk), '--dont-block']

def mock_check_output(options):
assert isinstance(options, list)
assert options[0] == fmt

if fmt in ['jmol', 'xcrysden']:
# This is called by the ``_show_jmol`` and ``_show_xcrysden`` implementations. We want to test just the
# function but not the actual commands through a sub process. Note that this mock needs to happen only for
# these specific formats, because ``matplotlib`` used in the others _also_ calls ``subprocess.check_output``
monkeypatch.setattr(sp, 'check_output', mock_check_output)

if fmt in ['mpl_pos']:
# This has to be mocked because ``plot_positions_xyz`` imports ``matplotlib.pyplot`` and for some completely
# unknown reason, causes ``tests/storage/psql_dos/test_backend.py::test_unload_profile`` to fail. For some
# reason, merely importing ``matplotlib`` (even here directly in the test) will cause that test to claim
# that there still is something holding on to a reference of an sqlalchemy session that it keeps track of
# in the ``sqlalchemy.orm.session._sessions`` weak ref dictionary. Since it is impossible to figure out why
# the hell importing matplotlib would interact with sqlalchemy sessions, the function that does the import
# is simply mocked out for now.
from aiida.orm.nodes.data.array import trajectory
monkeypatch.setattr(trajectory, 'plot_positions_XYZ', lambda *args, **kwargs: None)

run_cli_command(cmd_trajectory.trajectory_show, options, use_subprocess=False)


class TestVerdiDataStructure(DummyVerdiDataListable, DummyVerdiDataExportable):
"""Test verdi data core.structure."""
Expand Down

0 comments on commit fd4c126

Please sign in to comment.