Skip to content

Commit 2470e22

Browse files
chris-langfieldgarrettwrong
authored andcommitted
Basis size fixes (#598)
* fix size checking bug * fb2d * fb3d and pswf2d * dirac * fpswf * polar * format * sz -> size * test init with int * docstring * typo * the same typo
1 parent aeccb9d commit 2470e22

File tree

13 files changed

+53
-14
lines changed

13 files changed

+53
-14
lines changed

src/aspire/basis/basis.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,15 @@ def expand(self, x):
183183
if isinstance(x, Image) or isinstance(x, Volume):
184184
x = x.asnumpy()
185185

186-
# ensure the first dimensions with size of self.sz
187-
sz_roll = x.shape[: -self.ndim]
188-
189-
x = x.reshape((-1, *self.sz))
190-
186+
# check that last ndim values of input shape match
187+
# the shape of this basis
191188
assert (
192189
x.shape[-self.ndim :] == self.sz
193190
), f"Last {self.ndim} dimensions of x must match {self.sz}."
191+
# extract number of images/volumes, or () if only one
192+
sz_roll = x.shape[: -self.ndim]
193+
# convert to standardized shape e.g. (L,L) to (1,L,L)
194+
x = x.reshape((-1, *self.sz))
194195

195196
operator = LinearOperator(
196197
shape=(self.count, self.count),

src/aspire/basis/dirac.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,21 @@ class DiracBasis(Basis):
1414
Define a derived class for Dirac basis
1515
"""
1616

17-
def __init__(self, sz, mask=None, dtype=np.float32):
17+
def __init__(self, size, mask=None, dtype=np.float32):
1818
"""
1919
Initialize an object for Dirac basis
20-
:param sz: The shape of the vectors for which to define the basis.
20+
:param size: The shape of the vectors for which to define the basis.
21+
May be a 2-tuple or an integer, in which case, a square basis is assumed.
2122
:param mask: A boolean _mask of size sz indicating which coordinates
2223
to include in the basis (default np.full(sz, True)).
2324
"""
25+
if isinstance(size, int):
26+
size = (size, size)
2427
if mask is None:
25-
mask = np.full(sz, True)
28+
mask = np.full(size, True)
2629
self._mask = m_flatten(mask)
2730

28-
super().__init__(sz, dtype=dtype)
31+
super().__init__(size, dtype=dtype)
2932

3033
def _build(self):
3134
"""

src/aspire/basis/fb_2d.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,16 @@ def __init__(self, size, ell_max=None, dtype=np.float32):
2929
Initialize an object for the 2D Fourier-Bessel basis class
3030
3131
:param size: The size of the vectors for which to define the basis.
32+
May be a 2-tuple or an integer, in which case a square basis is assumed.
3233
Currently only square images are supported.
3334
:ell_max: The maximum order ell of the basis elements. If no input
3435
(= None), it will be set to np.Inf and the basis includes all
3536
ell such that the resulting basis vectors are concentrated
3637
below the Nyquist frequency (default Inf).
3738
"""
3839

40+
if isinstance(size, int):
41+
size = (size, size)
3942
ndim = len(size)
4043
assert ndim == 2, "Only two-dimensional basis functions are supported."
4144
assert len(set(size)) == 1, "Only square domains are supported."

src/aspire/basis/fb_3d.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,15 @@ def __init__(self, size, ell_max=None, dtype=np.float32):
2323
Initialize an object for the 3D Fourier-Bessel basis class
2424
2525
:param size: The size of the vectors for which to define the basis.
26+
May be a 3-tuple or an integer, in which case a cubic basis is assumed.
2627
Currently only cubic images are supported.
2728
:ell_max: The maximum order ell of the basis elements. If no input
2829
(= None), it will be set to np.Inf and the basis includes all
2930
ell such that the resulting basis vectors are concentrated
3031
below the Nyquist frequency (default Inf).
3132
"""
33+
if isinstance(size, int):
34+
size = (size, size, size)
3235
ndim = len(size)
3336
assert ndim == 3, "Only three-dimensional basis functions are supported."
3437
assert len(set(size)) == 1, "Only cubic domains are supported."

src/aspire/basis/fpswf_2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def __init__(self, size, gamma_truncation=1.0, beta=1.0, dtype=np.float32):
3434
Initialize an object for 2D prolate spheroidal wave function (PSWF) basis expansion using fast method.
3535
3636
:param size: The size of the vectors for which to define the basis
37-
and the image resultion. Currently only square images are supported.
37+
and the image resolution. May be a 2-tuple or an integer, in which case
38+
a square basis is assumed. Currently only square images are supported.
3839
:param gamma_trunc: Truncation parameter of PSWFs, between 0 and 1e6,
3940
which controls the length of the expansion and the approximation error.
4041
Smaller values (close to zero) guarantee smaller errors, yet longer

src/aspire/basis/polar_2d.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ def __init__(self, size, nrad=None, ntheta=None, dtype=np.float32):
2020
Initialize an object for the 2D polar Fourier grid class
2121
2222
:param size: The shape of the vectors for which to define the grid.
23+
May be a 2-tuple or an integer, in which case a square basis is assumed.
2324
Currently only square images are supported.
24-
:param nrad: The number of points in the radial dimension. Default is resoltuion // 2.
25+
:param nrad: The number of points in the radial dimension. Default is resolution // 2.
2526
:param ntheta: The number of points in the angular dimension. Default is 8 * nrad.
2627
"""
27-
28+
if isinstance(size, int):
29+
size = (size, size)
2830
ndim = len(size)
2931
assert ndim == 2, "Only two-dimensional grids are supported."
3032
assert len(set(size)) == 1, "Only square domains are supported."

src/aspire/basis/pswf_2d.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def __init__(self, size, gamma_trunc=1.0, beta=1.0, dtype=np.float32):
3939
Initialize an object for 2D PSWF basis expansion using direct method
4040
4141
:param size: The size of the vectors for which to define the basis
42-
and the image resultion. Currently only square images are supported.
42+
and the image resolution. May be a 2-tuple or an integer, in which case
43+
a square basis is assumed. Currently only square images are supported.
4344
:param gamma_trunc: Truncation parameter of PSWFs, between 0 and 1e6,
4445
which controls the length of the expansion and the approximation error.
4546
Smaller values (close to zero) guarantee smaller errors, yet longer
@@ -51,7 +52,8 @@ def __init__(self, size, gamma_trunc=1.0, beta=1.0, dtype=np.float32):
5152
parameter controls the bandlimit of the PSWFs.
5253
:param dtype: Internal ndarray datatype.
5354
"""
54-
55+
if isinstance(size, int):
56+
size = (size, size)
5557
self.rcut = size[0] // 2
5658
self.gmcut = gamma_trunc
5759
self.beta = beta

tests/test_Diracbasis.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,7 @@ def testDiracEvaluate_t(self):
189189
)
190190
result = self.basis.evaluate_t(x)
191191
self.assertTrue(np.allclose(result, m_flatten(x)))
192+
193+
def testInitWithIntSize(self):
194+
# make sure we can instantiate with just an int as a shortcut
195+
self.assertEqual((8, 8), DiracBasis(8).sz)

tests/test_FBbasis2D.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,7 @@ def testComplexCoversionErrorsToReal(self):
389389

390390
# Try a 0d vector, should not crash.
391391
_ = self.basis.to_real(cv1.reshape(-1))
392+
393+
def testInitWithIntSize(self):
394+
# make sure we can instantiate with just an int as a shortcut
395+
self.assertEqual((8, 8), FBBasis2D(8).sz)

tests/test_FBbasis3D.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,3 +697,7 @@ def testFBBasis3DExpand(self):
697697
atol=utest_tolerance(self.dtype),
698698
)
699699
)
700+
701+
def testInitWithIntSize(self):
702+
# make sure we can instantiate with just an int as a shortcut
703+
self.assertEqual((8, 8, 8), FBBasis3D(8).sz)

0 commit comments

Comments
 (0)