Skip to content

Commit

Permalink
[Review Needed]Pytest Style coordinates/test_chainreader.py (#1640)
Browse files Browse the repository at this point in the history
* Pytest Style coordinates/test_chainreader.py

* Drop trajectory fixture
  • Loading branch information
utkbansal authored and kain88-de committed Aug 24, 2017
1 parent 56776ef commit f08aa4f
Showing 1 changed file with 80 additions and 91 deletions.
171 changes: 80 additions & 91 deletions testsuite/MDAnalysisTests/coordinates/test_chainreader.py
Expand Up @@ -21,147 +21,136 @@
#
from __future__ import division, absolute_import

import numpy as np
import os
from six.moves import zip

from numpy.testing import (assert_equal, assert_array_equal,
assert_almost_equal, dec)
from unittest import TestCase
import numpy as np

import pytest

from numpy.testing import (assert_equal, assert_almost_equal)

import MDAnalysis as mda
from MDAnalysisTests.datafiles import (PDB, PSF, CRD, DCD,
GRO, XTC, TRR, PDB_small, PDB_closed)
from MDAnalysisTests import tempdir



class TestChainReader(TestCase):
def setUp(self):
self.universe = mda.Universe(PSF,
[DCD, CRD, DCD, CRD, DCD, CRD, CRD])
self.trajectory = self.universe.trajectory
self.prec = 3
# dummy output DCD file
self.tmpdir = tempdir.TempDir()
self.outfile = os.path.join(self.tmpdir.name, 'chain-reader.dcd')

def tearDown(self):
try:
os.unlink(self.outfile)
except OSError:
pass
del self.universe
del self.tmpdir

def test_next_trajectory(self):
self.trajectory.rewind()
self.trajectory.next()
assert_equal(self.trajectory.ts.frame, 1, "loading frame 2")

def test_n_atoms(self):
assert_equal(self.universe.trajectory.n_atoms, 3341,


class TestChainReader(object):
prec = 3

@pytest.fixture()
def universe(self):
return mda.Universe(PSF,
[DCD, CRD, DCD, CRD, DCD, CRD, CRD])

def test_next_trajectory(self, universe):
universe.trajectory.rewind()
universe.trajectory.next()
assert_equal(universe.trajectory.ts.frame, 1, "loading frame 2")

def test_n_atoms(self, universe):
assert_equal(universe.trajectory.n_atoms, 3341,
"wrong number of atoms")

def test_n_frames(self):
assert_equal(self.universe.trajectory.n_frames, 3 * 98 + 4,
def test_n_frames(self, universe):
assert_equal(universe.trajectory.n_frames, 3 * 98 + 4,
"wrong number of frames in chained dcd")

def test_iteration(self):
for ts in self.trajectory:
def test_iteration(self, universe):
for ts in universe.trajectory:
pass # just forward to last frame
assert_equal(
self.trajectory.n_frames - 1, ts.frame,
universe.trajectory.n_frames - 1, ts.frame,
"iteration yielded wrong number of frames ({0:d}), "
"should be {1:d}".format(ts.frame, self.trajectory.n_frames))
"should be {1:d}".format(ts.frame, universe.trajectory.n_frames))

def test_jump_lastframe_trajectory(self):
self.trajectory[-1]
assert_equal(self.trajectory.ts.frame + 1, self.trajectory.n_frames,
def test_jump_lastframe_trajectory(self, universe):
universe.trajectory[-1]
assert_equal(universe.trajectory.ts.frame + 1, universe.trajectory.n_frames,
"indexing last frame with trajectory[-1]")

def test_slice_trajectory(self):
frames = [ts.frame for ts in self.trajectory[5:17:3]]
def test_slice_trajectory(self, universe):
frames = [ts.frame for ts in universe.trajectory[5:17:3]]
assert_equal(frames, [5, 8, 11, 14], "slicing dcd [5:17:3]")

def test_full_slice(self):
trj_iter = self.universe.trajectory[:]
def test_full_slice(self, universe):
trj_iter = universe.trajectory[:]
frames = [ts.frame for ts in trj_iter]
assert_equal(frames, np.arange(self.universe.trajectory.n_frames))
assert_equal(frames, np.arange(universe.trajectory.n_frames))

def test_frame_numbering(self):
self.trajectory[98] # index is 0-based and frames are 0-based
assert_equal(self.universe.trajectory.frame, 98, "wrong frame number")
def test_frame_numbering(self, universe):
universe.trajectory[98] # index is 0-based and frames are 0-based
assert_equal(universe.trajectory.frame, 98, "wrong frame number")

def test_frame(self):
self.trajectory[0]
coord0 = self.universe.atoms.positions.copy()
def test_frame(self, universe):
universe.trajectory[0]
coord0 = universe.atoms.positions.copy()
# forward to frame where we repeat original dcd again:
# dcd:0..97 crd:98 dcd:99..196
self.trajectory[99]
assert_array_equal(
self.universe.atoms.positions, coord0,
universe.trajectory[99]
assert_equal(
universe.atoms.positions, coord0,
"coordinates at frame 1 and 100 should be the same!")

def test_time(self):
self.trajectory[30] # index and frames 0-based
assert_almost_equal(self.universe.trajectory.time,
def test_time(self, universe):
universe.trajectory[30] # index and frames 0-based
assert_almost_equal(universe.trajectory.time,
30.0,
5,
err_msg="Wrong time of frame")


def test_write_dcd(self):
def test_write_dcd(self, universe, tmpdir):
"""test that ChainReader written dcd (containing crds) is correct
(Issue 81)"""
with mda.Writer(self.outfile, self.universe.atoms.n_atoms) as W:
for ts in self.universe.trajectory:
W.write(self.universe)
self.universe.trajectory.rewind()
u = mda.Universe(PSF, self.outfile)
for (ts_orig, ts_new) in zip(self.universe.trajectory,
outfile = str(tmpdir) + "chain-reader.dcd"
with mda.Writer(outfile, universe.atoms.n_atoms) as W:
for ts in universe.trajectory:
W.write(universe)
universe.trajectory.rewind()
u = mda.Universe(PSF, outfile)
for (ts_orig, ts_new) in zip(universe.trajectory,
u.trajectory):
assert_almost_equal(
ts_orig._pos,
ts_new._pos,
self.prec,
err_msg="Coordinates disagree at frame {0:d}".format(ts_orig.frame))
err_msg="Coordinates disagree at frame {0:d}".format(
ts_orig.frame))


class TestChainReaderCommonDt(object):
common_dt = 100.0
prec = 3

class TestChainReaderCommonDt(TestCase):
def setUp(self):
self.common_dt = 100.0
self.universe = mda.Universe(PSF,
[DCD, CRD, DCD, CRD, DCD, CRD, CRD],
dt=self.common_dt)
self.trajectory = self.universe.trajectory
self.prec = 3
@pytest.fixture()
def trajectory(self):
universe = mda.Universe(PSF,
[DCD, CRD, DCD, CRD, DCD, CRD, CRD],
dt=self.common_dt)
return universe.trajectory

def test_time(self):
def test_time(self, trajectory):
# We test this for the beginning, middle and end of the trajectory.
for frame_n in (0, self.trajectory.n_frames // 2, -1):
self.trajectory[frame_n]
assert_almost_equal(self.trajectory.time,
self.trajectory.frame*self.common_dt,
5,
err_msg="Wrong time for frame {0:d}".format(frame_n) )
for frame_n in (0, trajectory.n_frames // 2, -1):
trajectory[frame_n]
assert_almost_equal(trajectory.time,
trajectory.frame * self.common_dt,
5,
err_msg="Wrong time for frame {0:d}".format(
frame_n))


class TestChainReaderFormats(TestCase):
class TestChainReaderFormats(object):
"""Test of ChainReader with explicit formats (Issue 76)."""

@staticmethod
def test_set_all_format_tuples():
def test_set_all_format_tuples(self):
universe = mda.Universe(GRO, [(PDB, 'pdb'), (XTC, 'xtc'),
(TRR, 'trr')])
assert_equal(universe.trajectory.n_frames, 21)

@staticmethod
def test_set_one_format_tuple():
def test_set_one_format_tuple(self):
universe = mda.Universe(PSF, [(PDB_small, 'pdb'), DCD])
assert_equal(universe.trajectory.n_frames, 99)

@staticmethod
def test_set_all_formats():
def test_set_all_formats(self):
universe = mda.Universe(PSF, [PDB_small, PDB_closed], format='pdb')
assert_equal(universe.trajectory.n_frames, 2)

0 comments on commit f08aa4f

Please sign in to comment.