Skip to content

Commit

Permalink
neater pattern for assert statements
Browse files Browse the repository at this point in the history
  • Loading branch information
espenhgn committed Nov 11, 2020
1 parent 963f9a1 commit 95c8f6c
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 175 deletions.
43 changes: 14 additions & 29 deletions lfpykit/cellgeometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,35 +115,20 @@ def __init__(self, x, y, z, d):
array of compartment surface areas in units of um^2
'''
# check input
try:
assert np.all([type(x) is np.ndarray,
type(y) is np.ndarray,
type(z) is np.ndarray,
type(d) is np.ndarray])
except AssertionError:
raise AssertionError('x, y, z and d must be of type numpy.ndarray')
try:
assert x.ndim == y.ndim == z.ndim == 2
except AssertionError:
raise AssertionError('x, y and z must be of shape (n_seg x 2)')
try:
assert x.shape == y.shape == z.shape
except AssertionError:
raise AssertionError('x, y and z must all be the same shape')
try:
assert x.shape[1] == 2
except AssertionError:
raise AssertionError(
'the last axis of x, y and z must be of length 2')
try:
try:
assert d.shape == x.shape
except AssertionError:
assert d.ndim == 1 and d.size == x.shape[0]
except AssertionError:
raise AssertionError('d must either be 1-dimensional with '
'size == n_seg or 2-dimensional with '
'd.shape == x.shape')
assert np.all([type(x) is np.ndarray,
type(y) is np.ndarray,
type(z) is np.ndarray,
type(d) is np.ndarray]), \
'x, y, z and d must be of type numpy.ndarray'
assert x.ndim == y.ndim == z.ndim == 2, \
'x, y and z must be of shape (n_seg x 2)'
assert x.shape == y.shape == z.shape, \
'x, y and z must all be the same shape'
assert x.shape[1] == 2, \
'the last axis of x, y and z must be of length 2'
assert d.shape == x.shape or (d.ndim == 1 and d.size == x.shape[0]), \
'd must either be 1-dimensional with size == n_seg ' + \
'or 2-dimensional with d.shape == x.shape'

# set attributes
self.x = x
Expand Down
35 changes: 11 additions & 24 deletions lfpykit/eegmegcalc.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,14 +977,8 @@ def __init__(self, sensor_locations, mu=4 * np.pi * 1E-7):
"""
Initialize class MEG
"""
try:
assert sensor_locations.ndim == 2
except AssertionError:
raise AssertionError('sensor_locations.ndim != 2')
try:
assert sensor_locations.shape[1] == 3
except AssertionError:
raise AssertionError('sensor_locations.shape[1] != 3')
assert sensor_locations.ndim == 2, 'sensor_locations.ndim != 2'
assert sensor_locations.shape[1] == 3, 'sensor_locations.shape[1] != 3'

# set attributes
self.sensor_locations = sensor_locations
Expand Down Expand Up @@ -1033,18 +1027,12 @@ def calculate_H(self, current_dipole_moment, dipole_location):
If dimensionality of current_dipole_moment and/or dipole_location
is wrong
"""
try:
assert current_dipole_moment.ndim == 2
except AssertionError:
raise AssertionError('current_dipole_moment.ndim != 2')
try:
assert current_dipole_moment.shape[0] == 3
except AssertionError:
raise AssertionError('current_dipole_moment.shape[0] != 3')
try:
assert dipole_location.shape == (3, )
except AssertionError:
raise AssertionError('dipole_location.shape != (3, )')
assert current_dipole_moment.ndim == 2, \
'current_dipole_moment.ndim != 2'
assert current_dipole_moment.shape[0] == 3, \
'current_dipole_moment.shape[0] != 3'
assert dipole_location.shape == (3, ), \
'dipole_location.shape != (3, )'

# container
H = np.empty((self.sensor_locations.shape[0], 3,
Expand All @@ -1053,10 +1041,9 @@ def calculate_H(self, current_dipole_moment, dipole_location):
for i, r in enumerate(self.sensor_locations):
R = r - dipole_location
assert R.ndim == 1 and R.size == 3
try:
assert not np.allclose(R, np.zeros(3))
except AssertionError:
raise AssertionError('Identical dipole and sensor location.')
assert not np.allclose(R, np.zeros(3)), \
'Identical dipole and sensor location.'

H[i, ] = np.cross(current_dipole_moment.T, R).T \
/ (4 * np.pi * np.sqrt((R**2).sum())**3)

Expand Down
174 changes: 52 additions & 122 deletions lfpykit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,26 +257,16 @@ def __init__(self, cell, x, y, z, sigma=0.3):
super().__init__(cell=cell)

# check input
try:
assert np.all([isinstance(x, np.ndarray),
isinstance(y, np.ndarray),
isinstance(z, np.ndarray)])
except AssertionError:
raise AssertionError('x, y and z must be of type numpy.ndarray')
try:
assert x.ndim == y.ndim == z.ndim == 1
except AssertionError:
raise AssertionError('x, y and z must be of shape (n_coords, )')
try:
assert x.shape == y.shape == z.shape
except AssertionError:
raise AssertionError(
'x, y and z must contain the same number of elements')
try:
assert isinstance(sigma, float) and sigma > 0
except AssertionError:
raise AssertionError(
'sigma must be a float number greater than zero')
assert np.all([isinstance(x, np.ndarray),
isinstance(y, np.ndarray),
isinstance(z, np.ndarray)]), \
'x, y and z must be of type numpy.ndarray'
assert x.ndim == y.ndim == z.ndim == 1, \
'x, y and z must be of shape (n_coords, )'
assert x.shape == y.shape == z.shape, \
'x, y and z must contain the same number of elements'
assert isinstance(sigma, float) and sigma > 0, \
'sigma must be a float number greater than zero'

# set attributes
self.x = x
Expand Down Expand Up @@ -426,26 +416,16 @@ def __init__(self, cell, x, y, z, sigma=0.3):
super().__init__(cell=cell)

# check input
try:
assert np.all([isinstance(x, np.ndarray),
isinstance(y, np.ndarray),
isinstance(z, np.ndarray)])
except AssertionError:
raise AssertionError('x, y and z must be of type numpy.ndarray')
try:
assert x.ndim == y.ndim == z.ndim == 1
except AssertionError:
raise AssertionError('x, y and z must be of shape (n_coords, )')
try:
assert x.shape == y.shape == z.shape
except AssertionError:
raise AssertionError(
'x, y and z must contain the same number of elements')
try:
assert isinstance(sigma, float) and sigma > 0
except AssertionError:
raise AssertionError(
'sigma must be a float number greater than zero')
assert np.all([isinstance(x, np.ndarray),
isinstance(y, np.ndarray),
isinstance(z, np.ndarray)]), \
'x, y and z must be of type numpy.ndarray'
assert x.ndim == y.ndim == z.ndim == 1, \
'x, y and z must be of shape (n_coords, )'
assert x.shape == y.shape == z.shape, \
'x, y and z must contain the same number of elements'
assert isinstance(sigma, float) and sigma > 0, \
'sigma must be a float number greater than zero'

# set attributes
self.x = x
Expand Down Expand Up @@ -815,12 +795,9 @@ def __init__(self, cell, sigma=0.3, probe=None,
self.z = np.array([z])
else:
self.z = np.array(z).flatten()
try:
assert ((self.x.size == self.y.size) and
(self.x.size == self.z.size))
except AssertionError:
raise AssertionError(
"The number of elements in [x, y, z] must be equal")
assert (self.x.size == self.y.size and
self.x.size == self.z.size), \
"The number of elements in [x, y, z] must be equal"

if N is not None:
if not isinstance(N, np.ndarray):
Expand Down Expand Up @@ -1354,15 +1331,10 @@ def distort_cell_geometry(self, axis='z', nu=0.0):
Poisson's ratio. Ratio between axial and transversal
compression/stretching. Default is 0.
"""
try:
assert abs(self.squeeze_cell_factor) < 1.
except AssertionError:
raise AssertionError('abs(squeeze_cell_factor) >= 1, '
+ ' squeeze_cell_factor must be in <-1, 1>')
try:
assert axis in ['x', 'y', 'z']
except AssertionError:
raise AssertionError('axis={} not "x", "y" or "z"'.format(axis))
assert abs(self.squeeze_cell_factor) < 1., \
'abs(squeeze_cell_factor) >= 1, must be in <-1, 1>'
assert axis in ['x', 'y', 'z'], \
'axis={} not "x", "y" or "z"'.format(axis)

for pos, dir_ in zip([self.cell.x[0, ].mean(),
self.cell.y[0, ].mean(),
Expand Down Expand Up @@ -1514,20 +1486,12 @@ def __init__(self,
"""initialize class OneSphereVolumeConductor"""
super().__init__(cell=cell)
# check inputs
try:
assert r.shape[0] == 3
assert r.ndim == 2
except AssertionError:
raise AssertionError('r must be a shape (3, n_points) ndarray')
try:
assert (isinstance(R, float)) or (isinstance(R, int))
except AssertionError:
raise AssertionError('sphere radius R must be a float value')
try:
assert (sigma_i > 0) & (sigma_o > 0)
except AssertionError:
raise AssertionError(
'sigma_i and sigma_o must both be positive values')
assert r.shape[0] == 3 and r.ndim == 2, \
'r must be a shape (3, n_points) ndarray'
assert (isinstance(R, float)) or (isinstance(R, int)), \
'sphere radius R must be a float value'
assert sigma_i > 0 and sigma_o > 0, \
'sigma_i and sigma_o must both be positive values'

self.r = r
self.R = R
Expand Down Expand Up @@ -1563,18 +1527,12 @@ def calc_potential(self, rs, current, min_distance=1., n_max=1000):
an 1D ndarray, and shape (n-points, I.size) ndarray is returned.
Unit [mV].
"""
try:
assert type(rs) in [int, float, np.float64]
assert abs(rs) < self.R
except AssertionError:
raise AssertionError(
'source location rs must be a float value and |rs| '
'must be less than sphere radius R')
try:
assert (min_distance is None) \
or (type(min_distance) in [float, int, np.float64])
except AssertionError:
raise AssertionError('min_distance must be None or a float')
assert type(rs) in [int, float, np.float64], \
'source location rs must be a float value '
assert abs(rs) < self.R, '|rs| must be less than sphere radius R'
assert (min_distance is None) or \
(type(min_distance) in [float, int, np.float64]), \
'min_distance must be None or a float'

r = self.r[0]
theta = self.r[1]
Expand Down Expand Up @@ -1624,23 +1582,16 @@ def calc_potential(self, rs, current, min_distance=1., n_max=1000):
phi_i[inds_i] += 1. / denom

if isinstance(current, np.ndarray):
try:
assert np.all(np.isfinite(current))
assert np.all(np.isreal(current))
assert current.ndim == 1
except AssertionError:
raise AssertionError('input argument current must be float or '
'1D ndarray with float values')
assert np.all(np.isfinite(current) & np.isreal(current)), \
'current must be finite and real'
assert current.ndim == 1, 'current must be 1D'

return np.dot((phi_i + phi_o).reshape((1, -1)).T,
current.reshape((1, -1))
) / (4. * np.pi * self.sigma_i)
else:
try:
assert np.isfinite(current) and np.shape(current) == ()
except AssertionError:
raise AssertionError('input argument I must be float or 1D '
'ndarray with float values')
assert np.isfinite(current) and np.shape(current) == (), \
'current must be float or 1D ndarray with float values'
return current / (4. * np.pi * self.sigma_i) * (phi_i + phi_o)

def get_transformation_matrix(self, n_max=1000):
Expand Down Expand Up @@ -1870,35 +1821,14 @@ def __init__(self, cell, z, r):

# check input parameters
for varname, var in zip(['z', 'r'], [z, r]):
try:
assert type(var) is np.ndarray
except AssertionError:
raise AssertionError('type({}) != np.ndarray'.format(varname))
try:
assert z.ndim == 2
except AssertionError:
raise AssertionError('z.ndim != 2')
try:
assert np.all(np.diff(z, axis=-1) > 0)
except AssertionError:
raise AssertionError('lower edge <= upper edge')
try:
assert z.shape[1] == 2
except AssertionError:
raise AssertionError('z.shape[1] != 2')

try:
assert r.ndim == 1
except AssertionError:
raise AssertionError('r.ndim != 1')
try:
assert r.shape[0] == z.shape[0]
except AssertionError:
raise AssertionError('r.shape[0] != z.shape[0]')
try:
assert np.all(r > 0)
except AssertionError:
raise AssertionError('r must be greater than 0')
assert type(var) is np.ndarray, 'type({}) != np.ndarray'.format(
varname)
assert z.ndim == 2, 'z.ndim != 2'
assert np.all(np.diff(z, axis=-1) > 0), 'lower edge <= upper edge'
assert z.shape[1] == 2, 'z.shape[1] != 2'
assert r.ndim == 1, 'r.ndim != 1'
assert r.shape[0] == z.shape[0], 'r.shape[0] != z.shape[0]'
assert np.all(r > 0), 'r must be greater than 0'

self.z = z
self.r = r
Expand Down

0 comments on commit 95c8f6c

Please sign in to comment.