Skip to content
64 changes: 63 additions & 1 deletion src/aspire/reconstruction/estimator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import os
from pathlib import Path

from aspire.basis import Coef
from aspire.reconstruction.kernel import FourierKernel
Expand All @@ -7,14 +9,41 @@


class Estimator:
def __init__(self, src, basis, batch_size=512, preconditioner="circulant"):
def __init__(
self,
src,
basis,
batch_size=512,
preconditioner="circulant",
checkpoint_iterations=10,
checkpoint_prefix="volume_checkpoint",
maxiter=100,
):
"""
An object representing a 2*L-by-2*L-by-2*L array containing the non-centered Fourier transform of the mean
least-squares estimator kernel.
Convolving a volume with this kernel is equal to projecting and backproject-ing that volume in each of the
projection directions (with the appropriate amplitude multipliers and CTFs) and averaging over the whole
dataset.
Note that this is a non-centered Fourier transform, so the zero frequency is found at index 1.

:param src: `ImageSource` to be used for estimation.
:param basis: 3D Basis to be used during estimation.
:param batch_size: Optional batch size of images drawn from
`src` during back projection and kernel estimation steps.
:param preconditioner: Optional kernel preconditioner (`string`).
Currently supported options are "circulant" or None.
:param checkpoint_iterations: Optionally save `cg` estimated `Volume`
instance periodically each `checkpoint_iterations`.
Setting to None disables, otherwise checks for positive integer.
:param checkpoint_prefix: Optional path prefix for `cg`
checkpoint files. If the parent directory does not exist,
creation is attempted. `_iter{N}` will be appended to the
prefix.
:param maxiter: Optional max number of `cg` iterations
before returning. This should be used in conjunction with
`checkpoint_iterations` to prevent excessive disk usage.
`None` disables.
"""

self.src = src
Expand All @@ -23,6 +52,7 @@ def __init__(self, src, basis, batch_size=512, preconditioner="circulant"):
self.batch_size = batch_size
self.preconditioner = preconditioner

# dtype configuration
if not self.dtype == self.basis.dtype:
logger.warning(
f"Inconsistent types in {self.dtype} Estimator."
Expand All @@ -35,6 +65,38 @@ def __init__(self, src, basis, batch_size=512, preconditioner="circulant"):
f" Given src.L={src.L} != {basis.nres}"
)

# Checkpoint configuration
if checkpoint_iterations is not None:
try:
checkpoint_iterations = int(checkpoint_iterations)
except ValueError:
# Sentinel value to emit a more descriptive message below.
checkpoint_iterations = -1

if not checkpoint_iterations > 0:
raise ValueError(
"`checkpoint_iterations` should be a positive integer or `None`."
)
self.checkpoint_iterations = checkpoint_iterations

# Create checkpointing dirs as needed
if checkpoint_prefix:
parent = Path(checkpoint_prefix).parent
if not os.path.exists(parent):
os.makedirs(parent)
self.checkpoint_prefix = checkpoint_prefix

# Maximum iteration configuration
if maxiter is not None:
try:
maxiter = int(maxiter)
except ValueError:
# Sentinel value to emit a more descriptive message below.
maxiter = -1
if not maxiter > 0:
raise ValueError("`maxiter` should be a positive integer or `None`.")
self.maxiter = maxiter

def __getattr__(self, name):
"""Lazy attributes instantiated on first-access"""

Expand Down
34 changes: 32 additions & 2 deletions src/aspire/reconstruction/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,42 @@ def conj_grad(self, b_coef, tol=1e-5, regularizer=0):
tol = tol or config.mean.cg_tol
target_residual = tol * norm(b_coef)

# callback setup
self.i = 0 # iteration counter

def cb(xk):
self.i += 1 # increment iteration count

logger.info(
f"Delta {norm(b_coef - self.apply_kernel(xk))} (target {target_residual})"
f"[Iter {self.i}]: Delta {norm(b_coef - self.apply_kernel(xk))} (target {target_residual})"
)

x, info = cg(operator, b_coef.flatten(), M=M, callback=cb, tol=tol, atol=0)
# Do checkpoint at `checkpoint_iterations`,
_do_checkpoint = (
self.checkpoint_iterations
and (self.i % self.checkpoint_iterations) == 0
)
# or the last iteration when `maxiter` provided.
if self.maxiter:
_do_checkpoint |= self.i == (self.maxiter - 1)

# Optional checkpoint
if _do_checkpoint:
# Construct checkpoint path
path = f"{self.checkpoint_prefix}_iter{self.i:04d}.npy"
# Write out the current solution
np.save(path, xk.reshape(self.r, self.basis.count))
logger.info(f"Checkpoint saved to `{path}`")

x, info = cg(
operator,
b_coef.flatten(),
M=M,
callback=cb,
tol=tol,
atol=0,
maxiter=self.maxiter,
)

if info != 0:
raise RuntimeError("Unable to converge!")
Expand Down
59 changes: 56 additions & 3 deletions tests/test_mean_estimator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os.path
import tempfile
from unittest import TestCase

import numpy as np
Expand Down Expand Up @@ -26,12 +27,12 @@ def setUp(self):
],
dtype=self.dtype,
)
basis = FBBasis3D((self.resolution,) * 3, dtype=self.dtype)
self.basis = FBBasis3D((self.resolution,) * 3, dtype=self.dtype)

self.estimator = MeanEstimator(self.sim, basis, preconditioner="none")
self.estimator = MeanEstimator(self.sim, self.basis, preconditioner="none")

self.estimator_with_preconditioner = MeanEstimator(
self.sim, basis, preconditioner="circulant"
self.sim, self.basis, preconditioner="circulant"
)

def tearDown(self):
Expand Down Expand Up @@ -675,3 +676,55 @@ def testOptimize2(self):
atol=1e-4,
)
)

def testCheckpoint(self):
"""Exercise the checkpointing and max iterations branches."""
test_iter = 2
with tempfile.TemporaryDirectory() as tmp_input_dir:
prefix = os.path.join(tmp_input_dir, "new", "dirs", "chk")
estimator = MeanEstimator(
self.sim,
self.basis,
preconditioner="none",
checkpoint_iterations=test_iter,
maxiter=test_iter + 1,
checkpoint_prefix=prefix,
)

# Assert we raise when reading `maxiter`.
with raises(RuntimeError, match="Unable to converge!"):
_ = estimator.estimate()

# Load the checkpoint coefficients while tmp_input_dir exists.
b_chk = np.load(f"{prefix}_iter{test_iter:04d}.npy")

# Restart estimate from checkpoint
_ = self.estimator.estimate(b_coef=b_chk)

def testCheckpointArgs(self):
with tempfile.TemporaryDirectory() as tmp_input_dir:
prefix = os.path.join(tmp_input_dir, "chk")

for junk in [-1, 0, "abc"]:
# Junk `checkpoint_iterations` values
with raises(
ValueError, match=r".*iterations.*should be a positive integer.*"
):
_ = MeanEstimator(
self.sim,
self.basis,
preconditioner="none",
checkpoint_iterations=junk,
checkpoint_prefix=prefix,
)
# Junk `maxiter` values
with raises(
ValueError, match=r".*maxiter.*should be a positive integer.*"
):
_ = MeanEstimator(
self.sim,
self.basis,
preconditioner="none",
maxiter=junk,
checkpoint_prefix=prefix,
)