Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 58 additions & 70 deletions testsuite/MDAnalysisTests/coordinates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def test_forces_remove(self):
with pytest.raises(NoDataError):
getattr(ts, 'forces')

def _empty_ts(self):
def test_check_ts(self):
with pytest.raises(ValueError):
self.Timestep.from_coordinates(None, None, None)

Expand All @@ -714,7 +714,10 @@ def _from_coords(self, p, v, f):

return ts

def _check_from_coordinates(self, p, v, f):
@pytest.mark.parametrize('p, v, f', filter(any,
itertools.product([True, False],
repeat=3)))
def test_from_coordinates(self, p, v, f):
ts = self._from_coords(p, v, f)

if p:
Expand All @@ -733,15 +736,6 @@ def _check_from_coordinates(self, p, v, f):
with pytest.raises(NoDataError):
getattr(ts, 'forces')

def test_from_coordinates(self):
# Check all combinations of creating a Timestep from data
# 8 possibilites of with and without 3 data categories
for p, v, f in itertools.product([True, False], repeat=3):
if not any([p, v, f]):
yield self._empty_ts
else:
yield self._check_from_coordinates, p, v, f

def test_from_coordinates_mismatch(self):
velo = self.refvel[:2]

Expand All @@ -752,18 +746,6 @@ def test_from_coordinates_nodata(self):
with pytest.raises(ValueError):
self.Timestep.from_coordinates()

def _check_from_timestep(self, p, v, f):
ts = self._from_coords(p, v, f)
ts2 = self.Timestep.from_timestep(ts)

assert_timestep_almost_equal(ts, ts2)

def test_from_timestep(self):
for p, v, f in itertools.product([True, False], repeat=3):
if not any([p, v, f]):
continue
yield self._check_from_timestep, p, v, f

# Time related tests
def test_supply_dt(self):
# Check that this gets stored in data properly
Expand Down Expand Up @@ -942,61 +924,60 @@ def test_copy(self, func, ts):
ts = u.trajectory.ts
func(self, self.name, ts)

def test_copy_slice(self):
for p, v, f in itertools.product([True, False], repeat=3):
if not any([p, v, f]):
continue
ts = self._from_coords(p, v, f)
yield self._check_copy, self.name, ts
yield self._check_independent, self.name, ts
yield self._check_copy_slice_indices, self.name, ts
yield self._check_copy_slice_slice, self.name, ts

def _check_bad_slice(self, p, v, f):
ts = self._from_coords(p, v, f)
@pytest.fixture(params=filter(any,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this filter do the same as the other itertools.product call above? If so can we use just one way of doing this (don't care which)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filter(any, itertools.product([True, False], repeat=3))

is equivalent to

[pvf for pvf in itertools.product([True, False], repeat=3) if any(pvf)]

I personally like the filter call here because it is the clearest to read with the indentations due to being a decorator argument. But the problem is that we can't use tuple unpacking like in the list-comprehension above. I found the tuple unpacking more readable in the above example so I left it.

So yeah they are slightly differently written but personally this gives the best readability in each case. I can unify it though if you like. What would you prefer the filter or list-comprehension?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, lets use filter then. It's just confusing to see two ways for this, implies that maybe there's a difference

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed.

itertools.product([True, False], repeat=3)))
def some_ts(self, request):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a better name for this fixture. Also this can only be a function scope fixture.

p, v, f = request.param
return self._from_coords(p, v, f)

@pytest.mark.parametrize('func', [
_check_copy,
_check_independent,
_check_copy_slice_indices,
_check_copy_slice_slice,
_check_npint_slice
])
def test_copy_slice(self, func, some_ts):
func(self, self.name, some_ts)

def test_bad_slice(self, some_ts):
sl = ['this', 'is', 'silly']
with pytest.raises(TypeError):
ts.copy_slice(sl)
some_ts.copy_slice(sl)

def test_bad_copy_slice(self):
for p, v, f in itertools.product([True, False], repeat=3):
if not any([p, v, f]):
continue
yield self._check_bad_slice, p, v, f
def test_from_timestep(self, some_ts):
ts = some_ts
ts2 = self.Timestep.from_timestep(ts)

assert_timestep_almost_equal(ts, ts2)

def _get_pos(self):
# Get generic reference positions
return np.arange(30).reshape(10, 3) * 1.234

def _check_ts_equal(self, a, b, err_msg):
assert_(a == b, err_msg)
assert_(b == a, err_msg)

def test_check_equal(self):
for p, v, f in itertools.product([True, False], repeat=3):
if not any([p, v, f]):
continue

ts1 = self.Timestep(self.size,
positions=p,
velocities=v,
forces=f)
ts2 = self.Timestep(self.size,
positions=p,
velocities=v,
forces=f)
if p:
ts1.positions = self.refpos.copy()
ts2.positions = self.refpos.copy()
if v:
ts1.velocities = self.refvel.copy()
ts2.velocities = self.refvel.copy()
if f:
ts1.forces = self.reffor.copy()
ts2.forces = self.reffor.copy()

yield (self._check_ts_equal, ts1, ts2,
'Failed on {0}'.format(self.name))
@pytest.mark.parametrize('p, v, f', filter(any,
itertools.product([True, False],
repeat=3)))
def test_check_equal(self, p, v, f):
ts1 = self.Timestep(self.size,
positions=p,
velocities=v,
forces=f)
ts2 = self.Timestep(self.size,
positions=p,
velocities=v,
forces=f)
if p:
ts1.positions = self.refpos.copy()
ts2.positions = self.refpos.copy()
if v:
ts1.velocities = self.refvel.copy()
ts2.velocities = self.refvel.copy()
if f:
ts1.forces = self.reffor.copy()
ts2.forces = self.reffor.copy()

assert_timestep_equal(ts1, ts2)

def test_wrong_class_equality(self):
ts1 = self.Timestep(self.size)
Expand Down Expand Up @@ -1096,6 +1077,13 @@ def test_check_wrong_forces_equality(self):
assert_(ts2 != ts1)


def assert_timestep_equal(A, B, msg=''):
""" assert that two timesteps are exactly equal and commutative
"""
assert A == B, msg
assert B == A, msg


def assert_timestep_almost_equal(A, B, decimal=6, verbose=True):
if not isinstance(A, Timestep):
raise AssertionError('A is not of type Timestep')
Expand Down
4 changes: 2 additions & 2 deletions testsuite/MDAnalysisTests/coordinates/test_timestep_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
PDBQT_input, PQR, PRM, TRJ, PRMncdf,
NCDF, TRZ_psf, TRZ)

from MDAnalysisTests.coordinates.base import BaseTimestepTest
from MDAnalysisTests.coordinates.base import BaseTimestepTest, assert_timestep_equal
import pytest

# Can add in custom tests for a given Timestep here!
Expand All @@ -54,7 +54,7 @@ def test_other_timestep(self, otherTS):
ts1.positions = self._get_pos()
ts2 = otherTS(10)
ts2.positions = self._get_pos()
self._check_ts_equal(ts1, ts2, "Failed on {0}".format(otherTS))
assert_timestep_equal(ts1, ts2, "Failed on {0}".format(otherTS))


# TODO: Merge this into generic Reader tests
Expand Down