Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
6ccb435
patch project/back to be adjoints
garrettwrong Apr 3, 2024
dadc94d
another pts_rot reversal
garrettwrong Apr 30, 2024
7414c7c
simplify mean est and back tests
garrettwrong Apr 30, 2024
2733d81
tmp rm hardcoded 'optimize{1,2}' mean est tests
garrettwrong Apr 30, 2024
4574422
Remove legacy volume/sim call from mean est tests
garrettwrong Apr 30, 2024
70eb82b
Remove legacy volume/sim call from wt mean est tests
garrettwrong Apr 30, 2024
b473c9e
refactor mean_est towards pytest
garrettwrong May 1, 2024
3ecac64
fix small precon init bug
garrettwrong May 2, 2024
9ec354a
cleanup unneeded asnumpy calls in mean estimate
garrettwrong May 2, 2024
c18fbe0
refactor weigthed mean est test
garrettwrong May 2, 2024
9af8f34
covar 3d pts orientation swap and m_ rm
garrettwrong May 2, 2024
3e926f3
None has no lower, rm
garrettwrong May 2, 2024
e1e01ef
pass checkpoint as x0, rename x_chk
garrettwrong May 16, 2024
4936dca
Call mean est tests with circulant
garrettwrong May 16, 2024
031fd17
Call extra mean est none-ish tests as expensive
garrettwrong May 16, 2024
4a53d29
test adjoint formula in mean_est
garrettwrong May 29, 2024
2be2cf1
add src based adjoint test
garrettwrong May 30, 2024
d72f2db
update tests with backward scaling factor
garrettwrong May 31, 2024
4691bbd
optional index cleanup commit
garrettwrong May 31, 2024
4ea38dc
revert weakening of simulation config for mean_est_tests (debugging)
garrettwrong May 31, 2024
f069f3f
add scale comment
garrettwrong May 31, 2024
62a8d80
Revert "optional index cleanup commit"
garrettwrong Jun 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/aspire/covariance/covar.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def compute_kernel(self):
weights[:, 0, :] = 0

# TODO: This is where this differs from MeanEstimator
pts_rot = np.moveaxis(pts_rot[::-1], 1, 0).reshape(-1, 3, L**2)
pts_rot = np.moveaxis(pts_rot, 1, 0).reshape(-1, 3, L**2)
weights = weights.T.reshape((-1, L**2))

batch_n = weights.shape[0]
Expand All @@ -67,7 +67,7 @@ def compute_kernel(self):
factors[j] = anufft(weights[j], pts_rot[j], (_2L, _2L, _2L), real=True)

factors = Volume(factors).to_vec()
kernel += vecmat_to_volmat(factors.T @ factors) / (n * L**8)
kernel += (factors.T @ factors).reshape((_2L,) * 6) / (n * L**8)

# Ensure symmetric kernel
kernel[0, :, :, :, :, :] = 0
Expand All @@ -90,6 +90,8 @@ def estimate(self, mean_vol, noise_variance, tol=1e-5, regularizer=0):
b_coef = self.src_backward(mean_vol, noise_variance)
est_coef = self.conj_grad(b_coef, tol=tol, regularizer=regularizer)
covar_est = self.basis.mat_evaluate(est_coef)
# Note, notice these cancel out, but requires a lot of changes elsewhere in this file,
# basically totally removing all the `utils/matrix` hacks ... todo.
covar_est = vecmat_to_volmat(make_symmat(volmat_to_vecmat(covar_est)))
return covar_est

Expand Down Expand Up @@ -180,7 +182,9 @@ def src_backward(self, mean_vol, noise_variance, shrink_method=None):
im_centered_b[j] = self.src.im_backward(im_centered[j], i + j)
im_centered_b = Volume(im_centered_b).to_vec()

covar_b += vecmat_to_volmat(im_centered_b.T @ im_centered_b) / self.src.n
covar_b += (im_centered_b.T @ im_centered_b).reshape(
(self.src.L,) * 6
) / self.src.n

covar_b_coef = self.basis.mat_evaluate_t(covar_b)
return self._shrink(covar_b_coef, noise_variance, shrink_method)
Expand Down
2 changes: 1 addition & 1 deletion src/aspire/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def backproject(self, rot_matrices, symmetry_group=None):
)
pts_rot = pts_rot.reshape((3, -1))

vol += anufft(im_f, pts_rot[::-1], (L, L, L), real=True)
vol += anufft(im_f, pts_rot, (L, L, L), real=True)

vol /= L

Expand Down
13 changes: 10 additions & 3 deletions src/aspire/reconstruction/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def __init__(
self.basis = basis
self.dtype = self.src.dtype
self.batch_size = batch_size
if not preconditioner or preconditioner.lower() == "none":
# Resolve None and string nones to None
preconditioner = None
elif preconditioner not in ["circulant"]:
raise ValueError(
f"Supplied preconditioner {preconditioner} is not supported."
)
self.preconditioner = preconditioner
self.boost = boost

Expand Down Expand Up @@ -128,12 +135,12 @@ def __getattr__(self, name):
def compute_kernel(self):
raise NotImplementedError("Subclasses must implement the compute_kernel method")

def estimate(self, b_coef=None, tol=1e-5, regularizer=0):
def estimate(self, b_coef=None, x0=None, tol=1e-5, regularizer=0):
"""Return an estimate as a Volume instance."""
if b_coef is None:
b_coef = self.src_backward()
est_coef = self.conj_grad(b_coef, tol=tol, regularizer=regularizer)
est = Coef(self.basis, est_coef).evaluate().T
est_coef = self.conj_grad(b_coef, x0=x0, tol=tol, regularizer=regularizer)
est = Coef(self.basis, est_coef).evaluate()

return est

Expand Down
11 changes: 7 additions & 4 deletions src/aspire/reconstruction/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def __getattr__(self, name):
1.0 / self.kernel.circularize()
)
else:
if self.preconditioner.lower() not in (None, "none"):
if self.preconditioner and (
self.preconditioner.lower() not in ("none")
):
logger.warning(
f"Preconditioner {self.preconditioner} is not implemented, resetting to default of None."
)
Expand Down Expand Up @@ -126,7 +128,7 @@ def _compute_kernel(self):
batch_kernel += (
1
/ (self.r * self.src.L**4)
* anufft(weights, pts_rot[::-1], (_2L, _2L, _2L), real=True)
* anufft(weights, pts_rot, (_2L, _2L, _2L), real=True)
)

kernel[k, j] += batch_kernel
Expand Down Expand Up @@ -189,7 +191,7 @@ def src_backward(self):

return res

def conj_grad(self, b_coef, tol=1e-5, regularizer=0):
def conj_grad(self, b_coef, x0=None, tol=1e-5, regularizer=0):
count = b_coef.shape[-1] # b_coef should be (r, basis.count)
kernel = self.kernel

Expand Down Expand Up @@ -240,12 +242,13 @@ def cb(xk):
# Construct checkpoint path
path = f"{self.checkpoint_prefix}_iter{self.i:04d}.npy"
# Write out the current solution
np.save(path, xk.reshape(self.r, self.basis.count))
np.save(path, xk)
logger.info(f"Checkpoint saved to `{path}`")

x, info = cg(
operator,
b_coef.flatten(),
x0=x0,
M=M,
callback=cb,
tol=tol,
Expand Down
Loading