Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/aspire/basis/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,15 @@ def expand(self, x):
if isinstance(x, Image) or isinstance(x, Volume):
x = x.asnumpy()

# ensure the first dimensions with size of self.sz
sz_roll = x.shape[: -self.ndim]

x = x.reshape((-1, *self.sz))

# check that last ndim values of input shape match
# the shape of this basis
assert (
x.shape[-self.ndim :] == self.sz
), f"Last {self.ndim} dimensions of x must match {self.sz}."
# extract number of images/volumes, or () if only one
sz_roll = x.shape[: -self.ndim]
# convert to standardized shape e.g. (L,L) to (1,L,L)
x = x.reshape((-1, *self.sz))

operator = LinearOperator(
shape=(self.count, self.count),
Expand Down
11 changes: 7 additions & 4 deletions src/aspire/basis/dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,21 @@ class DiracBasis(Basis):
Define a derived class for Dirac basis
"""

def __init__(self, sz, mask=None, dtype=np.float32):
def __init__(self, size, mask=None, dtype=np.float32):
"""
Initialize an object for Dirac basis
:param sz: The shape of the vectors for which to define the basis.
:param size: The shape of the vectors for which to define the basis.
May be a 2-tuple or an integer, in which case, a square basis is assumed.
:param mask: A boolean _mask of size sz indicating which coordinates
to include in the basis (default np.full(sz, True)).
"""
if isinstance(size, int):
size = (size, size)
if mask is None:
mask = np.full(sz, True)
mask = np.full(size, True)
self._mask = m_flatten(mask)

super().__init__(sz, dtype=dtype)
super().__init__(size, dtype=dtype)

def _build(self):
"""
Expand Down
3 changes: 3 additions & 0 deletions src/aspire/basis/fb_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@ def __init__(self, size, ell_max=None, dtype=np.float32):
Initialize an object for the 2D Fourier-Bessel basis class

:param size: The size of the vectors for which to define the basis.
May be a 2-tuple or an integer, in which case a square basis is assumed.
Currently only square images are supported.
:ell_max: The maximum order ell of the basis elements. If no input
(= None), it will be set to np.Inf and the basis includes all
ell such that the resulting basis vectors are concentrated
below the Nyquist frequency (default Inf).
"""

if isinstance(size, int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe to cover these in testing we should initialize with an integer size in each basis test.

size = (size, size)
ndim = len(size)
assert ndim == 2, "Only two-dimensional basis functions are supported."
assert len(set(size)) == 1, "Only square domains are supported."
Expand Down
3 changes: 3 additions & 0 deletions src/aspire/basis/fb_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ def __init__(self, size, ell_max=None, dtype=np.float32):
Initialize an object for the 3D Fourier-Bessel basis class

:param size: The size of the vectors for which to define the basis.
May be a 3-tuple or an integer, in which case a cubic basis is assumed.
Currently only cubic images are supported.
:ell_max: The maximum order ell of the basis elements. If no input
(= None), it will be set to np.Inf and the basis includes all
ell such that the resulting basis vectors are concentrated
below the Nyquist frequency (default Inf).
"""
if isinstance(size, int):
size = (size, size, size)
ndim = len(size)
assert ndim == 3, "Only three-dimensional basis functions are supported."
assert len(set(size)) == 1, "Only cubic domains are supported."
Expand Down
3 changes: 2 additions & 1 deletion src/aspire/basis/fpswf_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(self, size, gamma_truncation=1.0, beta=1.0, dtype=np.float32):
Initialize an object for 2D prolate spheroidal wave function (PSWF) basis expansion using fast method.

:param size: The size of the vectors for which to define the basis
and the image resultion. Currently only square images are supported.
and the image resolution. May be a 2-tuple or an integer, in which case
a square basis is assumed. Currently only square images are supported.
:param gamma_trunc: Truncation parameter of PSWFs, between 0 and 1e6,
which controls the length of the expansion and the approximation error.
Smaller values (close to zero) guarantee smaller errors, yet longer
Expand Down
6 changes: 4 additions & 2 deletions src/aspire/basis/polar_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ def __init__(self, size, nrad=None, ntheta=None, dtype=np.float32):
Initialize an object for the 2D polar Fourier grid class

:param size: The shape of the vectors for which to define the grid.
May be a 2-tuple or an integer, in which case a square basis is assumed.
Currently only square images are supported.
:param nrad: The number of points in the radial dimension. Default is resoltuion // 2.
:param nrad: The number of points in the radial dimension. Default is resolution // 2.
:param ntheta: The number of points in the angular dimension. Default is 8 * nrad.
"""

if isinstance(size, int):
size = (size, size)
ndim = len(size)
assert ndim == 2, "Only two-dimensional grids are supported."
assert len(set(size)) == 1, "Only square domains are supported."
Expand Down
6 changes: 4 additions & 2 deletions src/aspire/basis/pswf_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(self, size, gamma_trunc=1.0, beta=1.0, dtype=np.float32):
Initialize an object for 2D PSWF basis expansion using direct method

:param size: The size of the vectors for which to define the basis
and the image resultion. Currently only square images are supported.
and the image resolution. May be a 2-tuple or an integer, in which case
a square basis is assumed. Currently only square images are supported.
:param gamma_trunc: Truncation parameter of PSWFs, between 0 and 1e6,
which controls the length of the expansion and the approximation error.
Smaller values (close to zero) guarantee smaller errors, yet longer
Expand All @@ -51,7 +52,8 @@ def __init__(self, size, gamma_trunc=1.0, beta=1.0, dtype=np.float32):
parameter controls the bandlimit of the PSWFs.
:param dtype: Internal ndarray datatype.
"""

if isinstance(size, int):
size = (size, size)
self.rcut = size[0] // 2
self.gmcut = gamma_trunc
self.beta = beta
Expand Down
4 changes: 4 additions & 0 deletions tests/test_Diracbasis.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,7 @@ def testDiracEvaluate_t(self):
)
result = self.basis.evaluate_t(x)
self.assertTrue(np.allclose(result, m_flatten(x)))

def testInitWithIntSize(self):
# make sure we can instantiate with just an int as a shortcut
self.assertEqual((8, 8), DiracBasis(8).sz)
4 changes: 4 additions & 0 deletions tests/test_FBbasis2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,7 @@ def testComplexCoversionErrorsToReal(self):

# Try a 0d vector, should not crash.
_ = self.basis.to_real(cv1.reshape(-1))

def testInitWithIntSize(self):
# make sure we can instantiate with just an int as a shortcut
self.assertEqual((8, 8), FBBasis2D(8).sz)
4 changes: 4 additions & 0 deletions tests/test_FBbasis3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,3 +697,7 @@ def testFBBasis3DExpand(self):
atol=utest_tolerance(self.dtype),
)
)

def testInitWithIntSize(self):
# make sure we can instantiate with just an int as a shortcut
self.assertEqual((8, 8, 8), FBBasis3D(8).sz)
4 changes: 4 additions & 0 deletions tests/test_FPSWFbasis2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ def testFPSWFBasis2DEvaluate(self):
result = self.basis.evaluate(coeffs)
images = np.load(os.path.join(DATA_DIR, "pswf2d_xcoeff_out_8_8.npy")).T # RCOPT
self.assertTrue(np.allclose(result.asnumpy(), images))

def testInitWithIntSize(self):
# make sure we can instantiate with just an int as a shortcut
self.assertEqual((8, 8), FPSWFBasis2D(8).sz)
4 changes: 4 additions & 0 deletions tests/test_PSWFbasis2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ def testPSWFBasis2DEvaluate(self):
result = self.basis.evaluate(coeffs)
images = np.load(os.path.join(DATA_DIR, "pswf2d_xcoeff_out_8_8.npy")).T # RCOPT
self.assertTrue(np.allclose(result.asnumpy(), images))

def testInitWithIntSize(self):
# make sure we can instantiate with just an int as a shortcut
self.assertEqual((8, 8), PSWFBasis2D(8).sz)
4 changes: 4 additions & 0 deletions tests/test_PolarBasis2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,7 @@ def testPolarBasis2DAdjoint(self):
)

self.assertTrue(np.isclose(lhs, rhs, atol=utest_tolerance(self.dtype)))

def testInitWithIntSize(self):
# make sure we can instantiate with just an int as a shortcut
self.assertEqual((8, 8), PolarBasis2D(8).sz)