Skip to content

Commit

Permalink
Merge pull request #202 from carterbox/finite-support
Browse files Browse the repository at this point in the history
NEW: Implement finite probe support constraint
  • Loading branch information
carterbox committed May 3, 2022
2 parents 1eef247 + 9a8a636 commit 40c47b0
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 14 deletions.
63 changes: 61 additions & 2 deletions src/tike/ptycho/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ class ProbeOptions:
m: np.array = dataclasses.field(init=False, default_factory=lambda: None)
"""The first moment for adaptive moment."""

probe_support: float = 10.0
"""Weight of the finite probe support constraint; zero or greater."""

probe_support_radius: float = 0.5 * 0.6
"""Radius of finite probe support as fraction of probe grid. [0.0, 0.5]."""

probe_support_degree: float = 5
"""Degree of the supergaussian defining the probe support."""

def copy_to_device(self):
"""Copy to the current GPU memory."""
if self.v is not None:
Expand Down Expand Up @@ -514,7 +523,10 @@ def constrain_center_peak(probe):
maximum intensity is centered.
"""
half = probe.shape[-2] // 2, probe.shape[-1] // 2
logger.info("Constrained probe intensity to center with sigma=%f", half[0])
logger.info(
"Constrained probe intensity to center with sigma=%.3e",
half[0],
)
# First reshape the probe to 3D so it is a single stack of 2D images.
stack = probe.reshape((-1, *probe.shape[-2:]))
intensity = cupyx.scipy.ndimage.gaussian_filter(
Expand All @@ -538,7 +550,7 @@ def constrain_probe_sparsity(probe, f):
"""Constrain the probe intensity so no more than f/1 elements are nonzero."""
if f == 1:
return probe
logger.info("Constrained probe intensity spasity to %f", f)
logger.info("Constrained probe intensity spasity to %.3e", f)
# First reshape the probe to 3D so it is a single stack of 2D images.
stack = probe.reshape((-1, *probe.shape[-2:]))
intensity = np.sum(np.square(np.abs(stack)), axis=0)
Expand All @@ -557,6 +569,45 @@ def constrain_probe_sparsity(probe, f):
return probe


def finite_probe_support(probe, *, radius=0.5, degree=5, p=1.0):
"""Returns a supergaussian penalty function for finite probe support.
A mask which provides an illumination penalty is determined by the equation:
penalty = p - p * exp( -( (x / radius)**2 + (y / radius)**2 )**degree)
where the maximum penalty is p and the minium penalty is 0. This penalty
function is used in the probe gradient to supress values in the probe grid
far from the center. The penalty is 0 near the center and p at the edge.
Parameters
----------
radius : float (0, 0.5]
The radius of the supergaussian.
degree : float >= 0
The exponent of the terms in the supergaussian equation. Controls how
hard the penalty transition is outside of the radius.
Degree = 0 is a flat penalty.
Degree > 0, < 1 is flatter than a gaussian.
Degree 1 is a gaussian.
Degree > 1 is more like a top-hat than a gaussian.
"""
if p <= 0:
return 0.0
logger.info(
"Probe support constraint with weight %.3e, radius %.3e, degree %.3e",
p,
radius,
degree,
)
N = probe.shape[-1]
centers = cp.linspace(-0.5, 0.5, num=N, endpoint=False) + 0.5 / N
i, j = cp.meshgrid(centers, centers)
mask = 1 - cp.exp(-(cp.square(i / radius) + cp.square(j / radius))**degree)
return p * mask.astype('float32')


if __name__ == "__main__":
cp.random.seed()
x = (cp.random.rand(7, 1, 9, 3, 3) +
Expand All @@ -569,3 +620,11 @@ def constrain_probe_sparsity(probe, f):
print(p1)
p2 = constrain_probe_sparsity(p1, 0.6)
print(p2)

import sys
np.set_printoptions(threshold=sys.maxsize, precision=2)
print(finite_probe_support(
np.zeros((24, 24)),
radius=0.5,
degree=5,
))
25 changes: 20 additions & 5 deletions src/tike/ptycho/solvers/lstsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def lstsq_grad(
num_batch=algorithm_options.num_batch,
psi_update_denominator=psi_update_denominator,
object_options=object_options,
probe_options=probe_options,
)

if position_options:
Expand Down Expand Up @@ -290,6 +291,7 @@ def _update_nearplane(
psi_update_denominator,
*,
object_options,
probe_options,
):

patches = comm.pool.map(_get_patches, nearplane, psi, scan_, op=op)
Expand Down Expand Up @@ -347,6 +349,7 @@ def _update_nearplane(
nearplane,
scan_,
unique_probe,
probe,
common_grad_psi,
common_grad_probe,
psi_update_denominator,
Expand All @@ -356,6 +359,7 @@ def _update_nearplane(
m=m,
recover_psi=recover_psi,
recover_probe=recover_probe,
probe_options=probe_options,
)))

if recover_psi:
Expand Down Expand Up @@ -590,6 +594,7 @@ def _precondition_nearplane_gradients(
nearplane,
scan_,
unique_probe,
probe,
common_grad_psi,
common_grad_probe,
psi_update_denominator,
Expand All @@ -601,6 +606,7 @@ def _precondition_nearplane_gradients(
recover_psi,
recover_probe,
alpha=0.05,
probe_options,
):

diff = nearplane[..., [m], :, :]
Expand All @@ -625,11 +631,20 @@ def _precondition_nearplane_gradients(
A1 = None

if recover_probe:
common_grad_probe /= ((1 - alpha) * probe_update_denominator +
alpha * probe_update_denominator.max(
axis=(-2, -1),
keepdims=True,
))

b = tike.ptycho.probe.finite_probe_support(
unique_probe[..., [m], :, :],
p=probe_options.probe_support,
radius=probe_options.probe_support_radius,
degree=probe_options.probe_support_degree,
)

common_grad_probe = (common_grad_probe - b * probe[..., [m], :, :]) / (
(1 - alpha) * probe_update_denominator +
alpha * probe_update_denominator.max(
axis=(-2, -1),
keepdims=True,
) + b)

dPO = common_grad_probe * patches
A4 = cp.sum((dPO * dPO.conj()).real + eps, axis=(-2, -1))
Expand Down
25 changes: 18 additions & 7 deletions src/tike/ptycho/solvers/rpie.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import tike.linalg
import tike.opt
import tike.ptycho.position
import tike.ptycho.probe

from ..object import positivity_constraint, smoothness_constraint
from ..position import PositionOptions
from ..probe import orthogonalize_eig

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,7 +88,7 @@ def rpie(
bposition_options = None
else:
bposition_options = comm.pool.map(
PositionOptions.split,
tike.ptycho.position.PositionOptions.split,
position_options,
[b[n] for b in batches],
)
Expand Down Expand Up @@ -134,11 +133,12 @@ def rpie(
probe_options is not None,
position_options=bposition_options,
algorithm_options=algorithm_options,
probe_options=probe_options,
)

if position_options is not None:
comm.pool.map(
PositionOptions.join,
tike.ptycho.position.PositionOptions.join,
position_options,
bposition_options,
[b[n] for b in batches],
Expand All @@ -153,7 +153,7 @@ def rpie(
)

if probe_options and probe_options.orthogonality_constraint:
probe = comm.pool.map(orthogonalize_eig, probe)
probe = comm.pool.map(tike.ptycho.probe.orthogonalize_eig, probe)

if object_options:
psi = comm.pool.map(positivity_constraint,
Expand Down Expand Up @@ -210,6 +210,8 @@ def _update_nearplane(
step_length=1.0,
algorithm_options=None,
position_options=None,
*,
probe_options=None,
):

patches = comm.pool.map(_get_patches, nearplane_, psi, scan_, op=op)
Expand Down Expand Up @@ -268,12 +270,21 @@ def _update_nearplane(
probe_update_denominator = comm.reduce(probe_update_denominator,
'gpu')[0]

probe[0] += step_length * probe_update_numerator / (
b = tike.ptycho.probe.finite_probe_support(
probe[0],
p=probe_options.probe_support,
radius=probe_options.probe_support_radius,
degree=probe_options.probe_support_degree,
)

probe[0] += step_length * (
probe_update_numerator - b * probe[0]
) / (
(1 - alpha) * probe_update_denominator +
alpha * probe_update_denominator.max(
axis=(-2, -1),
keepdims=True,
))
) + b)

probe = comm.pool.bcast([probe[0]])

Expand Down
23 changes: 23 additions & 0 deletions tests/ptycho/test_probe.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import os
import unittest

import numpy as np
import cupy as cp
import tike.ptycho.probe
from tike.communicators import Comm, MPIComm

resultdir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'result',
'ptycho', 'probe')


class TestProbe(unittest.TestCase):

Expand Down Expand Up @@ -107,6 +112,24 @@ def test_init_many_varying_probe_with_multiple_basis(self):
def test_init_1_varying_probe_with_multiple_basis(self):
self.template_init_varing_probe(31, 7, 3, 16, 1)

def test_probe_support(self):
"""Finite probe support penalty function is within expected bounds."""
penalty = tike.ptycho.probe.finite_probe_support(
probe=cp.zeros((101, 101)), # must be odd shaped for min to be 0
radius=0.5 * 0.4,
degree=1.0, # must have degree >= 1 for upper bound to be p
p=2.345,
)
try:
import tifffile
os.makedirs(resultdir, exist_ok=True)
tifffile.imsave(os.path.join(resultdir, 'penalty.tiff'),
penalty.astype('float32').get())
except ImportError:
pass
assert cp.around(cp.min(penalty), 3) == 0.000
assert cp.around(cp.max(penalty), 3) == 2.345


if __name__ == '__main__':
unittest.main()

0 comments on commit 40c47b0

Please sign in to comment.