Skip to content

Commit

Permalink
Use fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
utkbansal committed Jun 27, 2017
1 parent ea608e8 commit 853f646
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 135 deletions.
250 changes: 128 additions & 122 deletions testsuite/MDAnalysisTests/auxiliary/base.py
Expand Up @@ -20,6 +20,8 @@
# J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787
#
from __future__ import absolute_import

import pytest
from six.moves import range
import numpy as np
from numpy.testing import (assert_equal, assert_raises, assert_almost_equal,
Expand All @@ -34,6 +36,7 @@ def test_get_bad_auxreader_format_raises_ValueError():
# should raise a ValueError when no AuxReaders with match the specified format
mda.auxiliary.core.get_auxreader_for(format='bad-format')


class BaseAuxReference(object):
## assumes the reference auxiliary data has 5 steps, with three values
## for each step: i, 2*i and 2^i, where i is the step number.
Expand Down Expand Up @@ -126,243 +129,246 @@ def format_data(self, data):
return np.array(data)




class BaseAuxReaderTest(object):
def __init__(self, reference):
self.ref = reference
self.reader = self.ref.reader(self.ref.testdata, initial_time=self.ref.initial_time,
dt=self.ref.dt, auxname=self.ref.name,
time_selector=None, data_selector=None)
# def __init__(self, reference):
# self.ref = reference
# self.reader = self.ref.reader(self.ref.testdata, initial_time=self.ref.initial_time,
# dt=self.ref.dt, auxname=self.ref.name,
# time_selector=None, data_selector=None)

def tearDown(self):
del self.reader

def test_n_steps(self):
assert_equal(len(self.reader), self.ref.n_steps,
def test_n_steps(self, ref, reader):
assert_equal(len(reader), ref.n_steps,
"number of steps does not match")

def test_dt(self):
assert_equal(self.reader.dt, self.ref.dt,
def test_dt(self, ref, reader):
assert_equal(reader.dt, ref.dt,
"dt does not match")

def test_initial_time(self):
assert_equal(self.reader.initial_time, self.ref.initial_time,
def test_initial_time(self, ref, reader):
assert_equal(reader.initial_time, ref.initial_time,
"initial time does not match")

def test_first_step(self):
def test_first_step(self, ref, reader):
# on first loading we should start at step 0
assert_auxstep_equal(self.reader.auxstep, self.ref.auxsteps[0])
assert_auxstep_equal(reader.auxstep, ref.auxsteps[0])

def test_next(self):
def test_next(self, ref, reader):
# should take us to step 1
next(self.reader)
assert_auxstep_equal(self.reader.auxstep, self.ref.auxsteps[1])
next(reader)
assert_auxstep_equal(reader.auxstep, ref.auxsteps[1])

def test_rewind(self):
def test_rewind(self, ref, reader):
# move to step 1...
self.reader.next()
reader.next()
# now rewind should read step 0
self.reader.rewind()
assert_auxstep_equal(self.reader.auxstep, self.ref.auxsteps[0])
reader.rewind()
assert_auxstep_equal(reader.auxstep, ref.auxsteps[0])

def test_move_to_step(self):
def test_move_to_step(self, ref, reader):
# should take us to step 3
self.reader[3]
assert_auxstep_equal(self.reader.auxstep, self.ref.auxsteps[3])
reader[3]
assert_auxstep_equal(reader.auxstep, ref.auxsteps[3])

def test_last_step(self):
def test_last_step(self, ref, reader):
# should take us to the last step
self.reader[-1]
assert_auxstep_equal(self.reader.auxstep, self.ref.auxsteps[-1])
reader[-1]
assert_auxstep_equal(reader.auxstep, ref.auxsteps[-1])

def test_next_past_last_step_raises_StopIteration(self):
def test_next_past_last_step_raises_StopIteration(self, ref, reader):
# should take us to the last step
self.reader[-1]
reader[-1]
# if we try to move to next step from here, should raise StopIteration
assert_raises(StopIteration, self.reader.next)
assert_raises(StopIteration, reader.next)

@raises(IndexError)
def test_move_to_invalid_step_raises_IndexError(self):
# last step is number n_steps -1 ; if we try move to step number
def test_move_to_invalid_step_raises_IndexError(self, ref, reader):
# last step is number n_steps -1 ; if we try move to step number
# n_steps we should get a ValueError
self.reader[self.ref.n_steps]
with pytest.raises(IndexError):
reader[ref.n_steps]

@raises(ValueError)
def test_invalid_step_to_time_raises_ValueError(self):
# last step is number n_steps-1; if we try to run step_to_time on
def test_invalid_step_to_time_raises_ValueError(self, ref, reader):
# last step is number n_steps-1; if we try to run step_to_time on
# step n_steps we should get a ValueError
self.reader.step_to_time(self.reader.n_steps)
with pytest.raises(ValueError):
reader.step_to_time(reader.n_steps)

def test_iter(self):
for i, val in enumerate(self.reader):
assert_auxstep_equal(val, self.ref.auxsteps[i])
def test_iter(self,ref, reader):
for i, val in enumerate(reader):
assert_auxstep_equal(val, ref.auxsteps[i])

def test_iter_list(self):
def test_iter_list(self, ref, reader):
# test using __getitem__ with a list
for i, val in enumerate(self.reader[self.ref.iter_list]):
assert_auxstep_equal(val, self.ref.iter_list_auxsteps[i])
for i, val in enumerate(reader[ref.iter_list]):
assert_auxstep_equal(val, ref.iter_list_auxsteps[i])


def test_iter_slice(self):
def test_iter_slice(self, ref, reader):
# test using __getitem__ with a slice
for i, val in enumerate(self.reader[self.ref.iter_slice]):
assert_auxstep_equal(val, self.ref.iter_slice_auxsteps[i])
for i, val in enumerate(reader[ref.iter_slice]):
assert_auxstep_equal(val, ref.iter_slice_auxsteps[i])

@raises(IndexError)
def test_slice_start_after_stop_raises_IndexError(self):
def test_slice_start_after_stop_raises_IndexError(self, ref, reader):
#should raise IndexError if start frame after end frame
self.reader[2:1]
with pytest.raises(IndexError):
reader[2:1]

@raises(IndexError)
def test_slice_out_of_range_raises_IndexError(self):
def test_slice_out_of_range_raises_IndexError(self, ref, reader):
# should raise IndexError if indices our of range
self.reader[self.ref.n_steps:]
with pytest.raises(IndexError):
reader[ref.n_steps:]


@raises(TypeError)
def test_slice_non_int_raises_TypeError(self):
def test_slice_non_int_raises_TypeError(self, ref, reader):
# should raise TypeError if try pass in non-integer to slice
self.reader['a':]
with pytest.raises(TypeError):
reader['a':]

@raises(ValueError)
def test_bad_represent_raises_ValueError(self):
# if we try to set represent_ts_as to something not listed as a
def test_bad_represent_raises_ValueError(self, ref, reader):
# if we try to set represent_ts_as to something not listed as a
# valid option, we should get a ValueError
self.reader.represent_ts_as = 'invalid-option'
with pytest.raises(ValueError):
reader.represent_ts_as = 'invalid-option'

def test_time_selector(self):
def test_time_selector(self, ref, reader):
# reload the reader, passing a time selector
self.reader = self.ref.reader(self.ref.testdata,
time_selector = self.ref.time_selector)
reader = ref.reader(ref.testdata,
time_selector = ref.time_selector)
# time should still match reference time for each step
for i, val in enumerate(self.reader):
assert_equal(val.time, self.ref.select_time_ref[i],
for i, val in enumerate(reader):
assert_equal(val.time, ref.select_time_ref[i],
"time for step {} does not match".format(i))

def test_data_selector(self):
def test_data_selector(self, ref, reader):
# reload reader, passing in a data selector
self.reader = self.ref.reader(self.ref.testdata,
data_selector=self.ref.data_selector)
reader = ref.reader(ref.testdata,
data_selector=ref.data_selector)
# data should match reference data for each step
for i, val in enumerate(self.reader):
assert_equal(val.data, self.ref.select_data_ref[i],
for i, val in enumerate(reader):
assert_equal(val.data, ref.select_data_ref[i],
"data for step {0} does not match".format(i))

def test_no_constant_dt(self):
def test_no_constant_dt(self, ref, reader):
## assume we can select time...
# reload reader, without assuming constant dt
self.reader = self.ref.reader(self.ref.testdata,
time_selector=self.ref.time_selector,
reader = ref.reader(ref.testdata,
time_selector=ref.time_selector,
constant_dt=False)
# time should match reference for selecting time, for each step
for i, val in enumerate(self.reader):
assert_equal(val.time, self.ref.select_time_ref[i],
for i, val in enumerate(reader):
assert_equal(val.time, ref.select_time_ref[i],
"data for step {} does not match".format(i))

@raises(ValueError)
def test_update_ts_without_auxname_raises_ValueError(self):
def test_update_ts_without_auxname_raises_ValueError(self, ref, reader):
# reload reader without auxname
self.reader = self.ref.reader(self.ref.testdata)
ts = self.ref.lower_freq_ts
self.reader.update_ts(ts)
with pytest.raises(ValueError):
reader = ref.reader(ref.testdata)
ts = ref.lower_freq_ts
reader.update_ts(ts)

def test_read_lower_freq_timestep(self):
def test_read_lower_freq_timestep(self, ref, reader):
# test reading a timestep with lower frequency
ts = self.ref.lower_freq_ts
self.reader.update_ts(ts)
ts = ref.lower_freq_ts
reader.update_ts(ts)
# check the value set in ts is as we expect
assert_almost_equal(ts.aux.test, self.ref.lowf_closest_rep,
assert_almost_equal(ts.aux.test, ref.lowf_closest_rep,
err_msg="Representative value in ts.aux does not match")

def test_represent_as_average(self):
def test_represent_as_average(self, ref, reader):
# test the 'average' option for 'represent_ts_as'
# reset the represent method to 'average'...
self.reader.represent_ts_as = 'average'
reader.represent_ts_as = 'average'
# read timestep; use the low freq timestep
ts = self.ref.lower_freq_ts
self.reader.update_ts(ts)
ts = ref.lower_freq_ts
reader.update_ts(ts)
# check the representative value set in ts is as expected
assert_almost_equal(ts.aux.test, self.ref.lowf_average_rep,
assert_almost_equal(ts.aux.test, ref.lowf_average_rep,
err_msg="Representative value does not match when "
"using with option 'average'")

def test_represent_as_average_with_cutoff(self):
def test_represent_as_average_with_cutoff(self, ref, reader):
# test the 'represent_ts_as' 'average' option when we have a cutoff set
# set the cutoff...
self.reader.cutoff = self.ref.cutoff
reader.cutoff = ref.cutoff
# read timestep; use the low frequency timestep
ts = self.ref.lower_freq_ts
self.reader.update_ts(ts)
ts = ref.lower_freq_ts
reader.update_ts(ts)
# check representative value set in ts is as expected
assert_almost_equal(ts.aux.test, self.ref.lowf_cutoff_average_rep,
assert_almost_equal(ts.aux.test, ref.lowf_cutoff_average_rep,
err_msg="Representative value does not match when "
"applying cutoff")

def test_read_offset_timestep(self):
def test_read_offset_timestep(self, ref, reader):
# try reading a timestep offset from auxiliary
ts = self.ref.offset_ts
self.reader.update_ts(ts)
assert_almost_equal(ts.aux.test, self.ref.offset_closest_rep,
ts = ref.offset_ts
reader.update_ts(ts)
assert_almost_equal(ts.aux.test, ref.offset_closest_rep,
err_msg="Representative value in ts.aux does not match")

def test_represent_as_closest_with_cutoff(self):
def test_represent_as_closest_with_cutoff(self, ref, reader):
# test the 'represent_ts_as' 'closest' option when we have a cutoff set
# set the cutoff...
self.reader.cutoff = self.ref.cutoff
reader.cutoff = ref.cutoff
# read timestep; use the offset timestep
ts = self.ref.offset_ts
self.reader.update_ts(ts)
ts = ref.offset_ts
reader.update_ts(ts)
# check representative value set in ts is as expected
assert_almost_equal(ts.aux.test, self.ref.offset_cutoff_closest_rep,
assert_almost_equal(ts.aux.test, ref.offset_cutoff_closest_rep,
err_msg="Representative value does not match when "
"applying cutoff")

def test_read_higher_freq_timestep(self):
def test_read_higher_freq_timestep(self, ref, reader):
# try reading a timestep with higher frequency
ts = self.ref.higher_freq_ts
self.reader.update_ts(ts)
assert_almost_equal(ts.aux.test, self.ref.highf_rep,
ts = ref.higher_freq_ts
reader.update_ts(ts)
assert_almost_equal(ts.aux.test, ref.highf_rep,
err_msg="Representative value in ts.aux does not match")

def test_get_auxreader_for(self):
def test_get_auxreader_for(self, ref, reader):
# check guesser gives us right reader
reader = mda.auxiliary.core.get_auxreader_for(self.ref.testdata)
assert_equal(reader, self.ref.reader)
reader = mda.auxiliary.core.get_auxreader_for(ref.testdata)
assert_equal(reader, ref.reader)

def test_iterate_through_trajectory(self):
def test_iterate_through_trajectory(self, ref, reader):
# add to trajectory
u = mda.Universe(COORDINATES_TOPOLOGY, COORDINATES_XTC)
u.trajectory.add_auxiliary('test', self.ref.testdata)
u.trajectory.add_auxiliary('test', ref.testdata)
# check the representative values of aux for each frame are as expected
# trajectory here has same dt, offset; so there's a direct correspondence
# between frames and steps
for i, ts in enumerate(u.trajectory):
assert_equal(ts.aux.test, self.ref.auxsteps[i].data,
assert_equal(ts.aux.test, ref.auxsteps[i].data,
"representative value does not match when iterating through "
"all trajectory timesteps")
u.trajectory.close()

def test_iterate_as_auxiliary_from_trajectory(self):
def test_iterate_as_auxiliary_from_trajectory(self, ref, reader):
# add to trajectory
u = mda.Universe(COORDINATES_TOPOLOGY, COORDINATES_XTC)
u.trajectory.add_auxiliary('test', self.ref.testdata)
u.trajectory.add_auxiliary('test', ref.testdata)
# check representative values of aux for each frame are as expected
# trahectory here has same dt, offset, so there's a direct correspondence
# between frames and steps, and iter_as_aux will run through all frames
for i, ts in enumerate(u.trajectory.iter_as_aux('test')):
assert_equal(ts.aux.test, self.ref.auxsteps[i].data,
assert_equal(ts.aux.test, ref.auxsteps[i].data,
"representative value does not match when iterating through "
"all trajectory timesteps")
u.trajectory.close()

def test_get_description(self):
description = self.reader.get_description()
for attr in self.ref.description:
assert_equal(description[attr], self.ref.description[attr],
def test_get_description(self, ref, reader):
description = reader.get_description()
for attr in ref.description:
assert_equal(description[attr], ref.description[attr],
"'Description' does not match for {}".format(attr))

def test_load_from_description(self):
description = self.reader.get_description()
def test_load_from_description(self, ref, reader):
description = reader.get_description()
new = mda.auxiliary.core.auxreader(**description)
assert_equal(new, self.reader,
assert_equal(new, reader,
"AuxReader reloaded from description does not match")


Expand Down

0 comments on commit 853f646

Please sign in to comment.